aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/active_blocks.py
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2015-11-30 20:37:24 -0500
committerThibaut Horel <thibaut.horel@gmail.com>2015-11-30 20:37:24 -0500
commit5f9af1c88a19be3bff23ef3c8f3705ea43f5e0d6 (patch)
treeffc7b81e2bc32941eb1de4c73fcce4cbb3288179 /simulation/active_blocks.py
parentd2d9ab7ba51f2cc382a2676aaacecbf75bcfadc8 (diff)
downloadcascades-5f9af1c88a19be3bff23ef3c8f3705ea43f5e0d6.tar.gz
Fix mle with blocks
Diffstat (limited to 'simulation/active_blocks.py')
-rw-r--r--simulation/active_blocks.py24
1 files changed, 11 insertions, 13 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py
index a1f6e76..40266b2 100644
--- a/simulation/active_blocks.py
+++ b/simulation/active_blocks.py
@@ -98,24 +98,20 @@ def rmse_error(graph, params):
return rmse, g_shared
-def create_fixed_data_stream(n_cascades, graph, batch_size, shuffle=True):
+def create_fixed_data_stream(n_obs, graph, batch_size, shuffle=True):
"""
creates a datastream for a fixed (not learned) dataset:
-shuffle (bool): shuffle minibatches but not within minibatch, else
sequential (non-shuffled) batches are used
"""
- cascades = utils.build_cascade_list(
- utils.simulate_cascades(n_cascades, graph),
- collapse=True)
- x_obs, s_obs = cascades[0], cascades[1]
+ x_obs, s_obs = utils.simulate_cascades(n_obs, graph)
data_set = fuel.datasets.base.IndexableDataset(collections.OrderedDict(
[('x', x_obs), ('s', s_obs)]
))
if shuffle:
- scheme = ShuffledBatchesScheme(len(x_obs), batch_size=batch_size)
+ scheme = ShuffledBatchesScheme(n_obs, batch_size=batch_size)
else:
- scheme = fuel.schemes.SequentialScheme(len(x_obs),
- batch_size=batch_size)
+ scheme = fuel.schemes.SequentialScheme(n_obs, batch_size=batch_size)
return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme)
@@ -127,8 +123,9 @@ def create_learned_data_stream(graph, batch_size):
if __name__ == "__main__":
- batch_size = 1000
- graph = utils.create_wheel(1000)
+ batch_size = 100
+ n_obs = 1000
+ graph = utils.create_wheel(10)
print('GRAPH:\n', graph, '\n-------------\n')
x, s, params, cost = create_mle_model(graph)
@@ -137,15 +134,16 @@ if __name__ == "__main__":
alg = algorithms.GradientDescent(
cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
)
- data_stream = create_learned_data_stream(graph, batch_size)
+ # data_stream = create_learned_data_stream(graph, batch_size)
+ data_stream = create_fixed_data_stream(n_obs, graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
extensions=[
- be.FinishAfter(after_n_batches=10**4),
+ be.FinishAfter(after_n_batches=10**3),
bm.TrainingDataMonitoring([cost, params,
rmse, g_shared], after_batch=True),
be.Printing(every_n_batches=10),
- ActiveLearning(data_stream.dataset),
+ #ActiveLearning(data_stream.dataset),
]
)
loop.run()