aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/utils.py
blob: aad777135b841e3712cdbf02987bdd2302ed8042 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import numpy.random as nr
from six.moves import range


def create_random_graph(n_nodes, p=.5):
    graph = .5 * np.random.binomial(2, p=.5, size=(n_nodes, n_nodes))
    for k in range(len(graph)):
        graph[k, k] = 0
    return np.log(1. / (1 - p * graph))


def create_wheel(n_nodes, p=.5):
    graph = np.zeros((n_nodes, n_nodes))
    graph[0] = np.ones(n_nodes)
    graph[0, 0] = 0
    for i in range(1, n_nodes-1):
        graph[i, i + 1] = 1
    graph[n_nodes-1, 1] = 1
    return np.log(1. / (1 - p * graph))


def simulate_cascade(x, graph):
    """
    Simulate an IC cascade given a graph and initial state.

    For each time step we yield:
        - susc: the nodes susceptible at the beginning of this time step
        - x: the subset of susc who became infected
    """
    yield x, np.zeros(graph.shape[0], dtype=bool)
    susc = np.ones(graph.shape[0], dtype=bool)
    while np.any(x):
        susc = susc ^ x  # nodes infected at previous step are now inactive
        if not np.any(susc):
            break
        x = 1 - np.exp(-np.dot(graph.T, x))
        y = nr.random(x.shape[0])
        x = (x >= y) & susc
        yield x, susc


def random_source(graph, node_p=None):
    n_nodes = graph.shape[0]
    if node_p is None:
        node_p = np.ones(n_nodes) / n_nodes
    x0 = np.zeros(graph.shape[0], dtype=bool)
    x0[nr.choice(n_nodes, p=node_p)] = True
    return x0


def simulate_cascades(n_obs, graph, source=random_source):
    n_nodes = graph.shape[0]
    x_obs = np.zeros((n_obs, n_nodes), dtype=bool)
    s_obs = np.zeros((n_obs, n_nodes), dtype=bool)
    i = 0
    while i < n_obs:
        for x, s in simulate_cascade(source(graph), graph):
            x_obs[i] = x
            s_obs[i] = s
            i += 1
            if i >= n_obs:
                break
    return x_obs, s_obs