diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-12-02 12:36:05 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-12-02 12:36:05 -0500 |
| commit | 5e546cb6c96e5e5e575730e27c175f558da5ec82 (patch) | |
| tree | 3b9f6b81df16d84396fac029e8c487190348e2c6 /simulation/vi_blocks.py | |
| parent | c985f325b4cfcb5563c4edb7e1f9fe89dcc3b892 (diff) | |
| parent | 3b5321add6cd71c6e23ff65e75faaa48e6829634 (diff) | |
| download | cascades-5e546cb6c96e5e5e575730e27c175f558da5ec82.tar.gz | |
merge conflicts 2
Diffstat (limited to 'simulation/vi_blocks.py')
| -rw-r--r-- | simulation/vi_blocks.py | 47 |
1 files changed, 22 insertions, 25 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py index 84e637f..1177979 100644 --- a/simulation/vi_blocks.py +++ b/simulation/vi_blocks.py @@ -1,17 +1,15 @@ -import main as mn +import utils +import utils_blocks as ub import theano from theano import tensor as tsr -import blocks -import blocks.algorithms, blocks.main_loop, blocks.extensions.monitoring +from blocks import algorithms, main_loop +import blocks.extensions as be +import blocks.extensions.monitoring as bm import theano.tensor.shared_randomstreams import numpy as np -from six.moves import range -import fuel -import fuel.datasets -import active_blocks as ab -class ClippedParams(blocks.algorithms.StepRule): +class ClippedParams(algorithms.StepRule): """A rule to maintain parameters within a specified range""" def __init__(self, min_value, max_value): self.min_value = min_value @@ -38,8 +36,8 @@ def create_vi_model(n_nodes, n_samp=100): sig0 = theano.shared(value=aux(.5, .1), name='sig0') srng = tsr.shared_randomstreams.RandomStreams(seed=123) - theta = srng.normal((n_samp, n_nodes, n_nodes)) * sig[None, :, :] + mu[None, - :, :] + theta = (srng.normal((n_samp, n_nodes, n_nodes)) * sig[None, :, :] + + mu[None, :, :]) y = tsr.maximum(tsr.dot(x, theta), 1e-3) infect = tsr.log(1. - tsr.exp(-y[0:-1])).dimshuffle(1, 0, 2) lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samp @@ -56,28 +54,27 @@ if __name__ == "__main__": #n_cascades = 10000 batch_size = 10 n_samples = 50 - graph = mn.create_random_graph(n_nodes=4) + graph = utils.create_random_graph(n_nodes=4) print('GRAPH:\n', graph, '\n-------------\n') x, s, mu, sig, cost = create_vi_model(len(graph), n_samples) - rmse, g_shared = ab.rmse_error(graph, mu) + rmse = ub.rmse_error(graph, mu) - step_rules= blocks.algorithms.CompositeRule([blocks.algorithms.AdaDelta(), - ClippedParams(1e-3, 1 - 1e-3)]) + step_rules = algorithms.CompositeRule([algorithms.AdaDelta(), + ClippedParams(1e-3, 1 - 1e-3)]) - alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig], - step_rule=step_rules) - #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size, - # shuffle=False) - data_stream = ab.create_learned_data_stream(graph, batch_size) - loop = blocks.main_loop.MainLoop( + alg = algorithms.GradientDescent(cost=cost, parameters=[mu, sig], + step_rule=step_rules) + data_stream = ub.fixed_data_stream(n_cascades, graph, batch_size, + shuffle=False) + # data_stream = ub.dynamic_data_stream(graph, batch_size) + loop = main_loop.MainLoop( alg, data_stream, extensions=[ - blocks.extensions.FinishAfter(after_n_batches = 10**4), - blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig, - rmse, g_shared], after_batch=True), - blocks.extensions.Printing(every_n_batches = 100, - after_epoch=False), + be.FinishAfter(after_n_batches=10**4), + bm.TrainingDataMonitoring([cost, mu, sig, rmse], + every_n_batches=10), + be.Printing(every_n_batches=100, after_epoch=False), ] ) loop.run() |
