aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/vi_blocks.py
blob: 3f88611a90683fd70e784f67fec1194025bcc5ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import main as mn
import theano
from theano import tensor as tsr
import blocks
import blocks.algorithms, blocks.main_loop, blocks.extensions.monitoring
import theano.tensor.shared_randomstreams
import numpy as np
from six.moves import range
import fuel
import fuel.datasets
import active_blocks as ab


class ClippedParams(blocks.algorithms.StepRule):
    """A rule to maintain parameters within a specified range"""
    def __init__(self, min_value, max_value):
        self.min_value = min_value
        self.max_value = max_value

    def compute_step(self, parameter, previous_step):
        min_clipped = tsr.switch(parameter - previous_step < self.min_value,
                                 self.min_value, previous_step)
        return tsr.switch(parameter - previous_step > self.max_value,
                          self.max_value, min_clipped)


def create_vi_model(n_nodes, n_samp=100):
    """return variational inference theano computation graph"""
    def aux():
        rand = .1 + .05 * np.random.normal(size=(n_nodes, n_nodes))
        return rand.astype(theano.config.floatX)

    x = tsr.matrix(name='x', dtype='int8')
    s = tsr.matrix(name='s', dtype='int8')
    mu = theano.shared(value=aux(), name='mu1')
    sig = theano.shared(value=aux(), name='sig1')
    mu0 = theano.shared(value=aux(), name='mu0')
    sig0 = theano.shared(value=aux(), name='sig0')

    srng = tsr.shared_randomstreams.RandomStreams(seed=123)
    theta = srng.normal((n_samp, n_nodes, n_nodes)) * sig[None, :, :] + mu[None,
            :, :]
    y = tsr.maximum(tsr.dot(x, theta), 1e-3)
    infect = tsr.log(1. - tsr.exp(-y[0:-1])).dimshuffle(1, 0, 2)
    lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samp
    lkl_neg = tsr.sum(-y[0:-1].dimshuffle(1, 0, 2) * (~x[1:] & s[1:])) / n_samp
    lkl = lkl_pos + lkl_neg
    kl = tsr.sum(tsr.log(sig / sig0) + (sig0**2 + (mu0 - mu)**2)/(2*sig)**2)
    cost = lkl + kl
    cost.name = 'cost'

    return x, s, mu, sig, cost


if __name__ == "__main__":
    n_cascades = 10000
    batch_size = 1000
    graph = mn.create_random_graph(n_nodes=3)
    print('GRAPH:\n', graph, '\n-------------\n')

    x, s, mu, sig, cost = create_vi_model(len(graph))

    step_rules= blocks.algorithms.CompositeRule([blocks.algorithms.AdaDelta(),
        ClippedParams(1e-3, 1 - 1e-3)])

    alg = blocks.algorithms.GradientDescent(cost=-cost, parameters=[mu, sig],
        step_rule=step_rules)
    data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
            shuffle=False)
    loop = blocks.main_loop.MainLoop(
        alg, data_stream,
        extensions=[
            blocks.extensions.FinishAfter(after_n_batches = 10**4),
            blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig],
                after_batch=True),
            blocks.extensions.Printing(every_n_batches = 10),
        ]
    )
    loop.run()