diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-11-30 20:11:09 -0500 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-11-30 20:11:09 -0500 |
| commit | d2d9ab7ba51f2cc382a2676aaacecbf75bcfadc8 (patch) | |
| tree | 58c74dde087ed45f469a1ec37e0a43f0ca67ff24 /simulation/mle.py | |
| parent | f457b3480d53920a0d8b0e3b8cdb2b18601088ee (diff) | |
| download | cascades-d2d9ab7ba51f2cc382a2676aaacecbf75bcfadc8.tar.gz | |
Fix mle.py
Diffstat (limited to 'simulation/mle.py')
| -rw-r--r-- | simulation/mle.py | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/simulation/mle.py b/simulation/mle.py index c6b2e85..21b50b5 100644 --- a/simulation/mle.py +++ b/simulation/mle.py @@ -1,5 +1,6 @@ import numpy as np from scipy.optimize import minimize +import utils def likelihood(p, x, y): @@ -53,20 +54,20 @@ def confidence_interval(counts, bins): return bins[len(bins)/2-k], bins[len(bins)/2+k] -def build_matrix(cascades, node): - - def aux(cascade, node): - xlist, slist = zip(*cascade) - indices = [i for i, s in enumerate(slist) if s[node] and i >= 1] - if indices: - x = np.vstack(xlist[i-1] for i in indices) - y = np.array([xlist[i][node] for i in indices]) - return x, y - else: - return None - - pairs = (aux(cascade, node) for cascade in cascades) - xs, ys = zip(*(pair for pair in pairs if pair)) - x = np.vstack(xs) - y = np.concatenate(ys) +def build_matrix(x_obs, s_obs, node): + ind = s_obs[:, node] + ind_bis = np.zeros(x_obs.shape[0], dtype=bool) + ind_bis[:-1] = ind[1:] + ind_bis[-1] = False + y = x_obs[ind, node] + x = x_obs[ind_bis] return x, y + +if __name__ == "__main__": + n_obs = 10000 + graph = utils.create_wheel(10) + source = lambda graph: utils.constant_source(graph, 0) + x, s = utils.simulate_cascades(n_obs, graph, source) + x, y = build_matrix(x, s, 1) + print x, y + print infer(x, y)[0] |
