aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-30 16:45:25 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-30 16:45:25 -0500
commit52cf8293061a1e35b5b443ef6dc70aa51727cf00 (patch)
tree0a3c2cb895c7564906016209ead389024aeaa6af /simulation/vi_blocks.py
parentf9e3d5e4dda32f33e5e5a0e82dda30a23f5dfae6 (diff)
downloadcascades-52cf8293061a1e35b5b443ef6dc70aa51727cf00.tar.gz
syncing with Thibaut
Diffstat (limited to 'simulation/vi_blocks.py')
-rw-r--r--simulation/vi_blocks.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index dcf6b46..2b03198 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -67,16 +67,17 @@ if __name__ == "__main__":
alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
- # shuffle=False)
- data_stream = ab.create_learned_data_stream(graph, batch_size)
+ data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
+ shuffle=False)
+ #data_stream = ab.create_learned_data_stream(graph, batch_size)
loop = blocks.main_loop.MainLoop(
alg, data_stream,
extensions=[
blocks.extensions.FinishAfter(after_n_batches = 10**4),
blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig,
rmse, g_shared], after_batch=True),
- blocks.extensions.Printing(every_n_batches = 100),
+ blocks.extensions.Printing(every_n_batches = 100,
+ after_epoch=False),
]
)
loop.run()