aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
Diffstat (limited to 'simulation')
-rw-r--r--simulation/vi.py38
1 files changed, 23 insertions, 15 deletions
diff --git a/simulation/vi.py b/simulation/vi.py
index 3ba6b96..9604c7d 100644
--- a/simulation/vi.py
+++ b/simulation/vi.py
@@ -5,6 +5,7 @@ from autograd import grad
def g(m):
+ assert (m > 0).all()
return np.log(1 - np.exp(-m))
@@ -18,13 +19,16 @@ def ll(x, s, theta):
s : susceptible
"""
res = 0
- for t in range(1, x.shape[1] + 1):
+ for t in range(1, x.shape[0]):
w = np.dot(x[t-1], theta)
res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum()
+ return res
def sample(params):
- return np.random.normal(*params)
+ mu, v = params
+ size = mu.shape
+ return np.maximum(np.random.normal(size=size) * v + mu, 1e-3)
def ll_full(params, x, s, nsamples=50):
@@ -43,25 +47,29 @@ def kl(params1, params0):
grad_kl = grad(kl)
-def sgd(mu1, sig1, mu0, sig0, cascades, n_e, lr=lambda t: 1e-3):
+def sgd(mu1, sig1, mu0, sig0, cascades, n_e=100, lr=lambda t: 1e-2):
g_mu1, g_sig1 = grad_kl((mu1, sig1), (mu0, sig0))
for t in xrange(n_e):
lrt = lr(t) # learning rate
mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1
- for x, s in cascades:
+ for x, s in zip(*cascades):
g_mu1, g_sig1 = grad_ll_full((mu1, sig1), x, s)
- mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1
- res = np.sum(ll_full((mu1, sig1), x, s) for x, s in cascades) +\
- kl((mu1, sig1), (mu0, sig0))
+ mu1 = np.maximum(mu1 + lrt * g_mu1, 0)
+ sig1 = np.maximum(sig1 + lrt * g_sig1, 1e-3)
+ res = np.sum(ll_full((mu1, sig1), x, s) for x, s in zip(*cascades)) + \
+ kl((mu1, sig1), (mu0, sig0))
print("Epoch: {}\t LB: {}\t Time: {}".format(t, res, time.time()))
+ print mu1
+ print sig1
-if __name__=='__main__':
- g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
+if __name__ == '__main__':
+ graph = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
p = 0.5
- g = np.log(1. / (1 - p * g))
- cascades = mn.build_cascade_list(mn.simulate_cascades(100, g))
- params0 = (1. + .2 * np.random.normal(len(g)), \
- 1 + .2 * np.random.normal(len(g)))
- params1 = (1. + .2 * np.random.normal(len(g)), \
- 1 + .2 * np.random.normal(len(g)))
+ graph = np.log(1. / (1 - p * graph))
+ cascades = mn.build_cascade_list(mn.simulate_cascades(1000, graph))
+ mu0, sig0 = (1. + .2 * np.random.normal(size=graph.shape),
+ 1 + .2 * np.random.normal(size=graph.shape))
+ mu1, sig1 = (1. + .2 * np.random.normal(size=graph.shape),
+ 1 + .2 * np.random.normal(size=graph.shape))
+ sgd(mu1, sig1, mu0, sig0, cascades, n_e=30)