aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
diff options
context:
space:
mode:
Diffstat (limited to 'simulation/vi_blocks.py')
-rw-r--r--simulation/vi_blocks.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index b78375b..5deb6f6 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -51,10 +51,10 @@ def create_vi_model(n_nodes, n_samp=100):
if __name__ == "__main__":
- n_cascades = 10000
- batch_size = 10
+ batch_size = 100
+ frequency = 10
n_samples = 50
- graph = utils.create_random_graph(n_nodes=4)
+ graph = utils.create_random_graph(n_nodes=10)
print('GRAPH:\n', graph, '\n-------------\n')
x, s, mu, sig, cost = create_vi_model(len(graph), n_samples)
@@ -65,17 +65,16 @@ if __name__ == "__main__":
alg = algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- data_stream = ub.fixed_data_stream(n_cascades, graph, batch_size,
- shuffle=False)
- # data_stream = ub.dynamic_data_stream(graph, batch_size)
+ data_stream = ub.dynamic_data_stream(graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
log_backend="sqlite",
extensions=[
be.FinishAfter(after_n_batches=10**4),
- bm.TrainingDataMonitoring([cost, mu, sig, rmse],
- every_n_batches=10),
- be.Printing(every_n_batches=100, after_epoch=False),
+ bm.TrainingDataMonitoring([cost, rmse, mu], every_n_batches=frequency),
+ be.Printing(every_n_batches=frequency, after_epoch=False),
+ ub.JSONDump("logs/tmp.json", every_n_batches=10),
+ #ub.ActiveLearning(dataset=data_stream.dataset, params=graph)
]
)
loop.run()