aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_theano.py
blob: 9e8fdb0b33f83ea1b307b1b23e99acdad37a6c83 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import main as mn
import theano
from theano import tensor as tsr
import theano.tensor.shared_randomstreams
import numpy as np

n_cascades = 1000
n_nodes = 4
n_samples = 100
srng = tsr.shared_randomstreams.RandomStreams(seed=123)
lr = 5*1e-3
n_epochs = 20

###############Variational Inference####################

# Declare Theano variables
mu = theano.shared(.5 + .2 * np.random.normal(size=(1, n_nodes, n_nodes)),
                    name="mu", broadcastable=(True, False, False))
sig = theano.shared(.1 + .04 * np.random.normal(size=(1, n_nodes, n_nodes)),
                    name="sig", broadcastable=(True, False, False))
mu0 = theano.shared(.5 + .2 * np.random.normal(size=(1, n_nodes, n_nodes)),
                    name="mu", broadcastable=(True, False, False))
sig0 = theano.shared(.1 + .04 * np.random.normal(size=(1, n_nodes, n_nodes)),
                    name="sig", broadcastable=(True, False, False))
x = tsr.matrix(name='x', dtype='int8')
s = tsr.matrix(name='s', dtype='int8')

# Construct Theano graph
theta = srng.normal((n_samples, n_nodes, n_nodes)) * sig + mu
y = tsr.maximum(tsr.dot(x, theta), 1e-3)
infect = tsr.log(1. - tsr.exp(-y[0:-1])).dimshuffle(1, 0, 2)
lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samples
lkl_neg = tsr.sum(-y[0:-1].dimshuffle(1, 0, 2) * (~x[1:] & s[1:])) / n_samples
lkl = lkl_pos + lkl_neg
kl = tsr.sum(tsr.log(sig / sig0) + (sig0**2 + (mu0 - mu)**2)/(2*sig)**2)
res = lkl + kl

gmu, gsig = theano.gradient.grad(lkl, [mu, sig])
gmukl, gsigkl = theano.grad(kl, [mu, sig])

# Compile into functions
loglkl_full = theano.function([x, s], lkl)
train = theano.function(inputs=[x, s], outputs=res,
                        updates=((mu, tsr.clip(mu + lr * gmu, 0, 1)),
                                 (sig, tsr.clip(sig + lr * gsig, 1e-3, 1))))
train_kl = theano.function(inputs=[], outputs=[],
                           updates=((mu, tsr.clip(mu + lr * gmukl, 0, 1)),
                                   (sig, tsr.clip(sig + lr * gsigkl, 1e-3, 1))))


###############Maximum Likelihood#####################

x = tsr.matrix(name='x', dtype='int8')
s = tsr.matrix(name='s', dtype='int8')
params = theano.shared(.5 + .01*np.random.normal(size=(n_nodes, n_nodes)),
                        name='params')
y = tsr.maximum(tsr.dot(x, params), 1e-5)
infect = tsr.log(1. - tsr.exp(-y[0:-1]))
lkl_pos = tsr.sum(infect * (x[1:] & s[1:]))
lkl_neg = tsr.sum(-y[0:-1] * (~x[1:] & s[1:]))
lkl_mle = lkl_pos + lkl_neg
gparams = theano.gradient.grad(lkl_mle, params)
train_mle = theano.function(inputs=[x, s], outputs=lkl_mle, updates=[(params,
                                        tsr.clip(params + lr * gparams, 0, 1))])


if __name__ == "__main__":
    graph = .5 * np.random.binomial(2, p=.5, size=(n_nodes, n_nodes))
    for k in range(len(graph)):
        graph[k, k] = 0
    p = 0.5
    graph = np.log(1. / (1 - p * graph))
    cascades = mn.build_cascade_list(mn.simulate_cascades(n_cascades, graph),
                                     collapse=True)
    x_obs, s_obs = cascades[0], cascades[1]

    #mle
    lkl_plot = []
    if 0:
        for i in range(n_epochs):
            for xt, st in zip(x_obs, s_obs):
                lkl = train_mle(xt, st)
                lkl_plot.append(lkl)
        print(graph)
        w = params.get_value()
        for k in range(len(w)):
            w[k, k] = 0
        print(w)
        import matplotlib.pyplot as plt
        plt.plot(lkl_plot)
        plt.show()

    #variational inference
    if 1:
        for i in range(n_epochs):
            train_kl()
            for k in xrange(len(x_obs)/100):
                cost = train(x_obs[k*100:(k+1)*100], s_obs[k*100:(k+1)*100])
            print(cost)
        print(graph)
        print(mu.get_value())
        print(sig.get_value())