diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-12-02 12:31:05 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-12-02 12:31:05 -0500 |
| commit | 5fbd4664e76d25de95f89329ee5f0f912fee4259 (patch) | |
| tree | 05812c8f49cd44aca4ebf19a9b836df41ea33396 /simulation/active_blocks.py | |
| parent | 600251accf79333d487c7186dcc5354e310c84c7 (diff) | |
| download | cascades-5fbd4664e76d25de95f89329ee5f0f912fee4259.tar.gz | |
frequency param introduced + plots_utils file sketch
Diffstat (limited to 'simulation/active_blocks.py')
| -rw-r--r-- | simulation/active_blocks.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py index 1495eb8..be1fc3d 100644 --- a/simulation/active_blocks.py +++ b/simulation/active_blocks.py @@ -43,7 +43,9 @@ class ActiveLearning(blocks.extensions.SimpleExtension): def do(self, which_callback, *args): out_degree = np.sum(self.dataset.graph, axis=1) self.dataset.node_p = out_degree / np.sum(out_degree) - print(self.dataset.node_p) + +# def do(self, which_callback, *args): + class JSONDump(blocks.extensions.SimpleExtension): @@ -149,7 +151,8 @@ def create_learned_data_stream(graph, batch_size): if __name__ == "__main__": batch_size = 100 n_obs = 1000 - graph = utils.create_wheel(10) + frequency = 1 + graph = utils.create_wheel(1000) print('GRAPH:\n', graph, '\n-------------\n') g_shared = theano.shared(value=graph, name='graph') @@ -160,17 +163,17 @@ if __name__ == "__main__": alg = algorithms.GradientDescent( cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta() ) - # data_stream = create_learned_data_stream(graph, batch_size) - data_stream = create_fixed_data_stream(n_obs, graph, batch_size) + data_stream = create_learned_data_stream(graph, batch_size) + #data_stream = create_fixed_data_stream(n_obs, graph, batch_size) loop = main_loop.MainLoop( alg, data_stream, extensions=[ - be.FinishAfter(after_n_batches=10**3), - bm.TrainingDataMonitoring([cost, params, - rmse, error], every_n_batches=10), - be.Printing(every_n_batches=10), - JSONDump("log.json", every_n_batches=10) - # ActiveLearning(data_stream.dataset), + be.FinishAfter(after_n_batches=10**4), + bm.TrainingDataMonitoring([cost, rmse, error], + every_n_batches=frequency), + be.Printing(every_n_batches=frequency), + JSONDump("tmpactive_log.json", every_n_batches=frequency), + ActiveLearning(data_stream.dataset, every_n_batches=frequency) ], ) loop.run() |
