aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-29 17:03:22 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-29 17:03:22 -0500
commit7322c00eafcde38dadbf9d4f05a1572d627355bf (patch)
tree45db3a6e4d71b0a17be2a87b49a225d0d430cb1b /simulation/vi_blocks.py
parent041aa021657a3c290952b222e3141449638bad19 (diff)
downloadcascades-7322c00eafcde38dadbf9d4f05a1572d627355bf.tar.gz
active learning for mle + variational inf. (not bug-free)
Diffstat (limited to 'simulation/vi_blocks.py')
-rw-r--r--simulation/vi_blocks.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
new file mode 100644
index 0000000..3f88611
--- /dev/null
+++ b/simulation/vi_blocks.py
@@ -0,0 +1,79 @@
+import main as mn
+import theano
+from theano import tensor as tsr
+import blocks
+import blocks.algorithms, blocks.main_loop, blocks.extensions.monitoring
+import theano.tensor.shared_randomstreams
+import numpy as np
+from six.moves import range
+import fuel
+import fuel.datasets
+import active_blocks as ab
+
+
+class ClippedParams(blocks.algorithms.StepRule):
+ """A rule to maintain parameters within a specified range"""
+ def __init__(self, min_value, max_value):
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def compute_step(self, parameter, previous_step):
+ min_clipped = tsr.switch(parameter - previous_step < self.min_value,
+ self.min_value, previous_step)
+ return tsr.switch(parameter - previous_step > self.max_value,
+ self.max_value, min_clipped)
+
+
+def create_vi_model(n_nodes, n_samp=100):
+ """return variational inference theano computation graph"""
+ def aux():
+ rand = .1 + .05 * np.random.normal(size=(n_nodes, n_nodes))
+ return rand.astype(theano.config.floatX)
+
+ x = tsr.matrix(name='x', dtype='int8')
+ s = tsr.matrix(name='s', dtype='int8')
+ mu = theano.shared(value=aux(), name='mu1')
+ sig = theano.shared(value=aux(), name='sig1')
+ mu0 = theano.shared(value=aux(), name='mu0')
+ sig0 = theano.shared(value=aux(), name='sig0')
+
+ srng = tsr.shared_randomstreams.RandomStreams(seed=123)
+ theta = srng.normal((n_samp, n_nodes, n_nodes)) * sig[None, :, :] + mu[None,
+ :, :]
+ y = tsr.maximum(tsr.dot(x, theta), 1e-3)
+ infect = tsr.log(1. - tsr.exp(-y[0:-1])).dimshuffle(1, 0, 2)
+ lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samp
+ lkl_neg = tsr.sum(-y[0:-1].dimshuffle(1, 0, 2) * (~x[1:] & s[1:])) / n_samp
+ lkl = lkl_pos + lkl_neg
+ kl = tsr.sum(tsr.log(sig / sig0) + (sig0**2 + (mu0 - mu)**2)/(2*sig)**2)
+ cost = lkl + kl
+ cost.name = 'cost'
+
+ return x, s, mu, sig, cost
+
+
+if __name__ == "__main__":
+ n_cascades = 10000
+ batch_size = 1000
+ graph = mn.create_random_graph(n_nodes=3)
+ print('GRAPH:\n', graph, '\n-------------\n')
+
+ x, s, mu, sig, cost = create_vi_model(len(graph))
+
+ step_rules= blocks.algorithms.CompositeRule([blocks.algorithms.AdaDelta(),
+ ClippedParams(1e-3, 1 - 1e-3)])
+
+ alg = blocks.algorithms.GradientDescent(cost=-cost, parameters=[mu, sig],
+ step_rule=step_rules)
+ data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
+ shuffle=False)
+ loop = blocks.main_loop.MainLoop(
+ alg, data_stream,
+ extensions=[
+ blocks.extensions.FinishAfter(after_n_batches = 10**4),
+ blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig],
+ after_batch=True),
+ blocks.extensions.Printing(every_n_batches = 10),
+ ]
+ )
+ loop.run()