From f1762904c648b2089031ba6ce46ccaaac4f3514c Mon Sep 17 00:00:00 2001 From: Thibaut Horel Date: Mon, 30 Nov 2015 19:57:58 -0500 Subject: Big code cleanup --- simulation/mcmc.py | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 simulation/mcmc.py (limited to 'simulation/mcmc.py') diff --git a/simulation/mcmc.py b/simulation/mcmc.py new file mode 100644 index 0000000..bde9e94 --- /dev/null +++ b/simulation/mcmc.py @@ -0,0 +1,117 @@ +import pymc +import numpy as np + +def glm_node_setup(cascade, y_obs, prior=None, *args, **kwargs): + """ + Build an IC PyMC node-level model from: + -observed cascades: cascade + -outcome vector: y_obs + -desired PyMC prior and parameters: prior, *args + Note: we use the glm formulation: y = Bernoulli[f(x.dot(theta))] + """ + n_nodes = len(cascade[0]) + + # Container class for node's parents + theta = np.empty(n_nodes, dtype=object) + for j in xrange(n_nodes): + if prior is None: + theta[j] = pymc.Beta('theta_{}'.format(j), alpha=1, beta=1) + else: + theta[j] = prior('theta_{}'.format(j), *args, **kwargs) + + # Observed container class for cascade realization + x = np.empty(n_nodes, dtype=object) + for i, val in enumerate(cascade.T): + x[i] = pymc.Normal('x_{}'.format(i), 0, 1, value=val, observed=True) + + @pymc.deterministic + def glm_p(x=x, theta=theta): + return 1. - np.exp(-x.dot(theta)) + + @pymc.observed + def y(glm_p=glm_p, value=y_obs): + return pymc.bernoulli_like(value, glm_p) + + return pymc.Model([y, pymc.Container(theta), pymc.Container(x)]) + + +def formatLabel(s, n): + return '0'*(len(str(n)) - len(str(s))) + str(s) + + +def mc_graph_setup(infected, susceptible, prior=None, *args, **kwargs): + """ + Build an IC PyMC graph-level model from: + -infected nodes over time: list/tuple of list/tuple of np.array + -susceptible nodes over time: same format as above + Note: we use the Markov Chain formulation: X_{t+1}|X_t,theta = f(X_t.theta) + """ + + # Container class for graph parameters + n_nodes = len(infected[0][0]) + theta = np.empty((n_nodes,n_nodes), dtype=object) + if prior is None: + for i in xrange(n_nodes): + for j in xrange(n_nodes): + theta[i, j] = pymc.Beta('theta_{}{}'.format(formatLabel(i, + n_nodes-1), formatLabel(j, n_nodes-1)), + alpha=1, beta=1) + else: + theta = prior(theta=theta, *args, **kwargs) + + # Container class for cascade realization + x = {} + for i, cascade in enumerate(infected): + for j, step in enumerate(cascade): + for k, node in enumerate(step): + if j and susceptible[i][j][k]: + p = 1. - pymc.exp(-cascade[j-1].dot(theta[k])) + else: + p = .5 + x[i, j, k] = pymc.Bernoulli('x_{}{}{}'.format(i, j, k), p=p, + value=node, observed=True) + + return pymc.Model([pymc.Container(theta), pymc.Container(x)]) + + +if __name__=="__main__": + import main + import matplotlib.pyplot as plt + import seaborn + seaborn.set_style('whitegrid') + g = np.array([[0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]) + p = 0.5 + g = np.log(1. / (1 - p * g)) + + print('running the graph-level MC set-up') + n_nodes = len(g) + cascades = main.simulate_cascades(1000, g) + infected, susc = main.build_cascade_list(cascades) + model = mc_graph_setup(infected, susc) + mcmc = pymc.MCMC(model) + mcmc.sample(10**4, 1000) + fig, ax = plt.subplots(len(g), len(g)) + for i in xrange(n_nodes): + for j in xrange(n_nodes): + if n_nodes < 5: + ax[i,j].locator_params(nbins=3, axis='x') + else: + ax[i, j].get_xaxis().set_ticks([]) + ax[i, j].get_yaxis().set_ticks([]) + it, jt = formatLabel(i, n_nodes-1), formatLabel(j, n_nodes-1) + ax[i,j].hist(mcmc.trace('theta_{}{}'.format(it,jt))[:], normed=True) + ax[i, j].set_xlim([0,1]) + ax[i, j].plot([g[i, j]]*2, [0, ax[i,j].get_ylim()[-1]], color='red') + plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=.1) + plt.show() + + print('running the node level set-up') + node = 0 + cascades = main.simulate_cascades(100, g) + cascade, y_obs = main.build_matrix(cascades, node) + model = glm_node_setup(cascade, y_obs) + mcmc = pymc.MCMC(model) + mcmc.sample(1e5, 1e4) + plt.hist(mcmc.trace('theta_1')[:], bins=1e2) + plt.show() + -- cgit v1.2.3-70-g09d2