aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-15 20:16:25 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-15 20:16:25 -0500
commitec129983fbac109000d436ce526763206b81aa76 (patch)
treeecc83ddf6752e8f4c20cb843ff3173524bf8ec79 /simulation/vi.py
parent730b35ee58cd616ea626bde7171144c326d7cfd8 (diff)
downloadcascades-ec129983fbac109000d436ce526763206b81aa76.tar.gz
variational inference sgd function
Diffstat (limited to 'simulation/vi.py')
-rw-r--r--simulation/vi.py54
1 files changed, 46 insertions, 8 deletions
diff --git a/simulation/vi.py b/simulation/vi.py
index 9f2ed40..3ba6b96 100644
--- a/simulation/vi.py
+++ b/simulation/vi.py
@@ -1,3 +1,5 @@
+import time
+import main as mn
import autograd.numpy as np
from autograd import grad
@@ -10,20 +12,56 @@ def h(m):
return -m
-def ll(cascades, theta):
+def ll(x, s, theta):
+ """
+ x : infected
+ s : susceptible
+ """
res = 0
- for x, s in cascades:
- for t in range(1, x.shape[1] + 1):
- w = np.dot(x[t-1], theta)
- res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum()
+ for t in range(1, x.shape[1] + 1):
+ w = np.dot(x[t-1], theta)
+ res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum()
def sample(params):
return np.random.normal(*params)
-def ll_full(params, cascades):
- nsamples = 50
- return np.mean([ll(cascades, sample(params)) for _ in xrange(nsamples)])
+def ll_full(params, x, s, nsamples=50):
+ return np.mean([ll(x, s, sample(params)) for _ in xrange(nsamples)])
+
grad_ll_full = grad(ll_full)
+
+
+def kl(params1, params0):
+ mu0, sig0 = params0
+ mu1, sig1 = params1
+ return np.sum(np.log(sig1/sig0) + (sig0**2 + (mu0 - mu1)**2)/(2*sig1)**2)
+
+
+grad_kl = grad(kl)
+
+
+def sgd(mu1, sig1, mu0, sig0, cascades, n_e, lr=lambda t: 1e-3):
+ g_mu1, g_sig1 = grad_kl((mu1, sig1), (mu0, sig0))
+ for t in xrange(n_e):
+ lrt = lr(t) # learning rate
+ mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1
+ for x, s in cascades:
+ g_mu1, g_sig1 = grad_ll_full((mu1, sig1), x, s)
+ mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1
+ res = np.sum(ll_full((mu1, sig1), x, s) for x, s in cascades) +\
+ kl((mu1, sig1), (mu0, sig0))
+ print("Epoch: {}\t LB: {}\t Time: {}".format(t, res, time.time()))
+
+
+if __name__=='__main__':
+ g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
+ p = 0.5
+ g = np.log(1. / (1 - p * g))
+ cascades = mn.build_cascade_list(mn.simulate_cascades(100, g))
+ params0 = (1. + .2 * np.random.normal(len(g)), \
+ 1 + .2 * np.random.normal(len(g)))
+ params1 = (1. + .2 * np.random.normal(len(g)), \
+ 1 + .2 * np.random.normal(len(g)))