diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-30 16:45:25 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-30 16:45:25 -0500 |
| commit | 52cf8293061a1e35b5b443ef6dc70aa51727cf00 (patch) | |
| tree | 0a3c2cb895c7564906016209ead389024aeaa6af /simulation | |
| parent | f9e3d5e4dda32f33e5e5a0e82dda30a23f5dfae6 (diff) | |
| download | cascades-52cf8293061a1e35b5b443ef6dc70aa51727cf00.tar.gz | |
syncing with Thibaut
Diffstat (limited to 'simulation')
| -rw-r--r-- | simulation/active_blocks.py | 2 | ||||
| -rw-r--r-- | simulation/vi_blocks.py | 9 |
2 files changed, 5 insertions, 6 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py index 7aa1afb..569cb6c 100644 --- a/simulation/active_blocks.py +++ b/simulation/active_blocks.py @@ -156,8 +156,6 @@ if __name__ == "__main__": rmse, g_shared], after_batch=True), blocks.extensions.Printing(every_n_batches = 10), ActiveLearning(data_stream.dataset), - blocks.extras.extensions.Plot('graph rmse', channels=[], - every_n_batches = 10) ] ) loop.run() diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py index dcf6b46..2b03198 100644 --- a/simulation/vi_blocks.py +++ b/simulation/vi_blocks.py @@ -67,16 +67,17 @@ if __name__ == "__main__": alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig], step_rule=step_rules) - #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size, - # shuffle=False) - data_stream = ab.create_learned_data_stream(graph, batch_size) + data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size, + shuffle=False) + #data_stream = ab.create_learned_data_stream(graph, batch_size) loop = blocks.main_loop.MainLoop( alg, data_stream, extensions=[ blocks.extensions.FinishAfter(after_n_batches = 10**4), blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig, rmse, g_shared], after_batch=True), - blocks.extensions.Printing(every_n_batches = 100), + blocks.extensions.Printing(every_n_batches = 100, + after_epoch=False), ] ) loop.run() |
