aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
Diffstat (limited to 'simulation')
-rw-r--r--simulation/active_blocks.py38
-rw-r--r--simulation/active_learning.ipynb248
-rw-r--r--simulation/vi_blocks.py37
3 files changed, 39 insertions, 284 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py
index e3924c6..7aa1afb 100644
--- a/simulation/active_blocks.py
+++ b/simulation/active_blocks.py
@@ -33,7 +33,7 @@ class LearnedDataset(fuel.datasets.Dataset):
i = 0
while i < request:
x_tmp, s_tmp = mn.build_cascade_list(
- mn.simulate_cascades(self.n_cascades, graph, self.source),
+ mn.simulate_cascades(self.n_cascades, self.graph, self.source),
collapse=True
)
x_obs[i:i + len(x_tmp)] = x_tmp[:request - i]
@@ -60,11 +60,9 @@ class ActiveLearning(blocks.extensions.SimpleExtension):
class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme):
- """
- Iteration scheme over finite dataset:
+ """Iteration scheme over finite dataset:
-shuffles batches but not within batch
- -arguments: dataset_size (int) ; batch_size (int)
- """
+ -arguments: dataset_size (int) ; batch_size (int)"""
def get_request_iterator(self):
indices = list(self.indices) # self.indices = xrange(dataset_size)
start = np.random.randint(self.batch_size)
@@ -80,12 +78,8 @@ class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme):
def create_mle_model(graph):
- """
- return cascade likelihood theano computation graph
- """
+ """return cascade likelihood theano computation graph"""
n_nodes = len(graph)
- g_shared = theano.shared(value=graph, name='graph')
-
x = tsr.matrix(name='x', dtype='int8')
s = tsr.matrix(name='s', dtype='int8')
params = theano.shared(
@@ -100,12 +94,19 @@ def create_mle_model(graph):
lkl_mle = lkl_pos + lkl_neg
lkl_mle.name = 'cost'
+ return x, s, params, lkl_mle
+
+
+def rmse_error(graph, params):
+ n_nodes = len(graph)
+ g_shared = theano.shared(value=graph, name='graph')
diff = (g_shared - params) ** 2
subarray = tsr.arange(g_shared.shape[0])
tsr.set_subtensor(diff[subarray, subarray], 0)
rmse = tsr.sum(diff) / (n_nodes ** 2)
rmse.name = 'rmse'
- return x, s, params, lkl_mle, rmse
+ g_shared.name = 'graph'
+ return rmse, g_shared
def create_fixed_data_stream(n_cascades, graph, batch_size, shuffle=True):
@@ -137,27 +138,26 @@ def create_learned_data_stream(graph, batch_size):
if __name__ == "__main__":
batch_size = 1000
- #graph = mn.create_random_graph(n_nodes=1000)
graph = mn.create_star(1000)
print('GRAPH:\n', graph, '\n-------------\n')
- x, s, params, cost, rmse = create_mle_model(graph)
+ x, s, params, cost = create_mle_model(graph)
+ rmse, g_shared = rmse_error(graph, params)
alg = blocks.algorithms.GradientDescent(
cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
)
data_stream = create_learned_data_stream(graph, batch_size)
- #n_cascades = 10000
- #data_stream = create_fixed_data_stream(n_cascades, graph, batch_size,
- # shuffle=False)
loop = blocks.main_loop.MainLoop(
alg, data_stream,
extensions=[
blocks.extensions.FinishAfter(after_n_batches = 10**4),
blocks.extensions.monitoring.TrainingDataMonitoring([cost, params,
- rmse], after_batch=True),
- blocks.extensions.Printing(every_n_batches = 10)#,
- #ActiveLearning(data_stream.dataset)
+ rmse, g_shared], after_batch=True),
+ blocks.extensions.Printing(every_n_batches = 10),
+ ActiveLearning(data_stream.dataset),
+ blocks.extras.extensions.Plot('graph rmse', channels=[],
+ every_n_batches = 10)
]
)
loop.run()
diff --git a/simulation/active_learning.ipynb b/simulation/active_learning.ipynb
deleted file mode 100644
index 1803946..0000000
--- a/simulation/active_learning.ipynb
+++ /dev/null
@@ -1,248 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Couldn't import dot_parser, loading of dot files will not be possible.\n"
- ]
- }
- ],
- "source": [
- "%matplotlib inline\n",
- "import numpy as np\n",
- "import main as main\n",
- "import bayes as bayes\n",
- "import matplotlib.pyplot as plt\n",
- "import pymc\n",
- "import mleNode as mn"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Active Learning"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "(a) Comparing a star network between always choosing the main source and choosing a source uniformly at random"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "n_nodes = 5\n",
- "n_cascades = 40\n",
- "g = np.vstack((np.ones(n_nodes), np.zeros((n_nodes-1, n_nodes))))\n",
- "g[0, 0] = 0\n",
- "p = 0.5\n",
- "g = np.log(1. / (1 - p * g))\n",
- "cascades = main.simulate_cascades(n_cascades, g)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "infected, susc = main.build_cascade_list(cascades)\n",
- "model = bayes.mc_graph_setup(infected, susc)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " [-----------------100%-----------------] 1000 of 1000 complete in 18.1 sec"
- ]
- }
- ],
- "source": [
- "mcmc = pymc.MCMC(model)\n",
- "n_total = 1000\n",
- "burn = 100\n",
- "n_simul = n_total - burn\n",
- "mcmc.sample(n_total, burn)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def formatLabel(s, n):\n",
- " return '0'*(len(str(n)) - len(str(s))) + str(s)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAGSCAYAAAAGmg7IAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3V+IXNmdH/Cf1Gr9Kc2oZWlHajMNUsMimGBIHgzBeCAL\n+5CQwMI+bLzrXgXiFXgMeQhkIJYJTmwcW5Alj/biF0OMlA2EsItZdrMZkuxm7AmsvV4vYrURM1PF\njHrULalH6m6p1Cp1SXnQ7Z7q6nvrT3f9uffW5/Oi6ntPVR/V6Vv3W+ece+6B58+fPw8AAOLguCsA\nAJAXghEAQEIwAgBICEYAAAnBCAAgcShrx8bGRly/fj1eeeWVmJqaGmWd2Idmsxl3796Nz3zmM3H0\n6NGI0JZFpB3LQ1uWh7Ysh7R2bJUZjK5fvx4LCwtDrRzDc/Xq1fjsZz8bEdqyyLRjeWjL8tCW5dDa\njq0yg9Err7yy/cTZ2dnh1YyBWlpaioWFhe32i/ikLec+90ZMHzu5vf1z5x7G7/z2r4+8jnTXTzu+\neuR2fOPNfzbyOtKbTm2Z9flarVbja9/7cRyfOZv5uo9Wl+PbX/l8zM/PD77SpOrnuIyI+Ae//Dgu\nfuHXRlpHuktrx1aZwWirS3B2djbm5uaGUzuGprVLd+vx9LGTMV05tb395ROHtG3O9dKOx44+0o4F\nkNaWWZ+v9Xo9Dh09saOd2x168jjOnj2r7cegl+MyIuLEjGMzz7KGPk2+BgBICEYAAAnBCAAgkTnH\nCIDh+ne/+5/i2PETu7bfX7kTES56gXEQjADG5NaTT8f01O4J1g8bOvNhXBx9AAAJwQgAICEYAQAk\nBCMAgIRgBACQEIwAABKCEQBAQjACAEgIRgAACcEIACAhGAEAJAQjAICEYAQAkBCMAAASghEAQOLQ\nuCsA7E+zuRk3b97ctf38+fNx+PDhMdQIoLgEIyi49Qcfx8XL16Iyc2Z7W331TvzwO1+MCxcujLFm\nAMUjGEEJVGbOxEufenXc1QAoPHOMAAASghEAQEIwAgBICEYAAAmTryfUs+bTuLt822XeANBCMJpQ\n9bU78T8/rMf/vfLWzu0u8wZggglGE8wl3gCwkzlGAAAJwQgAICEYAQAkBCMAgIRgBACQEIwAABIu\n1weAAXvWfBrLS+mL6EZYSDfPBCMooWfNzahWq7u2+zCG0aiv3Yn//uF6/J/339q9z0K6uSYYQQlt\nPFyJr3//najMvLe9zYcxjFbWIrpZX1wifHnJA8EISsrK5pBPaV9cInx5yQvBiB0MwQAMny8u+SUY\nsYMhGAAmmWDELr7JADCpBCOAEmo0GlGr1TqWMUQOuwlGACVUq9Xi4uVrUZk5k7rfEDmkE4wASsqw\nOPTPLUEAABJ6jAAKqNMigRHRcR+QTTACKKCsRQK3rNy6EafnXhtxraD4BCOAguo0h6i+ujzi2kA5\nCEYwIaxqDtCdYERXTqjlYFVzgO4EI7pyQi0Pl28DdCYY0RMnVAAmgXWMAAASghEAQMJQGkwwE+sB\ndhKMYIKZWA+wk2AEE87EeoBPCEYApGo0GlGr1TqWMexK2QhG7EmnG1j6oIRyqNVqcfHytajMnEnd\nb9iVMhKM2JOsG1j6oCw+E7InQ6cvN1uq1aqh1hHq1iaOwdEQjNgzH5jlZEL2ZMj6ctNq5daNOD33\n2ghrNdk6tYljcHQEI2AXoXcydGvn+uryCGtDhGMvDwQjAMg5w2yjIxgxUOanAAyeYbbREYwYKPNT\nAIYja5jNVcKDJRgxcMbIAUbHVcKDJRgBe5K1+J9vqJOjl0v+I/xNjELaF1LzkvZGMGLo+unmdbIt\njrTF/3xDnSy9XPLf7W/C6trDY17S3ghGDF0/3bxOtvmUFm7TFv8z12Hy7Hfo3Oraw2VeUv8EI0ai\n125eJ9t8Sgu3aYv/metAu27DOd1W1zZcNxxZx+rD+7fjW2+8HvPz86nPm4T3WTBibJxsi6X95JW1\n+F8/cx0m4UN20nUbbuu2uvYghutIl3as1leXM9/vTqGpTMdyZjBqNpsREbG0tDSyyrB/W+211X6t\nj2cafxtHpl6KiIippytx68Gh2NxY2/H8+v3FaDYe7dje67a9lD1y/GRsbhzd3vZ883Gs331/1/Pb\ny0VENBsPY3l5OSqVSg/vTLH02o4REYfiUep7ttd2HPS2iIgHS+/Gm1d+HkcqJ7e3Pak/iO9+80uZ\n30zLop+23JJ1fLbKeq8HtX+Qr5F2/G5JO+b7eX7E6D4L+mnLTm3Yz+fnfvbt9TlZ7/fG6u1488rv\n7ziOI4p3LKe1Y6vMYLS4uBgREQsLC0OoFsO2uLgY586d234cEfFXb//BrnLrKc9tpGzvdduwymY9\n/9KlH6VsLY9e2zFi7+/jKLa17mt16dJPMkqWTz9tuSXrfdzS6b0exP68vEYvv2OUnwX7OS639POZ\nuJ99g37O1v52RTyWW9ux1cEx1AUAIJcye4xeffXFuOPVq1djdnZ2ZBUqrNdff/Hv22/v2lWtVuNr\n3/txHJ85m/n0R6vL8e2vfH7fXZFLS0uxsLCw3X4R2rKIOrXj3OfeiOljJ+OXDn4YVy5/aVxVpEe9\ntGWaX/nlx/HbX/i1kdSR3vTaln9v9kH8i9/5jbHUke7S2rFVZjCampqKiIjZ2dmYm5sbTu3KKOW9\nqtfrcejoiZiunMp82qEnj+Ps2bMDe6+32q/1sbYsnrR2nD52MqYrp+LowVXtWSCd2jLNiZlH2jen\nurXlSy8/13YF0NqOrQylAQAkXK5fIL2sELu8nH4JNQDQnWBUIN1WiI2IWL/7/ghrBADlIhjlRC+r\nu3ZbITYiYnNjretlrQBAOsEoJ3pZ3bXbCrEAwP4IRjnSrTco6xYMAMBguCoNACAhGAEAJAQjAICE\nYAQAkBCMAAASghEAQEIwAgBICEYAAAnBCAAgIRgBACTcEmSfGo1G1Gq1mN/cjIiI6s2bu8p0uzks\nAJAPgtE+1Wq1uHj5Wvzh6uOIiPjylbd2lXHzVwAoBsFoACozZ+LAwamIiNSbwLr5KwAUg2AEBdVs\nbsbNlqHb8+fPx+HDh8dYI4DiE4ygoB6ufhwXL1+LysyZqK/eiR9+54tx4cKFcVcLoNAEIyiwysyZ\n1OFbAPZGMOpg64qzTlxxBgDlIRh1sHXFWWXmTGYZV5wBg/as+TSWl27vmEPWzpwyGA7BqItuQxWu\nOCMPnjU3d/ReOmkWW33tTvzJh+vx5+/vXv4jIswpy7Fnzadx7+7SrlDrmCwOwQhKYOPhSnz9++9E\nZeY9J82SMH+smOprd+KdDxvxi5Y17RyTxSIYQUk4kUI+OBaLzb3SAAASghEAQEIwAgBICEYAAImJ\nnXxt8UYAoN3EBiOLNwIA7SY2GEVYvBEA2MkcIwCAxET3GAHAsLXfsmeL24Tkk2AEAEPUesueLW4T\nkl+CEQAMmduEFIc5RgAACcEIACAhGAEAJAQjAICEYAQAkHBVGpRM+5op1kopn6x1cVppd9gbwQhK\npnXNFGullFPaujittHv+WfQxvwQjKCFrppSfNi42iz7ml2AEAGMg3OZTKYNRo9GIWq3WsUy38XkA\nYPKUMhjVarW4ePlaVGbOZJZZuXUjTs+9NsJaAQB5V8pgFNG9i7K+ujzC2gAARWAdIwCAhGAEAJAQ\njAAAEqWdYwQARZK26KMFH0dPMAKAHGhf9NGCj+MhGAFATlj0cfzMMQIASAhGAAAJQ2kAJZN15/ZW\nJvVCOsEIoGTS7tzeyqReyCYYAZSQSbywN+YYAQAkCtdj1Gg0olardSzTbWwdJkX7XBPzSqA4suaK\nOY6Hq3DBqFarxcXL16IycyazzMqtG3F67rUR1gryqXWuiXklUCxpc8Ucx8OXq2DUa29Qt7Hz+ury\ngGsGxbV1vLjdABRP+/lOL9Lw5SoY6Q2C4XG7ASg+vUjDl6tgFNH9Sgq9QbB3rlQiord1jiL0QuSV\n43i4cheMgPFqH9J2ciyfbuscReiFKBLDa4M10mD0h3/0p7G0vJK5f3npdkQcHV2FYIJlXbHWOqTt\n5Fheeh3Kw/DaYI00GP3Rn9+I25vnM/c/vC8Uwai0fpg+vH87vvXG6zE/P9/TBQ5AvvQySVsPUm8M\npcEE2/owra8ub4ckFzgwCL1cZexEPTwuttg7wQiIiJ0haYtvnZOr2wTtRqMREZH5t1CtVpMTc/pV\nxk7Uw9fai2QeUu8yg1Gz2YyIiKWlpYH9so1Ha/G0+XHm/s2NtXi8fi82N9Yyy9TvL0az8Sh3ZT46\n8DwiIh7cvjHW+jz6+IOI+KT9Wh8Psi0Zrq22SmvHp48fRETEwY1Hsb72fmxurO342+jlcUT0VO7B\n0rvx5pWfx5HKyYiIeFJ/EN/95pdifn5+lG9HofXSlu26fRZ2+yzY7/6I2NX27dbvfRCHKyc67n/5\nzHxsbqRPkWg2HsZPf/rTWF7e+5XGo/477KUt09ou7f3uZdten5e2La09J/V4TmvHVpnBaHFxMSIi\nFhYWhlCtzta77G/ksMxvnj314oef/WDs9Yl40X7nzp3bfhwxnrZkf9La8dY7v7ejzNbfQ6PPx/2W\n23Lp0k/6/n/QW1u263Ssd/ss2O/+1nKZ+x503r/+4HrH3/GNb/yohxrkTz/H5Za097uXbXt9Xqdt\nrSb5eG5tx1ZuIgsAkMjsMXr11RfjklevXo3Z2dmRVSjPqtVqfO17P47jM2d37bvyXy5HRMRXv/Cd\nuHfrehx7+ZXUcq0erS7Ht7/y+YF2Yy4tLcXCwsJ2+0V80pb/9Iv/PF4+cSIOxmb84V/U4/jM2aHU\ngf3r1I6OyWLp1JZzn3sjpo+lD0Vtee3UvfhXX/mtodaR3nRry4PTx+M3/v6R+Mf/6FfHVUV6kNaO\nrTKD0dTUVEREzM7Oxtzc3HBqVzD1ej0OHT0R05VTu/Z9uvksIiKmK6di6vBLmeVaHXryOM6ePTuU\n93er/Vof/8XS2ZheOxUnH/8sDh2djenKqaHWgf1La0fHZDGlteX0sZNdPyeOv/xUe+dMVlsePPxS\nfOpTR7VXQbS2YytDaQAACcEIACAhGAEAJAQjAICEYAQAkBCMAAASghEAQEIwAgBICEYAAInMla/J\nj0ajEbVaraey+7lTNQBMOsGoAGq1Wly8fC0qM2e6ll2/+/4IagQA5SQYFURl5ky89Kn0G9612txY\ni/UR1AcAysgcIwCAhGAEAJAQjAAAEhM/x6ifK76q1epwKwMAjNXEB6N+rvhauXUjTs+9NoJaAQDj\nMPHBKKL3K77qq9YIAqA37SMS58+fj8OHD4+vQvREMAKAIWgdkaiv3okffueLceHChXFXiy4EIwAY\nkl5HJMgPV6UBACQEIwCAhKG0MXrW3OxpCYBhLhOQVgcTBAGYVILRGG08XImvf/+dqMy817HcMJcJ\naK+DCYIATDLBaMx6mZg37GUCTA4EgBfMMQIASAhGAACJ0g6l9XoPNPc/A2AQnjU346OPFuPmzZsR\n4fxSVKUNRr3eA839zwAYhMdrd+La/74ff/DztyLC+aWoShuMIvIxsRmAydF63nF+KabCBSNDZAAU\njTXjiqNwwcgQGQBFY8244ihcMIowRAZA8Vgzrhhcrg8AkBCMAAASghEAQEIwAgBICEYAAAnBCAAg\nIRgBACQEIwCARCEXeARgp15vlxThVhTQiWBEz9I+eH3AQj70erskt6KAzgQjduh0o8P2D14fsJAv\nbjkB+ycYsUO3Gx364AWgzAQjdhF+AJhUrkoDAEjoMQLIqWfNp7Fydzlu3rzZtWz73EBgbwQjgJyq\nr92Jv1xtxpevvNW17MqtG3F67rUR1ArKLRfBqJ/1N3wrAiZJr3P+6qvLI6gNlF8uglGv629E+FYE\nAAxPLoJRhG9FAMD45SYYAcCkaF1Mt9FoRERs30XAHQXGSzCio9aD1/wugMFoXUx35daNOPby6ajM\nnHFHgRwQjOio/eA1vwuKLe22P53ovRierSkk9dVlC+vmiGBEV60HL1Bs7bf96UTvBZNIMAKYMHon\nIJtgxJ6ldcnrdgegyAQj9qy9S163OwBFJxixL7rkobz6maitt5iyEIwASNXrRG29xZSJYARAJr3C\nTBrBCABywkUt4ycYAUBOuKhl/AQjAMgRw5fjdXDcFQAAyIuh9hg1Go2o1Wpdy7k5KQCQB0MNRrVa\nLS5evhaVmTMdy7k5afmkhWITCAHIu6HPMeplrNTNScunPRSbQAhAEZh8zcC0XmZarVZ3hGKXoEJ5\nWSGbMhGMGJjWy0zbh0ddggrlZYVsykQwYqC2eonShkddggrl1cvx3U/PUoTepYjd75n3ZPj2FIx+\n+pe/iD975xddyy0vLUbEzF5+BQAl02vPUkTEw/u341tvvB7z8/MdyzUajYiInsJCP2XzEkBa37P2\n96T9/5OXOhddZjBqNpsREbG0tLRr3//4X2/Hn71b6friq3fq8bSxEpsbax3L1e8vRrPxqGu5fsoO\nuly3sh8deB4REQ9u3xhrHR99/EFEfNJ+rY+fPn4QERFPNh7F+r33Y3Njbddrd/p5kGUfr9+L5eXl\nqFS6/x1Noq3jLq0d045J8qtTW24dk1k2N9bi8fq9sXw2DuM16/cX48jxk7G5cbTr795YvR1vXvn9\nOFI52bHc+r0P4nDlRNdy/ZR9Un8Q3/3ml3aFsm5tufnkYTy5v9r3Z2K3clvvWft70vr/yaozu6W1\nY6vMYLS4uBgREQsLC/uuxHoPZRo9luun7KDLdSr7m2dPvXjwsx8M5Xf3U8eIF+137ty57ccREbfe\n+b0dZbZer/21O/08yLKXLv2ol//KREtrx0Eck4xeL8dklnF9Ng7jNfv9LGv0UuZBb+X6KXvp0k8y\n93Vry718Jvb7Obv9uOX/06nO7Nbajq2sfA0AkMjsMXr11ReT6K5evRqzs7Mjq1BeVKvV+Nr3fhzH\nZ85mlnm0uhzf/srnX3Rdvv76i41vvz2iGqZbWlqKhYWF7faL+KQt5z73RkwfOxmH6+/Fd//9G+Oq\nIj3opR0jIj71/IP4D//md8ZSR3rTqS0n9fO1qHo5LlfvvB+Hjhzfce7Yca5g7NLasVVmMJqamoqI\niNnZ2ZibmxtO7XKsXq/HoaMnYrpyKrPMoSeP4+zZszvfn5y8V1vt1/p4+tjJmK6cisPP70xkmxZR\np3aMiDj6/L62LIi0tpzUz9ei63RcTh25G4eOvrTj3JF6rmDsWtuxlaE0AICEYAQAkLDA45Cl3Uw1\njfUnAGD8BKMha7+ZappelsnvNWAtL7shLwDslWA0AoO4FUYvASsiYv3u+/v6PQAwyQSjAuklYG1u\nrPW1eBoA8AmTrwEAEoIRAEBCMAIASAhGAAAJk69z4FlzM6rVascy3fYDAPsnGOXAxsOV+Pr334nK\nzHuZZVZu3YjTc6+NsFYAMHkEo5zodil+fdXCjQAwbOYYAQAkBCMAgIRgBACQmMg5Rr3ckNVVYAAM\nQtqVx+fPn4/Dhw+PqUZ0MpHBqJcbsroKDIBBaL/yuL56J374nS/GhQsXxlwz0kxkMIpwFRgAo9PL\nTcDJh4kNRlAGzeZm3Lx5c8c2XfQAeycY7UPruPH85mZERFTbTlLmKjFMD9fu7xgW1kUPsD+C0T60\njhv/59XHERHx5Stv7SiTx7lKz/QylIoueoDBEYz2aeukdODgVETErhNUHucqPVrXywAAaQSjCaWX\nAQB2s8AjAECiMD1GvSzKGGGuDACwd4UJRr0symiuDACwH4UJRhHmxQAAw2WOEQBAQjACAEgIRgAA\nCcEIACAhGAEAJAQjAICEYAQAkCjUOkYAUHTPmptRrVZ3bHPXhvwoVTBK+2NL00sZABiGjYcr8fXv\nvxOVmfciwl0b8qZUwaj9jy3Lyq0bcXrutRHVCgB2cieH/CpVMIro7Y+tvro8otoUg25dAHihdMGI\n/unWBYAXBCMiQrcuAES4XB8AYJtgBACQMJQGJWIiPcD+CEZQIibSA+yPYAQlYyI9wN6ZYwQAkBCM\nAAASghEAQEIwAgBICEYAAAlXpQFAzjQajajVajt+johda5JZp2zwBCMAyJlarRYXL1+LysyZiIhY\nuXUjjr18evvnCOuUDYtgBAA51LomWX112RplIyIY0VV7l26E7luAQUm7lU/7z4yOYMQu7QdptVpN\nbjPxogtX921xuHca5F/7rXwiXgydnZ57bYy1mlyCEbu0H6RbB6gu3OJx7zQohvZhsvrq8hhrM9kE\nI1K1j21TXOYlAPTOOkYAAIlc9BilTe5tZyIaADBsuQhG7es1pDERDQAYtlwEo4ju8yDMc8mP9iud\nXOUEMHquOh2O3AQjiqP1SidXOQGMh6tOh0MwYk9c6VReFvSE4vBZPHiCEbBD+5w/30KBSSIYwQRJ\nm5MQsbtHyLdQKJ5ej286E4xggqTdeuDh/dvxrTdej/n5+YiwNAYUVdrxrce3f4IR++KqiOJJu/VA\n2i1ggOLR27t/ghH74qqIcuh0CxjhF4rL8du/oQejv77+N/Hf/vjHcfBA9t1Hlm7fiohXhl0VhqTT\nNxRXOBWf8AvF5fjt39CD0Y3/9278xeKn4uDUdGaZlVt34sjxYdeEUWj/dlKtVpOD0hVORdYafn0D\nhWIxvNYfQ2kMVPu3k635Kg7K8mhv4/bJ21uEJaCIBCMGznyV8mtvY1fCAGWRGYyazWZERCwtLe3r\nF9y//3E8uH0rDk5lZ7D6/Q+iUT8emxtrHcosRrPxaN9lBvlarWU+OvA8IiIe3L4x9N/VyaOPP4iI\nT9qv9fHanXdj6sjL0by/FI1Hsf1a7a/dz8/9PvfB0rvx5pWfx5HKyYiIeFJ/EN/95pd29TZMuq3j\nrlM7RkQcrd+J9ebTPbflXp6T9vOR4ydjc+PoJ3VtPIzl5eWoVCoDfmeKp1Nb7vfzldHq5bh8dO/D\nmDpybF/H2DDLPF6/N/HHZlo7tspMK4uLixERsbCwMIRqpVvvsr8xoDKDfK2tMr959tSLDT/7wdB/\nVy8WFxfj3Llz248jIlau/9cdZVpfq/21+/m53+dubdty6dJPMv8fk66XdozYX1vu5Tm9vOalSz/a\nVc9JltaWo/x8ZXD6/XyNGM5xudcyjs0XWtuxVfalYgAAEyazx+jVV1/MH7h69WrMzs6OrEKF9frr\nL/59++09Pb1arcbXvvfjOD5zNrPMo9Xl+PZXPt9x2GlpaSkWFha22y/ik7ac+9wbMX3s5Pb2f/h3\nGvEbv/5P9lRfhqufdvyVX34cv/2FXxt5HelNP20ZoT3zrFNbOlcWR1o7tsoMRlNTUxERMTs7G3Nz\nc8OpXRnt8b2q1+tx6OiJmK6cyixz6MnjOHv2bE/tsdV+rY+nj53c8fonTz7RtjnXSzvOzDzSjgXQ\nS1tGRJzQnrmX1pbOlcXT2o6tDKUBACQEIwCAhHWMBqDRaMSBzc2IiKjevJlaxlo9AJB/gtEA1Gq1\neOn+wzh4cCq+fOWtXfstdgcAxSAYDcjBg1Nx4OCUW18AQIEJRgA58az5NJaXbsfNlCF5w/EwGoIR\nQE7U1+7En3y4Hn/+/s4hecPxMDqCEUCOtN6gFxg9l+sDACQEIwCAhGAEAJAQjAAAEoIRAEBCMAIA\nSAhGAAAJwQgAICEYAQAkrHxdIM+am1GtVjuWmZ6eHlFtAKB8BKMC2Xi4El///jtRmXkvdX999U78\n7r/81RHXCgDKQzAqGPdRAoDhMccIACChx6iLRqMRtVqtY5lqtRp/dzTVAQCGSDDqolarxcXL16Iy\ncyazzMqtG/GnI6wTADAcglEPus3rqa8uj7A2g/Gs+TSWbn8UN2/e3LH9/Pnzcfjw4THVCgDGSzCa\nUPW1O/EHHz6IP/3bt7a3Pbx/O771xusxPz+/q7zABMAkEIwmWHtPWH11OXU5gPrqnfjhd74YFy5c\nGHUVAWCkBCN2sBxAsTxrPo3lpduGRAEGRDCCAquv3Yk//nA9/uz9T4ZE9fAB7J1gBAWnlw9gcCzw\nCACQEIwAABKCEQBAQjACAEhM9OTrXu+DBgBMhokORr3eB+303GsjrBUAMC4THYwiynkfNABgbyY+\nGAHk3bPmZsdhfSudw+AIRgA5t/FwJfU+hhFWOodBE4wACsAK5zAagtEIdOsGj3D1GwDkgWA0Ap26\nwbe4+g0Axk8wGhFXvwFA/ln5GgAgIRgBACRKO5Tmdh8AQL9KG4zc7mNwsq6qs6gcAGVT2mAUYcLz\noKRdVWdROQDKqNTBiMGxuBzkU6d10vTqQv8EI4ACy1onTa8u7I1gBFBwenRhcFyuDwCQEIwAABKC\nEQBAQjACAEiYfA0lY0FOgL0TjKBkLMgJsHeCEZSQy7cB9kYwYmA63bjXMM54GV4D6E1hg1Gnk3BE\nZC6Rz2CknWir1WoyhLPzxr0P79+Ob73xeszPz+/Y7qQ8OobXAHpT2GBUq9Xi4uVru07CW1Zu3YjT\nc6+NuFaTI+1Eu/Wetw/h1FeXnZRzwPAaQHeFDUYRnT/o66vLI67N5Gl//zu9507K+ZPW66cXD5h0\nhQ5GwN619/rpxSuXrHllEQIwdCIYwQTTk1deacPdEQIwdCMYAZSU4Av9c0sQAICEHiOACZI196jR\naEREpM49MieJSSIYAUyQrLlHK7duxLGXT+9aAsWcJCZNLoNRt8UbIyzgWHRWYobxSZt7VF9dTt3e\n6eq2CMcs5ZPLYNRt8cYICzgWnZWYoRiyepgiHLOUUy6DUUT3qyks4Fh8rpiBYnCsMknGEoz+9b/9\nj7H6JPtX31v+MOLAfOZ+ysmCdONleBNgTMFo5dHBWH5+LnP//ScbMX10hBUiFyxIN16GN+mXK9wo\no9wOpTGZ+pn86QN28Nrf/7T3Pu2k1+s2bVYurnCjjDKDUbPZjIiIpaWlgf/S9Xu12Ny8k7n/2f2P\nYv3QL8Xmxlpmmfr9xWg2HmWW6bZ/0GU+OvA8Dhx4Hg9u3xhbXR6v34t79z4dEZ+0X+vjzXt/FQeO\nHH+xbf3jqD97acdrZb1+2vZetw3i+Q+W3o03r/w8jlRObm97Un8Q3/3ml2J+vpxDrlvH3aDasZdt\naWXS3vtlOlCXAAABLElEQVT1ex/E4cqJvreVvc2y9NqWEentGdH52O/nuBvk9q19R46fjM2NnV38\nzzcfR7PxcNf2ZuNhLC8vR6VS2fVaRdCpLYdxrmQ40tqxVWYwunv3bkRELCwsDKFavVnvsr/RpUy3\n/YMs81uffuXFg5/9YKx1+epXfxQRL9rv3Llz248jIj786z/ZVb79tbJeP217r9sG8fytfa0uXfpJ\nRsnyGGQ79rKtU/vt+PnB3rZNQptl6bUtI3o/hrrtG/b2vTzn0qUfZbxScaS15TjPlexNazu2OvD8\n+fPnaU/Y2NiI69evxyuvvBJTU1NDryCD0Ww24+7du/GZz3wmjh598W1NWxaPdiwPbVke2rIc0tqx\nVWYwAgCYNG4iCwCQEIwAABKCEQBAQjACAEj8fxB90xAnEyeeAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<matplotlib.figure.Figure at 0x1195b3290>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "fig, ax = plt.subplots(len(g), len(g))\n",
- "for i in xrange(len(g)):\n",
- " for j in xrange(len(g)):\n",
- " ax[i, j].get_xaxis().set_ticks([])\n",
- " ax[i, j].get_yaxis().set_ticks([])\n",
- " if i != j:\n",
- " it, jt = formatLabel(i, len(g)-1), formatLabel(j, len(g)-1)\n",
- " ax[i,j].hist(mcmc.trace('theta_{}{}'.format(it,jt))[:], normed=True)\n",
- " ax[i, j].set_xlim([0,1])\n",
- " ax[i, j].plot([g[j, i]]*2, [0, ax[i,j].get_ylim()[-1]], color='red')\n",
- "plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=.1)\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "theta = np.empty((n_simul, n_nodes, n_nodes))\n",
- "for i in xrange(len(g)):\n",
- " for j in xrange(len(g)):\n",
- " it, jt = formatLabel(i, len(g)-1), formatLabel(j, len(g)-1)\n",
- " theta[:, i, j] = mcmc.trace('theta_{}{}'.format(it, jt))[:]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "- simulate cascades for each iteration of theta starting from specific source\n",
- "- calculate the likelihood of each cascade \n",
- "- calculate the marginal probability of each cascade\n",
- "- order nodes by difference"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def specific_source(i, graph):\n",
- " x0 = np.zeros(graph.shape[0], dtype=bool)\n",
- " x0[i] = True\n",
- " return x0"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[-0.02783062 0.02330009 0.02416332 -0.01541149 0.00959461]\n",
- "[0 3 4 1 2]\n"
- ]
- }
- ],
- "source": [
- "infected, susceptible = [], []\n",
- "rval = np.empty((theta.shape[0], theta.shape[1]))\n",
- "for node in xrange(n_nodes):\n",
- " for i, graph in enumerate(theta):\n",
- " casc = main.simulate_cascades(10, graph, \n",
- " source=lambda g, t: specific_source(node, g))\n",
- " x, s = main.build_cascade_list(casc, collapse=True)\n",
- " cond_lkl = main.cascadeLkl(graph, x, s)\n",
- " marg_lkl= np.mean([main.cascadeLkl(thet, x, s) for thet in theta])\n",
- " rval[i, node] = cond_lkl - marg_lkl\n",
- "rval = rval.mean(axis=0)\n",
- "print(rval)\n",
- "print(np.argsort(rval))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Error : the mutual information should always be positive"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/simulation/vi_blocks.py b/simulation/vi_blocks.py
index 58b68d3..dcf6b46 100644
--- a/simulation/vi_blocks.py
+++ b/simulation/vi_blocks.py
@@ -26,16 +26,16 @@ class ClippedParams(blocks.algorithms.StepRule):
def create_vi_model(n_nodes, n_samp=100):
"""return variational inference theano computation graph"""
- def aux():
- rand = .1 + .05 * np.random.normal(size=(n_nodes, n_nodes))
- return rand.astype(theano.config.floatX)
+ def aux(a, b):
+ rand = a + b * np.random.normal(size=(n_nodes, n_nodes))
+ return np.clip(rand, 1e-3, 1 - 1e-3).astype(theano.config.floatX)
x = tsr.matrix(name='x', dtype='int8')
s = tsr.matrix(name='s', dtype='int8')
- mu = theano.shared(value=aux(), name='mu1')
- sig = theano.shared(value=aux(), name='sig1')
- mu0 = theano.shared(value=aux(), name='mu0')
- sig0 = theano.shared(value=aux(), name='sig0')
+ mu = theano.shared(value=aux(.5, .1), name='mu1')
+ sig = theano.shared(value=aux(.5, .1), name='sig1')
+ mu0 = theano.shared(value=aux(.5, .1), name='mu0')
+ sig0 = theano.shared(value=aux(.5, .1), name='sig0')
srng = tsr.shared_randomstreams.RandomStreams(seed=123)
theta = srng.normal((n_samp, n_nodes, n_nodes)) * sig[None, :, :] + mu[None,
@@ -45,8 +45,8 @@ def create_vi_model(n_nodes, n_samp=100):
lkl_pos = tsr.sum(infect * (x[1:] & s[1:])) / n_samp
lkl_neg = tsr.sum(-y[0:-1].dimshuffle(1, 0, 2) * (~x[1:] & s[1:])) / n_samp
lkl = lkl_pos + lkl_neg
- kl = tsr.sum(tsr.log(sig / sig0) + (sig0**2 + (mu0 - mu)**2)/(2*sig)**2)
- cost = lkl + kl
+ kl = tsr.sum(tsr.log(sig0 / sig) + (sig**2 + (mu0 - mu)**2)/(2*sig0)**2)
+ cost = - lkl + kl
cost.name = 'cost'
return x, s, mu, sig, cost
@@ -55,25 +55,28 @@ def create_vi_model(n_nodes, n_samp=100):
if __name__ == "__main__":
n_cascades = 10000
batch_size = 1000
- graph = mn.create_random_graph(n_nodes=3)
+ n_samples = 50
+ graph = mn.create_random_graph(n_nodes=4)
print('GRAPH:\n', graph, '\n-------------\n')
- x, s, mu, sig, cost = create_vi_model(len(graph))
+ x, s, mu, sig, cost = create_vi_model(len(graph), n_samples)
+ rmse, g_shared = ab.rmse_error(graph, mu)
step_rules= blocks.algorithms.CompositeRule([blocks.algorithms.AdaDelta(),
ClippedParams(1e-3, 1 - 1e-3)])
- alg = blocks.algorithms.GradientDescent(cost=-cost, parameters=[mu, sig],
+ alg = blocks.algorithms.GradientDescent(cost=cost, parameters=[mu, sig],
step_rule=step_rules)
- data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
- shuffle=False)
+ #data_stream = ab.create_fixed_data_stream(n_cascades, graph, batch_size,
+ # shuffle=False)
+ data_stream = ab.create_learned_data_stream(graph, batch_size)
loop = blocks.main_loop.MainLoop(
alg, data_stream,
extensions=[
blocks.extensions.FinishAfter(after_n_batches = 10**4),
- blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig],
- after_batch=True),
- blocks.extensions.Printing(every_n_batches = 10),
+ blocks.extensions.monitoring.TrainingDataMonitoring([cost, mu, sig,
+ rmse, g_shared], after_batch=True),
+ blocks.extensions.Printing(every_n_batches = 100),
]
)
loop.run()