aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi.py
blob: 9f2ed401dadd44da4bf5ef450973e340691899f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import autograd.numpy as np
from autograd import grad


def g(m):
    return np.log(1 - np.exp(-m))


def h(m):
    return -m


def ll(cascades, theta):
    res = 0
    for x, s in cascades:
        for t in range(1, x.shape[1] + 1):
            w = np.dot(x[t-1], theta)
            res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum()


def sample(params):
    return np.random.normal(*params)


def ll_full(params, cascades):
    nsamples = 50
    return np.mean([ll(cascades, sample(params)) for _ in xrange(nsamples)])

grad_ll_full = grad(ll_full)