diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-15 20:16:25 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-15 20:16:25 -0500 |
| commit | ec129983fbac109000d436ce526763206b81aa76 (patch) | |
| tree | ecc83ddf6752e8f4c20cb843ff3173524bf8ec79 /simulation/vi.py | |
| parent | 730b35ee58cd616ea626bde7171144c326d7cfd8 (diff) | |
| download | cascades-ec129983fbac109000d436ce526763206b81aa76.tar.gz | |
variational inference sgd function
Diffstat (limited to 'simulation/vi.py')
| -rw-r--r-- | simulation/vi.py | 54 |
1 files changed, 46 insertions, 8 deletions
diff --git a/simulation/vi.py b/simulation/vi.py index 9f2ed40..3ba6b96 100644 --- a/simulation/vi.py +++ b/simulation/vi.py @@ -1,3 +1,5 @@ +import time +import main as mn import autograd.numpy as np from autograd import grad @@ -10,20 +12,56 @@ def h(m): return -m -def ll(cascades, theta): +def ll(x, s, theta): + """ + x : infected + s : susceptible + """ res = 0 - for x, s in cascades: - for t in range(1, x.shape[1] + 1): - w = np.dot(x[t-1], theta) - res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum() + for t in range(1, x.shape[1] + 1): + w = np.dot(x[t-1], theta) + res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum() def sample(params): return np.random.normal(*params) -def ll_full(params, cascades): - nsamples = 50 - return np.mean([ll(cascades, sample(params)) for _ in xrange(nsamples)]) +def ll_full(params, x, s, nsamples=50): + return np.mean([ll(x, s, sample(params)) for _ in xrange(nsamples)]) + grad_ll_full = grad(ll_full) + + +def kl(params1, params0): + mu0, sig0 = params0 + mu1, sig1 = params1 + return np.sum(np.log(sig1/sig0) + (sig0**2 + (mu0 - mu1)**2)/(2*sig1)**2) + + +grad_kl = grad(kl) + + +def sgd(mu1, sig1, mu0, sig0, cascades, n_e, lr=lambda t: 1e-3): + g_mu1, g_sig1 = grad_kl((mu1, sig1), (mu0, sig0)) + for t in xrange(n_e): + lrt = lr(t) # learning rate + mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1 + for x, s in cascades: + g_mu1, g_sig1 = grad_ll_full((mu1, sig1), x, s) + mu1, sig1 = mu1 + lrt * g_mu1, sig1 + lrt * g_sig1 + res = np.sum(ll_full((mu1, sig1), x, s) for x, s in cascades) +\ + kl((mu1, sig1), (mu0, sig0)) + print("Epoch: {}\t LB: {}\t Time: {}".format(t, res, time.time())) + + +if __name__=='__main__': + g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]]) + p = 0.5 + g = np.log(1. / (1 - p * g)) + cascades = mn.build_cascade_list(mn.simulate_cascades(100, g)) + params0 = (1. + .2 * np.random.normal(len(g)), \ + 1 + .2 * np.random.normal(len(g))) + params1 = (1. + .2 * np.random.normal(len(g)), \ + 1 + .2 * np.random.normal(len(g))) |
