diff options
Diffstat (limited to 'simulation/main.py')
| -rw-r--r-- | simulation/main.py | 42 |
1 files changed, 8 insertions, 34 deletions
diff --git a/simulation/main.py b/simulation/main.py index 402aa8d..4fa8f6c 100644 --- a/simulation/main.py +++ b/simulation/main.py @@ -19,7 +19,8 @@ def simulate_cascade(x, graph): - susc: the nodes susceptible at the beginning of this time step - x: the subset of susc who became infected """ - susc = np.ones(graph.shape[0], dtype=bool) # t=0, everyone is susceptible + susc = np.ones(graph.shape[0], dtype=bool) + susc = susc ^ x # t=0, the source is not susceptible yield x, susc while np.any(x): susc = susc ^ x # nodes infected at previous step are now inactive @@ -66,41 +67,14 @@ def cascadeLkl(graph, infect, sus): if __name__ == "__main__": - # g = np.array([[0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]) g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]]) p = 0.5 g = np.log(1. / (1 - p * g)) - # error = [] + cascades = simulate_cascades(100, g) + cascade, y_obs = mn.build_matrix(cascades, 0) + conf = mn.bootstrap(x, y, n_iter=100) - def source(graph, t): - x0 = np.zeros(graph.shape[0], dtype=bool) - a = randint(0, 1) - x0[a] = True - if random() > t: - x0[1-a] = True - return x0 - - thresh = np.arange(0., 1.1, step=0.2) - sizes = np.arange(10, 100, step=10) - nsimul = 10 - r = np.zeros(len(sizes), len(thresh)) - for t in thresh: - for i in nsimul: - cascades = simulate_cascades(np.max(sizes), g, - source=lambda graph: source(graph, t)) - e = np.zeros(g.shape[0]) - for j, s in enumerate(sizes): - x, y = mn.build_matrix(cascades, 2) - e += mn.infer(x[:s], y[:s]) - - for i, t in enumerate(thresh): - plt.plot(sizes, e[:, i], label=str(t)) - plt.legend() + estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1) + error.append(mn.confidence_interval(*np.histogram(estimand, bins=50))) + plt.semilogx(sizes, error) plt.show() - - - # conf = mn.bootstrap(x, y, n_iter=100) - # estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1) - # error.append(mn.confidence_interval(*np.histogram(estimand, bins=50))) - # plt.semilogx(sizes, error) - # plt.show() |
