diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-23 13:50:01 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2015-11-23 13:50:01 -0500 |
| commit | 4383fbb4179e73e8b4c28eabd64199b3f7a16eee (patch) | |
| tree | 41ac39c33b33ff9b6f89b4ab800b7c50dab818a7 /simulation/main.py | |
| parent | 65638ea7d886c25b8bf75e43a5ce46db2ebbaf53 (diff) | |
| download | cascades-4383fbb4179e73e8b4c28eabd64199b3f7a16eee.tar.gz | |
adding likelihood to theano implementation
Diffstat (limited to 'simulation/main.py')
| -rw-r--r-- | simulation/main.py | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/simulation/main.py b/simulation/main.py index c2446d7..decdf8f 100644 --- a/simulation/main.py +++ b/simulation/main.py @@ -19,9 +19,9 @@ def simulate_cascade(x, graph): - susc: the nodes susceptible at the beginning of this time step - x: the subset of susc who became infected """ + yield x, np.zeros(graph.shape[0], dtype=bool) susc = np.ones(graph.shape[0], dtype=bool) - susc = susc ^ x # t=0, the source is not susceptible - yield x, susc + #yield x, susc while np.any(x): susc = susc ^ x # nodes infected at previous step are now inactive if not np.any(susc): @@ -60,13 +60,16 @@ if __name__ == "__main__": g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]]) p = 0.5 g = np.log(1. / (1 - p * g)) - error = [] - sizes = [10, 10**2, 10**3] - for s in sizes: - cascades = simulate_cascades(s, g) - cascade, y_obs = mn.build_matrix(cascades, 0) - conf = mn.bootstrap(cascade, y_obs, n_iter=100) - estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1) - error.append(mn.confidence_interval(*np.histogram(estimand, bins=50))) - plt.semilogx(sizes, error) - plt.show() + print(g) + sizes = [10**3] + for si in sizes: + cascades = simulate_cascades(si, g) + cascade, y_obs = mn.build_matrix(cascades, 2) + print(mn.infer(cascade, y_obs)) + #conf = mn.bootstrap(cascade, y_obs, n_iter=100) + #estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1) + #plt.hist(estimand, bins=40) + #plt.show() + #error.append(mn.confidence_interval(*np.histogram(estimand, bins=50))) + #plt.plot(sizes, error) + #plt.show() |
