From d2d9ab7ba51f2cc382a2676aaacecbf75bcfadc8 Mon Sep 17 00:00:00 2001 From: Thibaut Horel Date: Mon, 30 Nov 2015 20:11:09 -0500 Subject: Fix mle.py --- simulation/mle.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) (limited to 'simulation') diff --git a/simulation/mle.py b/simulation/mle.py index c6b2e85..21b50b5 100644 --- a/simulation/mle.py +++ b/simulation/mle.py @@ -1,5 +1,6 @@ import numpy as np from scipy.optimize import minimize +import utils def likelihood(p, x, y): @@ -53,20 +54,20 @@ def confidence_interval(counts, bins): return bins[len(bins)/2-k], bins[len(bins)/2+k] -def build_matrix(cascades, node): - - def aux(cascade, node): - xlist, slist = zip(*cascade) - indices = [i for i, s in enumerate(slist) if s[node] and i >= 1] - if indices: - x = np.vstack(xlist[i-1] for i in indices) - y = np.array([xlist[i][node] for i in indices]) - return x, y - else: - return None - - pairs = (aux(cascade, node) for cascade in cascades) - xs, ys = zip(*(pair for pair in pairs if pair)) - x = np.vstack(xs) - y = np.concatenate(ys) +def build_matrix(x_obs, s_obs, node): + ind = s_obs[:, node] + ind_bis = np.zeros(x_obs.shape[0], dtype=bool) + ind_bis[:-1] = ind[1:] + ind_bis[-1] = False + y = x_obs[ind, node] + x = x_obs[ind_bis] return x, y + +if __name__ == "__main__": + n_obs = 10000 + graph = utils.create_wheel(10) + source = lambda graph: utils.constant_source(graph, 0) + x, s = utils.simulate_cascades(n_obs, graph, source) + x, y = build_matrix(x, s, 1) + print x, y + print infer(x, y)[0] -- cgit v1.2.3-70-g09d2