aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
Diffstat (limited to 'simulation')
-rw-r--r--simulation/vi_blocks.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index 2b03198..84e637f 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -53,8 +53,8 @@ def create_vi_model(n_nodes, n_samp=100):
if __name__ == "__main__":
- n_cascades = 10000
- batch_size = 1000
+ #n_cascades = 10000
+ batch_size = 10
n_samples = 50
graph = mn.create_random_graph(n_nodes=4)
print('GRAPH:\n', graph, '\n-------------\n')
@@ -67,9 +67,9 @@ if __name__ == "__main__":
alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
- shuffle=False)
- #data_stream = ab.create_learned_data_stream(graph, batch_size)
+ #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
+ # shuffle=False)
+ data_stream = ab.create_learned_data_stream(graph, batch_size)
loop = blocks.main_loop.MainLoop(
alg, data_stream,
extensions=[