aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
diff options
context:
space:
mode:
Diffstat (limited to 'simulation/vi_blocks.py')
-rw-r--r--simulation/vi_blocks.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index dcf6b46..2b03198 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -67,16 +67,17 @@ if __name__ == "__main__":
alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
- # shuffle=False)
- data_stream = ab.create_learned_data_stream(graph, batch_size)
+ data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
+ shuffle=False)
+ #data_stream = ab.create_learned_data_stream(graph, batch_size)
loop = blocks.main_loop.MainLoop(
alg, data_stream,
extensions=[
blocks.extensions.FinishAfter(after_n_batches = 10**4),
blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig,
rmse, g_shared], after_batch=True),
- blocks.extensions.Printing(every_n_batches = 100),
+ blocks.extensions.Printing(every_n_batches = 100,
+ after_epoch=False),
]
)
loop.run()