diff options
| -rw-r--r-- | simulation/active_blocks.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py index a1f6e76..40266b2 100644 --- a/simulation/active_blocks.py +++ b/simulation/active_blocks.py @@ -98,24 +98,20 @@ def rmse_error(graph, params): return rmse, g_shared -def create_fixed_data_stream(n_cascades, graph, batch_size, shuffle=True): +def create_fixed_data_stream(n_obs, graph, batch_size, shuffle=True): """ creates a datastream for a fixed (not learned) dataset: -shuffle (bool): shuffle minibatches but not within minibatch, else sequential (non-shuffled) batches are used """ - cascades = utils.build_cascade_list( - utils.simulate_cascades(n_cascades, graph), - collapse=True) - x_obs, s_obs = cascades[0], cascades[1] + x_obs, s_obs = utils.simulate_cascades(n_obs, graph) data_set = fuel.datasets.base.IndexableDataset(collections.OrderedDict( [('x', x_obs), ('s', s_obs)] )) if shuffle: - scheme = ShuffledBatchesScheme(len(x_obs), batch_size=batch_size) + scheme = ShuffledBatchesScheme(n_obs, batch_size=batch_size) else: - scheme = fuel.schemes.SequentialScheme(len(x_obs), - batch_size=batch_size) + scheme = fuel.schemes.SequentialScheme(n_obs, batch_size=batch_size) return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme) @@ -127,8 +123,9 @@ def create_learned_data_stream(graph, batch_size): if __name__ == "__main__": - batch_size = 1000 - graph = utils.create_wheel(1000) + batch_size = 100 + n_obs = 1000 + graph = utils.create_wheel(10) print('GRAPH:\n', graph, '\n-------------\n') x, s, params, cost = create_mle_model(graph) @@ -137,15 +134,16 @@ if __name__ == "__main__": alg = algorithms.GradientDescent( cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta() ) - data_stream = create_learned_data_stream(graph, batch_size) + # data_stream = create_learned_data_stream(graph, batch_size) + data_stream = create_fixed_data_stream(n_obs, graph, batch_size) loop = main_loop.MainLoop( alg, data_stream, extensions=[ - be.FinishAfter(after_n_batches=10**4), + be.FinishAfter(after_n_batches=10**3), bm.TrainingDataMonitoring([cost, params, rmse, g_shared], after_batch=True), be.Printing(every_n_batches=10), - ActiveLearning(data_stream.dataset), + #ActiveLearning(data_stream.dataset), ] ) loop.run() |
