aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/mle_blocks.py
diff options
context:
space:
mode:
Diffstat (limited to 'simulation/mle_blocks.py')
-rw-r--r--simulation/mle_blocks.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/simulation/mle_blocks.py b/simulation/mle_blocks.py
new file mode 100644
index 0000000..89aaf2e
--- /dev/null
+++ b/simulation/mle_blocks.py
@@ -0,0 +1,59 @@
+import utils
+import utils_blocks as ub
+import theano
+from theano import tensor as tsr
+from blocks import algorithms, main_loop
+import blocks.extensions as be
+import blocks.extensions.monitoring as bm
+import numpy as np
+
+
+def create_mle_model(graph):
+ """return cascade likelihood theano computation graph"""
+ n_nodes = len(graph)
+ x = tsr.matrix(name='x', dtype='int8')
+ s = tsr.matrix(name='s', dtype='int8')
+ params = theano.shared(
+ .5 + .01 *
+ np.random.normal(size=(n_nodes, n_nodes)).astype(theano.config.floatX),
+ name='params'
+ )
+ y = tsr.maximum(tsr.dot(x, params), 1e-5)
+ infect = tsr.log(1. - tsr.exp(-y[0:-1]))
+ lkl_pos = tsr.sum(infect * (x[1:] & s[1:]))
+ lkl_neg = tsr.sum(-y[0:-1] * (~x[1:] & s[1:]))
+ lkl_mle = lkl_pos + lkl_neg
+ lkl_mle.name = 'cost'
+
+ return x, s, params, lkl_mle
+
+
+if __name__ == "__main__":
+ batch_size = 100
+ n_obs = 100000
+ graph = utils.create_wheel(100)
+
+ print('GRAPH:\n', graph, '\n-------------\n')
+
+ g_shared = theano.shared(value=graph, name='graph')
+ x, s, params, cost = create_mle_model(graph)
+ rmse = ub.rmse_error(g_shared, params)
+ error = ub.relative_error(g_shared, params)
+
+ alg = algorithms.GradientDescent(
+ cost=-cost, parameters=[params], step_rule=algorithms.AdaDelta()
+ )
+ data_stream = ub.dynamic_data_stream(graph, batch_size)
+ # data_stream = ub.fixed_data_stream(n_obs, graph, batch_size)
+ loop = main_loop.MainLoop(
+ alg, data_stream,
+ extensions=[
+ be.FinishAfter(after_n_batches=10**3),
+ bm.TrainingDataMonitoring([cost, params,
+ rmse, error], every_n_batches=10),
+ be.Printing(every_n_batches=10),
+ ub.JSONDump("log.json", every_n_batches=10),
+ ub.ActiveLearning(data_stream.dataset),
+ ],
+ )
+ loop.run()