aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_beta.py
blob: e3bcbf62717bbddd405971b75092f3764da46808 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import time
import main as mn
import numpy as np
import logging
import scipy.special as ssp
from itertools import product

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    level=logging.INFO)

def g(m):
    assert (m > 0).all()
    return np.log(1 - np.exp(-m))


def h(m):
    return -m


def ll(x, s, theta):
    """
    x : infected
    s : susceptible
    """
    res = 0
    for t in range(1, x.shape[0]):
        w = np.dot(x[t-1], theta)
        res += g(w)[x[t]].sum() + h(w)[~x[t] & s[t]].sum()
    return res


def sample(params):
    mu, v = params
    size = mu.shape
    return np.clip(np.random.beta(mu, v, size=size), 1e-3, 1e5)


def ll_full(params, x, s, nsamples=50):
    return np.mean([ll(x, s, sample(params)) for _ in xrange(nsamples)])


def kl(params1, params0):
    mu0, sig0 = params0
    mu1, sig1 = params1
    return (ssp.betaln(mu0, sig0) - ssp.betaln(mu1, sig1) + (mu0 - mu1) *
            ssp.psi(mu0) + (sig0 - sig1) * ssp.psi(sig0) + (mu1 - mu0 + sig1 -
                sig0) * ssp.psi(mu0 + sig0)).sum()


def aux(var, res, i, j, f, eps):
    var[i,j] += eps
    res[i,j] += f(var)
    var[i,j] -= 2 * eps
    res[i,j] -= f(var)
    res[i,j] /= 2 * eps
    var[i, j] += eps


def grad_ll_full(params, x, s, nsamples=50, eps=1e-5):
    mu, v = params
    n, m = mu.shape
    mugrad = np.empty((n,m))
    vgrad = np.empty((n,m))
    for (i, j) in product(xrange(n), xrange(m)):
        aux(mu, mugrad, i, j, lambda t: ll_full((t, v), x, s, nsamples), eps)
        aux(v, vgrad, i, j, lambda t: ll_full((mu, t), x, s, nsamples), eps)

    return mugrad, vgrad


def grad_kl(params1, params0, eps=1e-5):
    mu0, sig0 = params0
    mu1, sig1 = params1

    n, m = mu0.shape
    mugrad = np.empty((n,m))
    vgrad = np.empty((n,m))
    for (i, j) in product(xrange(n), xrange(m)):
        aux(mu1, mugrad, i, j, lambda t: kl((t, sig1), params0), eps)
        aux(sig1, vgrad, i, j, lambda t: kl((mu1, t), params0), eps)

    return mugrad, vgrad


def sgd(mu1, sig1, mu0, sig0, cascades, n_e=100, lr=lambda t: 1e-1, n_print=10):
    g_mu1, g_sig1 = grad_kl((mu1, sig1), (mu0, sig0))
    for t in xrange(n_e):
        lrt = lr(t)  # learning rate
        mu1 = np.clip(mu1 + lrt * g_mu1, 1e-3, 1e5)
        sig1 = np.clip(sig1 + lrt * g_sig1, 1e-3, 1e5)
        for step, (x, s) in enumerate(zip(*cascades)):
            g_mu1, g_sig1 = grad_ll_full((mu1, sig1), x, s)
            mu1 = np.clip(mu1 + lrt * g_mu1, 1e-3, 1e5)
            sig1 = np.clip(sig1 + lrt * g_sig1, 1e-3, 1e5)
            res = np.sum(ll_full((mu1, sig1), x, s) for x, s in zip(*cascades))\
                    + kl((mu1, sig1), (mu0, sig0))
            #if step % n_print == 0:
        logging.info("Epoch:{}\tStep:{}\tLB:{}\t".format(t, step, res))
        print mu1
        print sig1


if __name__ == '__main__':
    graph = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
    #graph = np.random.binomial(2, p=.2, size=(4, 4))
    p = 0.5
    graph = np.log(1. / (1 - p * graph))
    print(graph)
    cascades = mn.build_cascade_list(mn.simulate_cascades(100, graph))
    mu0, sig0 = (1. + .2 * np.random.normal(size=graph.shape),
                 1 + .2 * np.random.normal(size=graph.shape))
    mu1, sig1 = (1. + .2 * np.random.normal(size=graph.shape),
                 1 + .2 * np.random.normal(size=graph.shape))
    sgd(mu1, sig1, mu0, sig0, cascades, n_e=30, n_print=1)