From 3b5321add6cd71c6e23ff65e75faaa48e6829634 Mon Sep 17 00:00:00 2001 From: Thibaut Horel Date: Tue, 1 Dec 2015 15:00:51 -0500 Subject: Extract blocks utils and reorganize code --- simulation/mle_blocks.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 simulation/mle_blocks.py (limited to 'simulation/mle_blocks.py') diff --git a/simulation/mle_blocks.py b/simulation/mle_blocks.py new file mode 100644 index 0000000..89aaf2e --- /dev/null +++ b/simulation/mle_blocks.py @@ -0,0 +1,59 @@ +import utils +import utils_blocks as ub +import theano +from theano import tensor as tsr +from blocks import algorithms, main_loop +import blocks.extensions as be +import blocks.extensions.monitoring as bm +import numpy as np + + +def create_mle_model(graph): + """return cascade likelihood theano computation graph""" + n_nodes = len(graph) + x = tsr.matrix(name='x', dtype='int8') + s = tsr.matrix(name='s', dtype='int8') + params = theano.shared( + .5 + .01 * + np.random.normal(size=(n_nodes, n_nodes)).astype(theano.config.floatX), + name='params' + ) + y = tsr.maximum(tsr.dot(x, params), 1e-5) + infect = tsr.log(1. - tsr.exp(-y[0:-1])) + lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) + lkl_neg = tsr.sum(-y[0:-1] * (~x[1:] & s[1:])) + lkl_mle = lkl_pos + lkl_neg + lkl_mle.name = 'cost' + + return x, s, params, lkl_mle + + +if __name__ == "__main__": + batch_size = 100 + n_obs = 100000 + graph = utils.create_wheel(100) + + print('GRAPH:\n', graph, '\n-------------\n') + + g_shared = theano.shared(value=graph, name='graph') + x, s, params, cost = create_mle_model(graph) + rmse = ub.rmse_error(g_shared, params) + error = ub.relative_error(g_shared, params) + + alg = algorithms.GradientDescent( + cost=-cost, parameters=[params], step_rule=algorithms.AdaDelta() + ) + data_stream = ub.dynamic_data_stream(graph, batch_size) + # data_stream = ub.fixed_data_stream(n_obs, graph, batch_size) + loop = main_loop.MainLoop( + alg, data_stream, + extensions=[ + be.FinishAfter(after_n_batches=10**3), + bm.TrainingDataMonitoring([cost, params, + rmse, error], every_n_batches=10), + be.Printing(every_n_batches=10), + ub.JSONDump("log.json", every_n_batches=10), + ub.ActiveLearning(data_stream.dataset), + ], + ) + loop.run() -- cgit v1.2.3-70-g09d2