diff options
Diffstat (limited to 'simulation')
| -rw-r--r-- | simulation/mle.py | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/simulation/mle.py b/simulation/mle.py index c6b2e85..21b50b5 100644 --- a/simulation/mle.py +++ b/simulation/mle.py @@ -1,5 +1,6 @@ import numpy as np from scipy.optimize import minimize +import utils def likelihood(p, x, y): @@ -53,20 +54,20 @@ def confidence_interval(counts, bins): return bins[len(bins)/2-k], bins[len(bins)/2+k] -def build_matrix(cascades, node): - - def aux(cascade, node): - xlist, slist = zip(*cascade) - indices = [i for i, s in enumerate(slist) if s[node] and i >= 1] - if indices: - x = np.vstack(xlist[i-1] for i in indices) - y = np.array([xlist[i][node] for i in indices]) - return x, y - else: - return None - - pairs = (aux(cascade, node) for cascade in cascades) - xs, ys = zip(*(pair for pair in pairs if pair)) - x = np.vstack(xs) - y = np.concatenate(ys) +def build_matrix(x_obs, s_obs, node): + ind = s_obs[:, node] + ind_bis = np.zeros(x_obs.shape[0], dtype=bool) + ind_bis[:-1] = ind[1:] + ind_bis[-1] = False + y = x_obs[ind, node] + x = x_obs[ind_bis] return x, y + +if __name__ == "__main__": + n_obs = 10000 + graph = utils.create_wheel(10) + source = lambda graph: utils.constant_source(graph, 0) + x, s = utils.simulate_cascades(n_obs, graph, source) + x, y = build_matrix(x, s, 1) + print x, y + print infer(x, y)[0] |
