diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-01 15:00:51 -0500 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-01 15:00:51 -0500 |
| commit | 3b5321add6cd71c6e23ff65e75faaa48e6829634 (patch) | |
| tree | afe0bd06cff712efed3f0de9366311fad318e046 /simulation/vi_blocks.py | |
| parent | c39e365d71fc07f3ae5b198252b4a7247efb9bc5 (diff) | |
| download | cascades-3b5321add6cd71c6e23ff65e75faaa48e6829634.tar.gz | |
Extract blocks utils and reorganize code
Diffstat (limited to 'simulation/vi_blocks.py')
| -rw-r--r-- | simulation/vi_blocks.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py index 11dc4de..94038a5 100644 --- a/simulation/vi_blocks.py +++ b/simulation/vi_blocks.py @@ -1,4 +1,5 @@ import utils +import utils_blocks as ub import theano from theano import tensor as tsr from blocks import algorithms, main_loop @@ -6,7 +7,6 @@ import blocks.extensions as be import blocks.extensions.monitoring as bm import theano.tensor.shared_randomstreams import numpy as np -import active_blocks as ab class ClippedParams(algorithms.StepRule): @@ -58,16 +58,16 @@ if __name__ == "__main__": print('GRAPH:\n', graph, '\n-------------\n') x, s, mu, sig, cost = create_vi_model(len(graph), n_samples) - rmse = ab.rmse_error(graph, mu) + rmse = ub.rmse_error(graph, mu) step_rules = algorithms.CompositeRule([algorithms.AdaDelta(), ClippedParams(1e-3, 1 - 1e-3)]) alg = algorithms.GradientDescent(cost=cost, parameters=[mu, sig], step_rule=step_rules) - data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size, - shuffle=False) - # data_stream = ab.create_learned_data_stream(graph, batch_size) + data_stream = ub.fixed_data_stream(n_cascades, graph, batch_size, + shuffle=False) + # data_stream = ub.dynamic_data_stream(graph, batch_size) loop = main_loop.MainLoop( alg, data_stream, extensions=[ |
