aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_theano.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-22 22:22:34 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-22 22:22:34 -0500
commit65638ea7d886c25b8bf75e43a5ce46db2ebbaf53 (patch)
treea2cc54487d224d4d69a3aa7c6a3fe4d1820baceb /simulation/vi_theano.py
parent2a193599c837b5dd12d38b23577b8403a18f2822 (diff)
downloadcascades-65638ea7d886c25b8bf75e43a5ce46db2ebbaf53.tar.gz
first semi working theano version
Diffstat (limited to 'simulation/vi_theano.py')
-rw-r--r--simulation/vi_theano.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/simulation/vi_theano.py b/simulation/vi_theano.py
new file mode 100644
index 0000000..562fa67
--- /dev/null
+++ b/simulation/vi_theano.py
@@ -0,0 +1,66 @@
+import main as mn
+import theano
+from theano import tensor as tsr
+import theano.tensor.shared_randomstreams
+import numpy as np
+
+n_cascades = 1000
+n_nodes = 4
+n_samples = 100
+srng = tsr.shared_randomstreams.RandomStreams(seed=123)
+lr = 1e-2
+n_epochs = 10
+
+
+# Declare Theano variables
+mu = theano.shared(.5 * np.random.rand(1, n_nodes, n_nodes), name="mu",
+ broadcastable=(True, False, False))
+sig = theano.shared(.3 * np.random.rand(1, n_nodes, n_nodes), name="sig",
+ broadcastable=(True, False, False))
+mu0 = theano.shared(.5 * np.random.rand(1, n_nodes, n_nodes), name="mu",
+ broadcastable=(True, False, False))
+sig0 = theano.shared(.3 * np.random.rand(1, n_nodes, n_nodes), name="sig",
+ broadcastable=(True, False, False))
+x = tsr.matrix(name='x', dtype='int8')
+s = tsr.matrix(name='s', dtype='int8')
+
+# Construct Theano graph
+theta = srng.normal((n_samples, n_nodes, n_nodes)) * sig + mu
+y = tsr.clip(tsr.dot(x, theta), 1e-3, 1)
+infect = tsr.log(1. - tsr.exp(-y[0:-1])).dimshuffle(1, 0, 2)
+lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samples
+lkl_neg = tsr.sum(-y[0:-1].dimshuffle(1, 0, 2) * (~x[1:] & s[1:])) / n_samples
+
+lkl = lkl_pos + lkl_neg
+kl = tsr.sum(tsr.log(sig / sig0) + (sig0**2 + (mu0 - mu)**2)/(2*sig)**2)
+res = lkl + kl
+
+gmu, gsig = theano.gradient.grad(lkl, [mu, sig])
+gmukl, gsigkl = theano.grad(kl, [mu, sig])
+
+# Compile into functions
+loglkl_full = theano.function([x, s], lkl)
+train = theano.function(inputs=[x, s], outputs=res,
+ updates=((mu, tsr.clip(mu + lr * gmu, 0, 1)),
+ (sig, tsr.clip(sig + lr * gsig, 1e-3, 1))))
+train_kl = theano.function(inputs=[], outputs=[],
+ updates=((mu, tsr.clip(mu + lr * gmukl, 0, 1)),
+ (sig, tsr.clip(sig + lr * gsigkl, 1e-3, 1))))
+
+
+if __name__ == "__main__":
+ graph = np.random.binomial(2, p=.2, size=(n_nodes, n_nodes))
+ for k in range(len(graph)):
+ graph[k, k] = 0
+ p = 0.5
+ graph = np.log(1. / (1 - p * graph))
+ cascades = mn.build_cascade_list(mn.simulate_cascades(n_cascades, graph),
+ collapse=True)
+ x_obs, s_obs = cascades[0], cascades[1]
+ for i in range(n_epochs):
+ train_kl()
+ for k in xrange(len(x_obs)/100):
+ cost = train(x_obs[k*100:(k+1)*100], s_obs[k*100:(k+1)*100])
+ print(cost)
+ print(mu.get_value())
+ print(graph)