aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--simulation/mle.py33
1 files changed, 17 insertions, 16 deletions
diff --git a/simulation/mle.py b/simulation/mle.py
index c6b2e85..21b50b5 100644
--- a/simulation/mle.py
+++ b/simulation/mle.py
@@ -1,5 +1,6 @@
import numpy as np
from scipy.optimize import minimize
+import utils
def likelihood(p, x, y):
@@ -53,20 +54,20 @@ def confidence_interval(counts, bins):
return bins[len(bins)/2-k], bins[len(bins)/2+k]
-def build_matrix(cascades, node):
-
- def aux(cascade, node):
- xlist, slist = zip(*cascade)
- indices = [i for i, s in enumerate(slist) if s[node] and i >= 1]
- if indices:
- x = np.vstack(xlist[i-1] for i in indices)
- y = np.array([xlist[i][node] for i in indices])
- return x, y
- else:
- return None
-
- pairs = (aux(cascade, node) for cascade in cascades)
- xs, ys = zip(*(pair for pair in pairs if pair))
- x = np.vstack(xs)
- y = np.concatenate(ys)
+def build_matrix(x_obs, s_obs, node):
+ ind = s_obs[:, node]
+ ind_bis = np.zeros(x_obs.shape[0], dtype=bool)
+ ind_bis[:-1] = ind[1:]
+ ind_bis[-1] = False
+ y = x_obs[ind, node]
+ x = x_obs[ind_bis]
return x, y
+
+if __name__ == "__main__":
+ n_obs = 10000
+ graph = utils.create_wheel(10)
+ source = lambda graph: utils.constant_source(graph, 0)
+ x, s = utils.simulate_cascades(n_obs, graph, source)
+ x, y = build_matrix(x, s, 1)
+ print x, y
+ print infer(x, y)[0]