diff options
Diffstat (limited to 'simulation')
| -rw-r--r-- | simulation/main.py | 28 |
1 files changed, 20 insertions, 8 deletions
diff --git a/simulation/main.py b/simulation/main.py index bdc5c97..e4b30f2 100644 --- a/simulation/main.py +++ b/simulation/main.py @@ -125,22 +125,34 @@ if __name__ == "__main__": g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]]) p = 0.5 g = np.log(1. / (1 - p * g)) - sizes = [100, 500, 1000, 5000, 10000] # error = [] - def source(graph): + def source(graph, t): x0 = np.zeros(graph.shape[0], dtype=bool) a = randint(0, 1) x0[a] = True - if random() > 0.01: + if random() > t: x0[1-a] = True return x0 - for i in sizes: - cascades = simulate_cascades(i, g, source=source) - x, y = build_matrix(cascades, 2) - e = infer(x, y) - print e + thresh = np.arange(0., 1.1, step=0.2) + sizes = np.arange(10, 100, step=10) + nsimul = 10 + r = np.zeros(len(sizes), len(thresh)) + for t in thresh: + for i in nsimul: + cascades = simulate_cascades(np.max(sizes), g, + source=lambda graph: source(graph, t)) + e = np.zeros(g.shape[0]) + for j, s in enumerate(sizes): + x, y = build_matrix(cascades, 2) + e += infer(x[:s], y[:s]) + + for i, t in enumerate(thresh): + plt.plot(sizes, m[:, i], label=str(t)) + plt.legend() + plt.show() + # conf = bootstrap(x, y, n_iter=100) # estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1) |
