aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2015-12-01 15:00:51 -0500
committerThibaut Horel <thibaut.horel@gmail.com>2015-12-01 15:00:51 -0500
commit3b5321add6cd71c6e23ff65e75faaa48e6829634 (patch)
treeafe0bd06cff712efed3f0de9366311fad318e046 /simulation/vi_blocks.py
parentc39e365d71fc07f3ae5b198252b4a7247efb9bc5 (diff)
downloadcascades-3b5321add6cd71c6e23ff65e75faaa48e6829634.tar.gz
Extract blocks utils and reorganize code
Diffstat (limited to 'simulation/vi_blocks.py')
-rw-r--r--simulation/vi_blocks.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index 11dc4de..94038a5 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -1,4 +1,5 @@
import utils
+import utils_blocks as ub
import theano
from theano import tensor as tsr
from blocks import algorithms, main_loop
@@ -6,7 +7,6 @@ import blocks.extensions as be
import blocks.extensions.monitoring as bm
import theano.tensor.shared_randomstreams
import numpy as np
-import active_blocks as ab
class ClippedParams(algorithms.StepRule):
@@ -58,16 +58,16 @@ if __name__ == "__main__":
print('GRAPH:\n', graph, '\n-------------\n')
x, s, mu, sig, cost = create_vi_model(len(graph), n_samples)
- rmse = ab.rmse_error(graph, mu)
+ rmse = ub.rmse_error(graph, mu)
step_rules = algorithms.CompositeRule([algorithms.AdaDelta(),
ClippedParams(1e-3, 1 - 1e-3)])
alg = algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
- shuffle=False)
- # data_stream = ab.create_learned_data_stream(graph, batch_size)
+ data_stream = ub.fixed_data_stream(n_cascades, graph, batch_size,
+ shuffle=False)
+ # data_stream = ub.dynamic_data_stream(graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
extensions=[