aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/mle_blocks.py
blob: 0d278696ab88c83808a9487aa234afcce5f7ab69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import utils
import utils_blocks as ub
import theano
from theano import tensor as tsr
from blocks import algorithms, main_loop
import blocks.extensions as be
import blocks.extensions.monitoring as bm
import numpy as np


def create_mle_model(graph):
    """return cascade likelihood theano computation graph"""
    n_nodes = len(graph)
    x = tsr.matrix(name='x', dtype='int8')
    s = tsr.matrix(name='s', dtype='int8')
    params = theano.shared(
        .5 + .01 *
        np.random.normal(size=(n_nodes, n_nodes)).astype(theano.config.floatX),
        name='params'
    )
    y = tsr.maximum(tsr.dot(x, params), 1e-5)
    infect = tsr.log(1. - tsr.exp(-y[0:-1]))
    lkl_pos = tsr.sum(infect * (x[1:] & s[1:]))
    lkl_neg = tsr.sum(-y[0:-1] * (~x[1:] & s[1:]))
    lkl_mle = lkl_pos + lkl_neg
    lkl_mle.name = 'cost'

    return x, s, params, lkl_mle


if __name__ == "__main__":
    batch_size = 100
    #n_obs = 100000
    freq = 10
    graph = utils.create_wheel(1000)

    print('GRAPH:\n', graph, '\n-------------\n')

    g_shared = theano.shared(value=graph, name='graph')
    x, s, params, cost = create_mle_model(graph)
    rmse = ub.rmse_error(g_shared, params)
    error = ub.relative_error(g_shared, params)

    alg = algorithms.GradientDescent(
        cost=-cost, parameters=[params], step_rule=algorithms.AdaDelta()
    )
    data_stream = ub.dynamic_data_stream(graph, batch_size)
    # data_stream = ub.fixed_data_stream(n_obs, graph, batch_size)
    loop = main_loop.MainLoop(
        alg, data_stream,
        log_backend="sqlite",
        extensions=[
            be.FinishAfter(after_n_batches=10**3),
            bm.TrainingDataMonitoring([cost, params,
                                       rmse, error], every_n_batches=freq),
            be.Printing(every_n_batches=freq),
            ub.JSONDump("logs/active_outdegree_mle.json", every_n_batches=freq),
            ub.ActiveLearning(data_stream.dataset, graph, every_n_batches=freq),
        ],
    )
    loop.run()