aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-11 13:39:53 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-11-11 13:39:53 -0500
commiteab2cc05424475a1b14bf950decde1bae8d8cc9a (patch)
treed93ff5bd5d2aacc7f91f78816a898143e0933a15 /simulation
parent43b07f68505ee10f2fe9fa94f365b27fb72f953c (diff)
downloadcascades-eab2cc05424475a1b14bf950decde1bae8d8cc9a.tar.gz
moving node-specific methods to new file, adding cascade-level log-likelihood, committing active learning ipython notebook
Diffstat (limited to 'simulation')
-rw-r--r--simulation/active_learning.ipynb248
-rw-r--r--simulation/main.py109
-rw-r--r--simulation/mleNode.py72
3 files changed, 345 insertions, 84 deletions
diff --git a/simulation/active_learning.ipynb b/simulation/active_learning.ipynb
new file mode 100644
index 0000000..1803946
--- /dev/null
+++ b/simulation/active_learning.ipynb
@@ -0,0 +1,248 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Couldn't import dot_parser, loading of dot files will not be possible.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%matplotlib inline\n",
+ "import numpy as np\n",
+ "import main as main\n",
+ "import bayes as bayes\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pymc\n",
+ "import mleNode as mn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Active Learning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(a) Comparing a star network between always choosing the main source and choosing a source uniformly at random"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_nodes = 5\n",
+ "n_cascades = 40\n",
+ "g = np.vstack((np.ones(n_nodes), np.zeros((n_nodes-1, n_nodes))))\n",
+ "g[0, 0] = 0\n",
+ "p = 0.5\n",
+ "g = np.log(1. / (1 - p * g))\n",
+ "cascades = main.simulate_cascades(n_cascades, g)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "infected, susc = main.build_cascade_list(cascades)\n",
+ "model = bayes.mc_graph_setup(infected, susc)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " [-----------------100%-----------------] 1000 of 1000 complete in 18.1 sec"
+ ]
+ }
+ ],
+ "source": [
+ "mcmc = pymc.MCMC(model)\n",
+ "n_total = 1000\n",
+ "burn = 100\n",
+ "n_simul = n_total - burn\n",
+ "mcmc.sample(n_total, burn)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "def formatLabel(s, n):\n",
+ " return '0'*(len(str(n)) - len(str(s))) + str(s)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAGSCAYAAAAGmg7IAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3V+IXNmdH/Cf1Gr9Kc2oZWlHajMNUsMimGBIHgzBeCAL\n+5CQwMI+bLzrXgXiFXgMeQhkIJYJTmwcW5Alj/biF0OMlA2EsItZdrMZkuxm7AmsvV4vYrURM1PF\njHrULalH6m6p1Cp1SXnQ7Z7q6nvrT3f9uffW5/Oi6ntPVR/V6Vv3W+ece+6B58+fPw8AAOLguCsA\nAJAXghEAQEIwAgBICEYAAAnBCAAgcShrx8bGRly/fj1eeeWVmJqaGmWd2Idmsxl3796Nz3zmM3H0\n6NGI0JZFpB3LQ1uWh7Ysh7R2bJUZjK5fvx4LCwtDrRzDc/Xq1fjsZz8bEdqyyLRjeWjL8tCW5dDa\njq0yg9Err7yy/cTZ2dnh1YyBWlpaioWFhe32i/ikLec+90ZMHzu5vf1z5x7G7/z2r4+8jnTXTzu+\neuR2fOPNfzbyOtKbTm2Z9flarVbja9/7cRyfOZv5uo9Wl+PbX/l8zM/PD77SpOrnuIyI+Ae//Dgu\nfuHXRlpHuktrx1aZwWirS3B2djbm5uaGUzuGprVLd+vx9LGTMV05tb395ROHtG3O9dKOx44+0o4F\nkNaWWZ+v9Xo9Dh09saOd2x168jjOnj2r7cegl+MyIuLEjGMzz7KGPk2+BgBICEYAAAnBCAAgkTnH\nCIDh+ne/+5/i2PETu7bfX7kTES56gXEQjADG5NaTT8f01O4J1g8bOvNhXBx9AAAJwQgAICEYAQAk\nBCMAgIRgBACQEIwAABKCEQBAQjACAEgIRgAACcEIACAhGAEAJAQjAICEYAQAkBCMAAASghEAQOLQ\nuCsA7E+zuRk3b97ctf38+fNx+PDhMdQIoLgEIyi49Qcfx8XL16Iyc2Z7W331TvzwO1+MCxcujLFm\nAMUjGEEJVGbOxEufenXc1QAoPHOMAAASghEAQEIwAgBICEYAAAmTryfUs+bTuLt822XeANBCMJpQ\n9bU78T8/rMf/vfLWzu0u8wZggglGE8wl3gCwkzlGAAAJwQgAICEYAQAkBCMAgIRgBACQEIwAABIu\n1weAAXvWfBrLS+mL6EZYSDfPBCMooWfNzahWq7u2+zCG0aiv3Yn//uF6/J/339q9z0K6uSYYQQlt\nPFyJr3//najMvLe9zYcxjFbWIrpZX1wifHnJA8EISsrK5pBPaV9cInx5yQvBiB0MwQAMny8u+SUY\nsYMhGAAmmWDELr7JADCpBCOAEmo0GlGr1TqWMUQOuwlGACVUq9Xi4uVrUZk5k7rfEDmkE4wASsqw\nOPTPLUEAABJ6jAAKqNMigRHRcR+QTTACKKCsRQK3rNy6EafnXhtxraD4BCOAguo0h6i+ujzi2kA5\nCEYwIaxqDtCdYERXTqjlYFVzgO4EI7pyQi0Pl28DdCYY0RMnVAAmgXWMAAASghEAQMJQGkwwE+sB\ndhKMYIKZWA+wk2AEE87EeoBPCEYApGo0GlGr1TqWMexK2QhG7EmnG1j6oIRyqNVqcfHytajMnEnd\nb9iVMhKM2JOsG1j6oCw+E7InQ6cvN1uq1aqh1hHq1iaOwdEQjNgzH5jlZEL2ZMj6ctNq5daNOD33\n2ghrNdk6tYljcHQEI2AXoXcydGvn+uryCGtDhGMvDwQjAMg5w2yjIxgxUOanAAyeYbbREYwYKPNT\nAIYja5jNVcKDJRgxcMbIAUbHVcKDJRgBe5K1+J9vqJOjl0v+I/xNjELaF1LzkvZGMGLo+unmdbIt\njrTF/3xDnSy9XPLf7W/C6trDY17S3ghGDF0/3bxOtvmUFm7TFv8z12Hy7Hfo3Oraw2VeUv8EI0ai\n125eJ9t8Sgu3aYv/metAu27DOd1W1zZcNxxZx+rD+7fjW2+8HvPz86nPm4T3WTBibJxsi6X95JW1\n+F8/cx0m4UN20nUbbuu2uvYghutIl3as1leXM9/vTqGpTMdyZjBqNpsREbG0tDSyyrB/W+211X6t\nj2cafxtHpl6KiIippytx68Gh2NxY2/H8+v3FaDYe7dje67a9lD1y/GRsbhzd3vZ883Gs331/1/Pb\ny0VENBsPY3l5OSqVSg/vTLH02o4REYfiUep7ttd2HPS2iIgHS+/Gm1d+HkcqJ7e3Pak/iO9+80uZ\n30zLop+23JJ1fLbKeq8HtX+Qr5F2/G5JO+b7eX7E6D4L+mnLTm3Yz+fnfvbt9TlZ7/fG6u1488rv\n7ziOI4p3LKe1Y6vMYLS4uBgREQsLC0OoFsO2uLgY586d234cEfFXb//BrnLrKc9tpGzvdduwymY9\n/9KlH6VsLY9e2zFi7+/jKLa17mt16dJPMkqWTz9tuSXrfdzS6b0exP68vEYvv2OUnwX7OS639POZ\nuJ99g37O1v52RTyWW9ux1cEx1AUAIJcye4xeffXFuOPVq1djdnZ2ZBUqrNdff/Hv22/v2lWtVuNr\n3/txHJ85m/n0R6vL8e2vfH7fXZFLS0uxsLCw3X4R2rKIOrXj3OfeiOljJ+OXDn4YVy5/aVxVpEe9\ntGWaX/nlx/HbX/i1kdSR3vTaln9v9kH8i9/5jbHUke7S2rFVZjCampqKiIjZ2dmYm5sbTu3KKOW9\nqtfrcejoiZiunMp82qEnj+Ps2bMDe6+32q/1sbYsnrR2nD52MqYrp+LowVXtWSCd2jLNiZlH2jen\nurXlSy8/13YF0NqOrQylAQAkXK5fIL2sELu8nH4JNQDQnWBUIN1WiI2IWL/7/ghrBADlIhjlRC+r\nu3ZbITYiYnNjretlrQBAOsEoJ3pZ3bXbCrEAwP4IRjnSrTco6xYMAMBguCoNACAhGAEAJAQjAICE\nYAQAkBCMAAASghEAQEIwAgBICEYAAAnBCAAgIRgBACTcEmSfGo1G1Gq1mN/cjIiI6s2bu8p0uzks\nAJAPgtE+1Wq1uHj5Wvzh6uOIiPjylbd2lXHzVwAoBsFoACozZ+LAwamIiNSbwLr5KwAUg2AEBdVs\nbsbNlqHb8+fPx+HDh8dYI4DiE4ygoB6ufhwXL1+LysyZqK/eiR9+54tx4cKFcVcLoNAEIyiwysyZ\n1OFbAPZGMOpg64qzTlxxBgDlIRh1sHXFWWXmTGYZV5wBg/as+TSWl27vmEPWzpwyGA7BqItuQxWu\nOCMPnjU3d/ReOmkWW33tTvzJh+vx5+/vXv4jIswpy7Fnzadx7+7SrlDrmCwOwQhKYOPhSnz9++9E\nZeY9J82SMH+smOprd+KdDxvxi5Y17RyTxSIYQUk4kUI+OBaLzb3SAAASghEAQEIwAgBICEYAAImJ\nnXxt8UYAoN3EBiOLNwIA7SY2GEVYvBEA2MkcIwCAxET3GAHAsLXfsmeL24Tkk2AEAEPUesueLW4T\nkl+CEQAMmduEFIc5RgAACcEIACAhGAEAJAQjAICEYAQAkHBVGpRM+5op1kopn6x1cVppd9gbwQhK\npnXNFGullFPaujittHv+WfQxvwQjKCFrppSfNi42iz7ml2AEAGMg3OZTKYNRo9GIWq3WsUy38XkA\nYPKUMhjVarW4ePlaVGbOZJZZuXUjTs+9NsJaAQB5V8pgFNG9i7K+ujzC2gAARWAdIwCAhGAEAJAQ\njAAAEqWdYwQARZK26KMFH0dPMAKAHGhf9NGCj+MhGAFATlj0cfzMMQIASAhGAAAJQ2kAJZN15/ZW\nJvVCOsEIoGTS7tzeyqReyCYYAZSQSbywN+YYAQAkCtdj1Gg0olardSzTbWwdJkX7XBPzSqA4suaK\nOY6Hq3DBqFarxcXL16IycyazzMqtG3F67rUR1gryqXWuiXklUCxpc8Ucx8OXq2DUa29Qt7Hz+ury\ngGsGxbV1vLjdABRP+/lOL9Lw5SoY6Q2C4XG7ASg+vUjDl6tgFNH9Sgq9QbB3rlQiord1jiL0QuSV\n43i4cheMgPFqH9J2ciyfbuscReiFKBLDa4M10mD0h3/0p7G0vJK5f3npdkQcHV2FYIJlXbHWOqTt\n5Fheeh3Kw/DaYI00GP3Rn9+I25vnM/c/vC8Uwai0fpg+vH87vvXG6zE/P9/TBQ5AvvQySVsPUm8M\npcEE2/owra8ub4ckFzgwCL1cZexEPTwuttg7wQiIiJ0haYtvnZOr2wTtRqMREZH5t1CtVpMTc/pV\nxk7Uw9fai2QeUu8yg1Gz2YyIiKWlpYH9so1Ha/G0+XHm/s2NtXi8fi82N9Yyy9TvL0az8Sh3ZT46\n8DwiIh7cvjHW+jz6+IOI+KT9Wh8Psi0Zrq22SmvHp48fRETEwY1Hsb72fmxurO342+jlcUT0VO7B\n0rvx5pWfx5HKyYiIeFJ/EN/95pdifn5+lG9HofXSlu26fRZ2+yzY7/6I2NX27dbvfRCHKyc67n/5\nzHxsbqRPkWg2HsZPf/rTWF7e+5XGo/477KUt09ou7f3uZdten5e2La09J/V4TmvHVpnBaHFxMSIi\nFhYWhlCtzta77G/ksMxvnj314oef/WDs9Yl40X7nzp3bfhwxnrZkf9La8dY7v7ejzNbfQ6PPx/2W\n23Lp0k/6/n/QW1u263Ssd/ss2O/+1nKZ+x503r/+4HrH3/GNb/yohxrkTz/H5Za097uXbXt9Xqdt\nrSb5eG5tx1ZuIgsAkMjsMXr11RfjklevXo3Z2dmRVSjPqtVqfO17P47jM2d37bvyXy5HRMRXv/Cd\nuHfrehx7+ZXUcq0erS7Ht7/y+YF2Yy4tLcXCwsJ2+0V80pb/9Iv/PF4+cSIOxmb84V/U4/jM2aHU\ngf3r1I6OyWLp1JZzn3sjpo+lD0Vtee3UvfhXX/mtodaR3nRry4PTx+M3/v6R+Mf/6FfHVUV6kNaO\nrTKD0dTUVEREzM7Oxtzc3HBqVzD1ej0OHT0R05VTu/Z9uvksIiKmK6di6vBLmeVaHXryOM6ePTuU\n93er/Vof/8XS2ZheOxUnH/8sDh2djenKqaHWgf1La0fHZDGlteX0sZNdPyeOv/xUe+dMVlsePPxS\nfOpTR7VXQbS2YytDaQAACcEIACAhGAEAJAQjAICEYAQAkBCMAAASghEAQEIwAgBICEYAAInMla/J\nj0ajEbVaraey+7lTNQBMOsGoAGq1Wly8fC0qM2e6ll2/+/4IagQA5SQYFURl5ky89Kn0G9612txY\ni/UR1AcAysgcIwCAhGAEAJAQjAAAEhM/x6ifK76q1epwKwMAjNXEB6N+rvhauXUjTs+9NoJaAQDj\nMPHBKKL3K77qq9YIAqA37SMS58+fj8OHD4+vQvREMAKAIWgdkaiv3okffueLceHChXFXiy4EIwAY\nkl5HJMgPV6UBACQEIwCAhKG0MXrW3OxpCYBhLhOQVgcTBAGYVILRGG08XImvf/+dqMy817HcMJcJ\naK+DCYIATDLBaMx6mZg37GUCTA4EgBfMMQIASAhGAACJ0g6l9XoPNPc/A2AQnjU346OPFuPmzZsR\n4fxSVKUNRr3eA839zwAYhMdrd+La/74ff/DztyLC+aWoShuMIvIxsRmAydF63nF+KabCBSNDZAAU\njTXjiqNwwcgQGQBFY8244ihcMIowRAZA8Vgzrhhcrg8AkBCMAAASghEAQEIwAgBICEYAAAnBCAAg\nIRgBACQEIwCARCEXeARgp15vlxThVhTQiWBEz9I+eH3AQj70erskt6KAzgQjduh0o8P2D14fsJAv\nbjkB+ycYsUO3Gx364AWgzAQjdhF+AJhUrkoDAEjoMQLIqWfNp7Fydzlu3rzZtWz73EBgbwQjgJyq\nr92Jv1xtxpevvNW17MqtG3F67rUR1ArKLRfBqJ/1N3wrAiZJr3P+6qvLI6gNlF8uglGv629E+FYE\nAAxPLoJRhG9FAMD45SYYAcCkaF1Mt9FoRERs30XAHQXGSzCio9aD1/wugMFoXUx35daNOPby6ajM\nnHFHgRwQjOio/eA1vwuKLe22P53ovRierSkk9dVlC+vmiGBEV60HL1Bs7bf96UTvBZNIMAKYMHon\nIJtgxJ6ldcnrdgegyAQj9qy9S163OwBFJxixL7rkobz6maitt5iyEIwASNXrRG29xZSJYARAJr3C\nTBrBCABywkUt4ycYAUBOuKhl/AQjAMgRw5fjdXDcFQAAyIuh9hg1Go2o1Wpdy7k5KQCQB0MNRrVa\nLS5evhaVmTMdy7k5afmkhWITCAHIu6HPMeplrNTNScunPRSbQAhAEZh8zcC0XmZarVZ3hGKXoEJ5\nWSGbMhGMGJjWy0zbh0ddggrlZYVsykQwYqC2eonShkddggrl1cvx3U/PUoTepYjd75n3ZPj2FIx+\n+pe/iD975xddyy0vLUbEzF5+BQAl02vPUkTEw/u341tvvB7z8/MdyzUajYiInsJCP2XzEkBa37P2\n96T9/5OXOhddZjBqNpsREbG0tLRr3//4X2/Hn71b6friq3fq8bSxEpsbax3L1e8vRrPxqGu5fsoO\nuly3sh8deB4REQ9u3xhrHR99/EFEfNJ+rY+fPn4QERFPNh7F+r33Y3Njbddrd/p5kGUfr9+L5eXl\nqFS6/x1Noq3jLq0d045J8qtTW24dk1k2N9bi8fq9sXw2DuM16/cX48jxk7G5cbTr795YvR1vXvn9\nOFI52bHc+r0P4nDlRNdy/ZR9Un8Q3/3ml3aFsm5tufnkYTy5v9r3Z2K3clvvWft70vr/yaozu6W1\nY6vMYLS4uBgREQsLC/uuxHoPZRo9luun7KDLdSr7m2dPvXjwsx8M5Xf3U8eIF+137ty57ccREbfe\n+b0dZbZer/21O/08yLKXLv2ol//KREtrx0Eck4xeL8dklnF9Ng7jNfv9LGv0UuZBb+X6KXvp0k8y\n93Vry718Jvb7Obv9uOX/06nO7Nbajq2sfA0AkMjsMXr11ReT6K5evRqzs7Mjq1BeVKvV+Nr3fhzH\nZ85mlnm0uhzf/srnX3Rdvv76i41vvz2iGqZbWlqKhYWF7faL+KQt5z73RkwfOxmH6+/Fd//9G+Oq\nIj3opR0jIj71/IP4D//md8ZSR3rTqS0n9fO1qHo5LlfvvB+Hjhzfce7Yca5g7NLasVVmMJqamoqI\niNnZ2ZibmxtO7XKsXq/HoaMnYrpyKrPMoSeP4+zZszvfn5y8V1vt1/p4+tjJmK6cisPP70xkmxZR\np3aMiDj6/L62LIi0tpzUz9ei63RcTh25G4eOvrTj3JF6rmDsWtuxlaE0AICEYAQAkLDA45Cl3Uw1\njfUnAGD8BKMha7+ZappelsnvNWAtL7shLwDslWA0AoO4FUYvASsiYv3u+/v6PQAwyQSjAuklYG1u\nrPW1eBoA8AmTrwEAEoIRAEBCMAIASAhGAAAJk69z4FlzM6rVascy3fYDAPsnGOXAxsOV+Pr334nK\nzHuZZVZu3YjTc6+NsFYAMHkEo5zodil+fdXCjQAwbOYYAQAkBCMAgIRgBACQmMg5Rr3ckNVVYAAM\nQtqVx+fPn4/Dhw+PqUZ0MpHBqJcbsroKDIBBaL/yuL56J374nS/GhQsXxlwz0kxkMIpwFRgAo9PL\nTcDJh4kNRlAGzeZm3Lx5c8c2XfQAeycY7UPruPH85mZERFTbTlLmKjFMD9fu7xgW1kUPsD+C0T60\njhv/59XHERHx5Stv7SiTx7lKz/QylIoueoDBEYz2aeukdODgVETErhNUHucqPVrXywAAaQSjCaWX\nAQB2s8AjAECiMD1GvSzKGGGuDACwd4UJRr0symiuDACwH4UJRhHmxQAAw2WOEQBAQjACAEgIRgAA\nCcEIACAhGAEAJAQjAICEYAQAkCjUOkYAUHTPmptRrVZ3bHPXhvwoVTBK+2NL00sZABiGjYcr8fXv\nvxOVmfciwl0b8qZUwaj9jy3Lyq0bcXrutRHVCgB2cieH/CpVMIro7Y+tvro8otoUg25dAHihdMGI\n/unWBYAXBCMiQrcuAES4XB8AYJtgBACQMJQGJWIiPcD+CEZQIibSA+yPYAQlYyI9wN6ZYwQAkBCM\nAAASghEAQEIwAgBICEYAAAlXpQFAzjQajajVajt+johda5JZp2zwBCMAyJlarRYXL1+LysyZiIhY\nuXUjjr18evvnCOuUDYtgBAA51LomWX112RplIyIY0VV7l26E7luAQUm7lU/7z4yOYMQu7QdptVpN\nbjPxogtX921xuHca5F/7rXwiXgydnZ57bYy1mlyCEbu0H6RbB6gu3OJx7zQohvZhsvrq8hhrM9kE\nI1K1j21TXOYlAPTOOkYAAIlc9BilTe5tZyIaADBsuQhG7es1pDERDQAYtlwEo4ju8yDMc8mP9iud\nXOUEMHquOh2O3AQjiqP1SidXOQGMh6tOh0MwYk9c6VReFvSE4vBZPHiCEbBD+5w/30KBSSIYwQRJ\nm5MQsbtHyLdQKJ5ej286E4xggqTdeuDh/dvxrTdej/n5+YiwNAYUVdrxrce3f4IR++KqiOJJu/VA\n2i1ggOLR27t/ghH74qqIcuh0CxjhF4rL8du/oQejv77+N/Hf/vjHcfBA9t1Hlm7fiohXhl0VhqTT\nNxRXOBWf8AvF5fjt39CD0Y3/9278xeKn4uDUdGaZlVt34sjxYdeEUWj/dlKtVpOD0hVORdYafn0D\nhWIxvNYfQ2kMVPu3k635Kg7K8mhv4/bJ21uEJaCIBCMGznyV8mtvY1fCAGWRGYyazWZERCwtLe3r\nF9y//3E8uH0rDk5lZ7D6/Q+iUT8emxtrHcosRrPxaN9lBvlarWU+OvA8IiIe3L4x9N/VyaOPP4iI\nT9qv9fHanXdj6sjL0by/FI1Hsf1a7a/dz8/9PvfB0rvx5pWfx5HKyYiIeFJ/EN/95pd29TZMuq3j\nrlM7RkQcrd+J9ebTPbflXp6T9vOR4ydjc+PoJ3VtPIzl5eWoVCoDfmeKp1Nb7vfzldHq5bh8dO/D\nmDpybF/H2DDLPF6/N/HHZlo7tspMK4uLixERsbCwMIRqpVvvsr8xoDKDfK2tMr959tSLDT/7wdB/\nVy8WFxfj3Llz248jIlau/9cdZVpfq/21+/m53+dubdty6dJPMv8fk66XdozYX1vu5Tm9vOalSz/a\nVc9JltaWo/x8ZXD6/XyNGM5xudcyjs0XWtuxVfalYgAAEyazx+jVV1/MH7h69WrMzs6OrEKF9frr\nL/59++09Pb1arcbXvvfjOD5zNrPMo9Xl+PZXPt9x2GlpaSkWFha22y/ik7ac+9wbMX3s5Pb2f/h3\nGvEbv/5P9lRfhqufdvyVX34cv/2FXxt5HelNP20ZoT3zrFNbOlcWR1o7tsoMRlNTUxERMTs7G3Nz\nc8OpXRnt8b2q1+tx6OiJmK6cyixz6MnjOHv2bE/tsdV+rY+nj53c8fonTz7RtjnXSzvOzDzSjgXQ\nS1tGRJzQnrmX1pbOlcXT2o6tDKUBACQEIwCAhHWMBqDRaMSBzc2IiKjevJlaxlo9AJB/gtEA1Gq1\neOn+wzh4cCq+fOWtXfstdgcAxSAYDcjBg1Nx4OCUW18AQIEJRgA58az5NJaXbsfNlCF5w/EwGoIR\nQE7U1+7En3y4Hn/+/s4hecPxMDqCEUCOtN6gFxg9l+sDACQEIwCAhGAEAJAQjAAAEoIRAEBCMAIA\nSAhGAAAJwQgAICEYAQAkrHxdIM+am1GtVjuWmZ6eHlFtAKB8BKMC2Xi4El///jtRmXkvdX999U78\n7r/81RHXCgDKQzAqGPdRAoDhMccIACChx6iLRqMRtVqtY5lqtRp/dzTVAQCGSDDqolarxcXL16Iy\ncyazzMqtG/GnI6wTADAcglEPus3rqa8uj7A2g/Gs+TSWbn8UN2/e3LH9/Pnzcfjw4THVCgDGSzCa\nUPW1O/EHHz6IP/3bt7a3Pbx/O771xusxPz+/q7zABMAkEIwmWHtPWH11OXU5gPrqnfjhd74YFy5c\nGHUVAWCkBCN2sBxAsTxrPo3lpduGRAEGRDCCAquv3Yk//nA9/uz9T4ZE9fAB7J1gBAWnlw9gcCzw\nCACQEIwAABKCEQBAQjACAEhM9OTrXu+DBgBMhokORr3eB+303GsjrBUAMC4THYwiynkfNABgbyY+\nGAHk3bPmZsdhfSudw+AIRgA5t/FwJfU+hhFWOodBE4wACsAK5zAagtEIdOsGj3D1GwDkgWA0Ap26\nwbe4+g0Axk8wGhFXvwFA/ln5GgAgIRgBACRKO5Tmdh8AQL9KG4zc7mNwsq6qs6gcAGVT2mAUYcLz\noKRdVWdROQDKqNTBiMGxuBzkU6d10vTqQv8EI4ACy1onTa8u7I1gBFBwenRhcFyuDwCQEIwAABKC\nEQBAQjACAEiYfA0lY0FOgL0TjKBkLMgJsHeCEZSQy7cB9kYwYmA63bjXMM54GV4D6E1hg1Gnk3BE\nZC6Rz2CknWir1WoyhLPzxr0P79+Ob73xeszPz+/Y7qQ8OobXAHpT2GBUq9Xi4uVru07CW1Zu3YjT\nc6+NuFaTI+1Eu/Wetw/h1FeXnZRzwPAaQHeFDUYRnT/o66vLI67N5Gl//zu9507K+ZPW66cXD5h0\nhQ5GwN619/rpxSuXrHllEQIwdCIYwQTTk1deacPdEQIwdCMYAZSU4Av9c0sQAICEHiOACZI196jR\naEREpM49MieJSSIYAUyQrLlHK7duxLGXT+9aAsWcJCZNLoNRt8UbIyzgWHRWYobxSZt7VF9dTt3e\n6eq2CMcs5ZPLYNRt8cYICzgWnZWYoRiyepgiHLOUUy6DUUT3qyks4Fh8rpiBYnCsMknGEoz+9b/9\nj7H6JPtX31v+MOLAfOZ+ysmCdONleBNgTMFo5dHBWH5+LnP//ScbMX10hBUiFyxIN16GN+mXK9wo\no9wOpTGZ+pn86QN28Nrf/7T3Pu2k1+s2bVYurnCjjDKDUbPZjIiIpaWlgf/S9Xu12Ny8k7n/2f2P\nYv3QL8Xmxlpmmfr9xWg2HmWW6bZ/0GU+OvA8Dhx4Hg9u3xhbXR6v34t79z4dEZ+0X+vjzXt/FQeO\nHH+xbf3jqD97acdrZb1+2vZetw3i+Q+W3o03r/w8jlRObm97Un8Q3/3ml2J+vpxDrlvH3aDasZdt\naWXS3vtlOlCXAAABLElEQVT1ex/E4cqJvreVvc2y9NqWEentGdH52O/nuBvk9q19R46fjM2NnV38\nzzcfR7PxcNf2ZuNhLC8vR6VS2fVaRdCpLYdxrmQ40tqxVWYwunv3bkRELCwsDKFavVnvsr/RpUy3\n/YMs81uffuXFg5/9YKx1+epXfxQRL9rv3Llz248jIj786z/ZVb79tbJeP217r9sG8fytfa0uXfpJ\nRsnyGGQ79rKtU/vt+PnB3rZNQptl6bUtI3o/hrrtG/b2vTzn0qUfZbxScaS15TjPlexNazu2OvD8\n+fPnaU/Y2NiI69evxyuvvBJTU1NDryCD0Ww24+7du/GZz3wmjh598W1NWxaPdiwPbVke2rIc0tqx\nVWYwAgCYNG4iCwCQEIwAABKCEQBAQjACAEj8fxB90xAnEyeeAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x1195b3290>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig, ax = plt.subplots(len(g), len(g))\n",
+ "for i in xrange(len(g)):\n",
+ " for j in xrange(len(g)):\n",
+ " ax[i, j].get_xaxis().set_ticks([])\n",
+ " ax[i, j].get_yaxis().set_ticks([])\n",
+ " if i != j:\n",
+ " it, jt = formatLabel(i, len(g)-1), formatLabel(j, len(g)-1)\n",
+ " ax[i,j].hist(mcmc.trace('theta_{}{}'.format(it,jt))[:], normed=True)\n",
+ " ax[i, j].set_xlim([0,1])\n",
+ " ax[i, j].plot([g[j, i]]*2, [0, ax[i,j].get_ylim()[-1]], color='red')\n",
+ "plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=.1)\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "theta = np.empty((n_simul, n_nodes, n_nodes))\n",
+ "for i in xrange(len(g)):\n",
+ " for j in xrange(len(g)):\n",
+ " it, jt = formatLabel(i, len(g)-1), formatLabel(j, len(g)-1)\n",
+ " theta[:, i, j] = mcmc.trace('theta_{}{}'.format(it, jt))[:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- simulate cascades for each iteration of theta starting from specific source\n",
+ "- calculate the likelihood of each cascade \n",
+ "- calculate the marginal probability of each cascade\n",
+ "- order nodes by difference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "def specific_source(i, graph):\n",
+ " x0 = np.zeros(graph.shape[0], dtype=bool)\n",
+ " x0[i] = True\n",
+ " return x0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[-0.02783062 0.02330009 0.02416332 -0.01541149 0.00959461]\n",
+ "[0 3 4 1 2]\n"
+ ]
+ }
+ ],
+ "source": [
+ "infected, susceptible = [], []\n",
+ "rval = np.empty((theta.shape[0], theta.shape[1]))\n",
+ "for node in xrange(n_nodes):\n",
+ " for i, graph in enumerate(theta):\n",
+ " casc = main.simulate_cascades(10, graph, \n",
+ " source=lambda g, t: specific_source(node, g))\n",
+ " x, s = main.build_cascade_list(casc, collapse=True)\n",
+ " cond_lkl = main.cascadeLkl(graph, x, s)\n",
+ " marg_lkl= np.mean([main.cascadeLkl(thet, x, s) for thet in theta])\n",
+ " rval[i, node] = cond_lkl - marg_lkl\n",
+ "rval = rval.mean(axis=0)\n",
+ "print(rval)\n",
+ "print(np.argsort(rval))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Error : the mutual information should always be positive"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/simulation/main.py b/simulation/main.py
index 3e3de4b..3e458a7 100644
--- a/simulation/main.py
+++ b/simulation/main.py
@@ -1,3 +1,5 @@
+import mleNode as mn
+
import numpy as np
from numpy.linalg import norm
import numpy.random as nr
@@ -9,85 +11,6 @@ from random import random, randint
seaborn.set_style("white")
-def likelihood(p, x, y):
- a = np.dot(x, p)
- return np.log(1. - np.exp(-a[y])).sum() - a[~y].sum()
-
-
-def likelihood_gradient(p, x, y):
- a = np.dot(x, p)
- l = np.log(1. - np.exp(-a[y])).sum() - a[~y].sum()
- g1 = 1. / (np.exp(a[y]) - 1.)
- g = (x[y] * g1[:, np.newaxis]).sum(0) - x[~y].sum(0)
- return l, g
-
-
-def test_gradient(x, y):
- eps = 1e-10
- for i in xrange(x.shape[1]):
- p = 0.5 * np.ones(x.shape[1])
- a = np.dot(x, p)
- g1 = 1. / (np.exp(a[y]) - 1.)
- g = (x[y] * g1[:, np.newaxis]).sum(0) - x[~y].sum(0)
- p[i] += eps
- f1 = likelihood(p, x, y)
- p[i] -= 2 * eps
- f2 = likelihood(p, x, y)
- print g[i], (f1 - f2) / (2 * eps)
-
-
-def infer(x, y):
- def f(p):
- l, g = likelihood_gradient(p, x, y)
- return -l, -g
- x0 = np.ones(x.shape[1])
- bounds = [(1e-10, None)] * x.shape[1]
- return minimize(f, x0, jac=True, bounds=bounds, method="L-BFGS-B").x
-
-
-def bootstrap(x, y, n_iter=100):
- rval = np.zeros((n_iter, x.shape[1]))
- for i in xrange(n_iter):
- indices = np.random.choice(len(y), replace=False, size=int(len(y)*.9))
- rval[i] = infer(x[indices], y[indices])
- return rval
-
-
-def confidence_interval(counts, bins):
- k = 0
- while np.sum(counts[len(counts)/2-k:len(counts)/2+k]) <= .95*np.sum(counts):
- k += 1
- return bins[len(bins)/2-k], bins[len(bins)/2+k]
-
-
-def build_matrix(cascades, node):
-
- def aux(cascade, node):
- xlist, slist = zip(*cascade)
- indices = [i for i, s in enumerate(slist) if s[node] and i >= 1]
- if indices:
- x = np.vstack(xlist[i-1] for i in indices)
- y = np.array([xlist[i][node] for i in indices])
- return x, y
- else:
- return None
-
- pairs = (aux(cascade, node) for cascade in cascades)
- xs, ys = zip(*(pair for pair in pairs if pair))
- x = np.vstack(xs)
- y = np.concatenate(ys)
- return x, y
-
-
-def build_cascade_list(cascades):
- x, s = [], []
- for cascade in cascades:
- xlist, slist = zip(*cascade)
- x.append(xlist)
- s.append(slist)
- return x, s
-
-
def simulate_cascade(x, graph):
"""
Simulate an IC cascade given a graph and initial state.
@@ -120,6 +43,24 @@ def simulate_cascades(n, graph, source=uniform_source):
yield simulate_cascade(x0, graph)
+def build_cascade_list(cascades, collapse=False):
+ x, s = [], []
+ for cascade in cascades:
+ xlist, slist = zip(*cascade)
+ x.append(np.vstack(xlist))
+ s.append(np.vstack(slist))
+ if not collapse:
+ return x, s
+ else:
+ return np.vstack(x), np.vstack(s)
+
+
+def cascadeLkl(graph, infect, sus):
+ # There is a problem with the current implementation
+ a = np.dot(infect, graph)
+ return np.log(1. - np.exp(-a[(infect)*sus])).sum() - a[(~infect)*sus].sum()
+
+
if __name__ == "__main__":
# g = np.array([[0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]])
g = np.array([[0, 0, 1], [0, 0, 0.5], [0, 0, 0]])
@@ -145,17 +86,17 @@ if __name__ == "__main__":
source=lambda graph: source(graph, t))
e = np.zeros(g.shape[0])
for j, s in enumerate(sizes):
- x, y = build_matrix(cascades, 2)
- e += infer(x[:s], y[:s])
+ x, y = mn.build_matrix(cascades, 2)
+ e += mn.infer(x[:s], y[:s])
for i, t in enumerate(thresh):
- plt.plot(sizes, m[:, i], label=str(t))
+ plt.plot(sizes, e[:, i], label=str(t))
plt.legend()
plt.show()
- # conf = bootstrap(x, y, n_iter=100)
+ # conf = mn.bootstrap(x, y, n_iter=100)
# estimand = np.linalg.norm(np.delete(conf - g[0], 0, axis=1), axis=1)
- # error.append(confidence_interval(*np.histogram(estimand, bins=50)))
+ # error.append(mn.confidence_interval(*np.histogram(estimand, bins=50)))
# plt.semilogx(sizes, error)
# plt.show()
diff --git a/simulation/mleNode.py b/simulation/mleNode.py
new file mode 100644
index 0000000..ed32a12
--- /dev/null
+++ b/simulation/mleNode.py
@@ -0,0 +1,72 @@
+import numpy as np
+from scipy.optimize import minimize
+
+
+def likelihood(p, x, y):
+ a = np.dot(x, p)
+ return np.log(1. - np.exp(-a[y])).sum() - a[~y].sum()
+
+
+def likelihood_gradient(p, x, y):
+ a = np.dot(x, p)
+ l = np.log(1. - np.exp(-a[y])).sum() - a[~y].sum()
+ g1 = 1. / (np.exp(a[y]) - 1.)
+ g = (x[y] * g1[:, np.newaxis]).sum(0) - x[~y].sum(0)
+ return l, g
+
+
+def test_gradient(x, y):
+ eps = 1e-10
+ for i in xrange(x.shape[1]):
+ p = 0.5 * np.ones(x.shape[1])
+ a = np.dot(x, p)
+ g1 = 1. / (np.exp(a[y]) - 1.)
+ g = (x[y] * g1[:, np.newaxis]).sum(0) - x[~y].sum(0)
+ p[i] += eps
+ f1 = likelihood(p, x, y)
+ p[i] -= 2 * eps
+ f2 = likelihood(p, x, y)
+ print g[i], (f1 - f2) / (2 * eps)
+
+
+def infer(x, y):
+ def f(p):
+ l, g = likelihood_gradient(p, x, y)
+ return -l, -g
+ x0 = np.ones(x.shape[1])
+ bounds = [(1e-10, None)] * x.shape[1]
+ return minimize(f, x0, jac=True, bounds=bounds, method="L-BFGS-B").x
+
+
+def bootstrap(x, y, n_iter=100):
+ rval = np.zeros((n_iter, x.shape[1]))
+ for i in xrange(n_iter):
+ indices = np.random.choice(len(y), replace=False, size=int(len(y)*.9))
+ rval[i] = infer(x[indices], y[indices])
+ return rval
+
+
+def confidence_interval(counts, bins):
+ k = 0
+ while np.sum(counts[len(counts)/2-k:len(counts)/2+k]) <= .95*np.sum(counts):
+ k += 1
+ return bins[len(bins)/2-k], bins[len(bins)/2+k]
+
+
+def build_matrix(cascades, node):
+
+ def aux(cascade, node):
+ xlist, slist = zip(*cascade)
+ indices = [i for i, s in enumerate(slist) if s[node] and i >= 1]
+ if indices:
+ x = np.vstack(xlist[i-1] for i in indices)
+ y = np.array([xlist[i][node] for i in indices])
+ return x, y
+ else:
+ return None
+
+ pairs = (aux(cascade, node) for cascade in cascades)
+ xs, ys = zip(*(pair for pair in pairs if pair))
+ x = np.vstack(xs)
+ y = np.concatenate(ys)
+ return x, y