aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/mle_blocks.py
blob: 5acebab2f209204ce1a96ec4b3eacad62d672570 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import main as mn
import theano
from theano import tensor as tsr
import blocks
import blocks.algorithms, blocks.main_loop, blocks.extensions.monitoring
import theano.tensor.shared_randomstreams
import picklable_itertools
import numpy as np
from six.moves import range
import fuel
import fuel.datasets
import collections


class JeaninuScheme(fuel.schemes.ShuffledScheme):
    def get_request_iterator(self):
        indices = list(self.indices)
        start = np.random.randint(self.batch_size)
        batches = list(map(
            list,
            picklable_itertools.extras.partition_all(self.batch_size,
                                                     indices[start:])
        ))
        if indices[:start]:
            batches.append(indices[:start])
        batches = np.asarray(batches)
        return iter(batches[np.random.permutation(len(batches))])


def create_model(n_nodes):
    x = tsr.matrix(name='x', dtype='int8')
    s = tsr.matrix(name='s', dtype='int8')
    params = theano.shared(
        .5 + .01 *
        np.random.normal(size=(n_nodes, n_nodes)).astype(theano.config.floatX),
        name='params'
    )
    y = tsr.maximum(tsr.dot(x, params), 1e-5)
    infect = tsr.log(1. - tsr.exp(-y[0:-1]))
    lkl_pos = tsr.sum(infect * (x[1:] & s[1:]))
    lkl_neg = tsr.sum(-y[0:-1] * (~x[1:] & s[1:]))
    lkl_mle = lkl_pos + lkl_neg
    lkl_mle.name = 'cost'
    return x, s, params, lkl_mle


def create_random_graph(n_nodes, p=.5):
    graph = .5 * np.random.binomial(2, p=.5, size=(n_nodes, n_nodes))
    for k in range(len(graph)):
        graph[k, k] = 0
    return np.log(1. / (1 - p * graph))


def create_data_stream(n_cascades, graph, batch_size, shuffle=True):
    """
    shuffle (bool): shuffle minibatches but not within minibatch
    """
    cascades = mn.build_cascade_list(mn.simulate_cascades(n_cascades, graph),
                                     collapse=True)
    x_obs, s_obs = cascades[0], cascades[1]
    data_set = fuel.datasets.base.IndexableDataset(collections.OrderedDict(
        [('x', x_obs), ('s', s_obs)]
    ))
    if shuffle:
        scheme = JeaninuScheme(len(x_obs), batch_size=batch_size)
    else:
        scheme = fuel.schemes.SequentialScheme(len(x_obs),
                                               batch_size=batch_size)
    return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme)


if __name__ == "__main__":
    n_cascades = 10000
    batch_size = 1000
    graph = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
    graph = np.log(1. / (1 - .5 * graph))
    print('GRAPH:\n', graph, '\n-------------\n')

    x, s, params, cost = create_model(len(graph))

    alg = blocks.algorithms.GradientDescent(
       cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
    )
    data_stream = create_data_stream(n_cascades, graph, batch_size,
            shuffle=True)
    loop = blocks.main_loop.MainLoop(
        alg, data_stream,
        extensions=[
            blocks.extensions.FinishAfter(after_n_epochs = 1000),
            blocks.extensions.monitoring.TrainingDataMonitoring([cost, params],
                after_batch=True),
            blocks.extensions.Printing()
        ]
    )
    loop.run()