aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
Diffstat (limited to 'simulation')
-rw-r--r--simulation/mle_blocks.py1
-rw-r--r--simulation/utils_blocks.py31
-rw-r--r--simulation/vi_blocks.py3
3 files changed, 3 insertions, 32 deletions
diff --git a/simulation/mle_blocks.py b/simulation/mle_blocks.py
index 89aaf2e..ab8816f 100644
--- a/simulation/mle_blocks.py
+++ b/simulation/mle_blocks.py
@@ -47,6 +47,7 @@ if __name__ == "__main__":
# data_stream = ub.fixed_data_stream(n_obs, graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
+ log_backend="sqlite",
extensions=[
be.FinishAfter(after_n_batches=10**3),
bm.TrainingDataMonitoring([cost, params,
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py
index 0d30786..3b29972 100644
--- a/simulation/utils_blocks.py
+++ b/simulation/utils_blocks.py
@@ -121,34 +121,3 @@ def dynamic_data_stream(graph, batch_size):
data_set = LearnedDataset(node_p, graph)
scheme = fuel.schemes.ConstantScheme(batch_size)
return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme)
-
-
-if __name__ == "__main__":
- batch_size = 100
- n_obs = 1000
- frequency = 1
- graph = utils.create_wheel(1000)
- print('GRAPH:\n', graph, '\n-------------\n')
-
- g_shared = theano.shared(value=graph, name='graph')
- x, s, params, cost = create_mle_model(graph)
- rmse = rmse_error(g_shared, params)
- error = relative_error(g_shared, params)
-
- alg = algorithms.GradientDescent(
- cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
- )
- data_stream = create_learned_data_stream(graph, batch_size)
- #data_stream = create_fixed_data_stream(n_obs, graph, batch_size)
- loop = main_loop.MainLoop(
- alg, data_stream,
- extensions=[
- be.FinishAfter(after_n_batches=10**4),
- bm.TrainingDataMonitoring([cost, rmse, error],
- every_n_batches=frequency),
- be.Printing(every_n_batches=frequency),
- JSONDump("tmpactive_log.json", every_n_batches=frequency),
- ActiveLearning(data_stream.dataset, every_n_batches=frequency)
- ],
- )
- loop.run()
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index 1177979..b78375b 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -51,7 +51,7 @@ def create_vi_model(n_nodes, n_samp=100):
if __name__ == "__main__":
- #n_cascades = 10000
+ n_cascades = 10000
batch_size = 10
n_samples = 50
graph = utils.create_random_graph(n_nodes=4)
@@ -70,6 +70,7 @@ if __name__ == "__main__":
# data_stream = ub.dynamic_data_stream(graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
+ log_backend="sqlite",
extensions=[
be.FinishAfter(after_n_batches=10**4),
bm.TrainingDataMonitoring([cost, mu, sig, rmse],