summaryrefslogtreecommitdiffstats
path: root/hawkes/cause.py
blob: fddbfa9164a1f984bb847a96fca127b40126ec6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from cPickle import load
from math import exp, sin
from csv import reader
from data2 import parse
import sys
import networkx as nx
import matplotlib.pyplot as plt


def fatal():
    with open(sys.argv[1]) as fh:
        fh.readline()
        r = reader(fh)
        d = {i + 1: parse(row[7]) for (i, row) in enumerate(r)}
    d = {k: v for k, v in d.iteritems() if v}
    return d


def main(lamb, alpha, mu):
    G = nx.DiGraph()
    r, dr, i = 0, 0, 0
    drf, iff, rf = 0, 0, 0
    dnf, rnf, inf = 0, 0, 0
    si = 0
    f = fatal().items()
    l = []
    for ((n1, t1), s) in event_edges.iteritems():
        G.add_node((n1, t1))
        if not s:
            dr += 1
            if (n1, t1) in f:
                drf += 1
            else:
                dnf += 1
            continue
        br = lamb * (1 + 0.43 * sin(0.0172 * t1 + 4.36))
        prl = sorted([(n2, t2, alpha / d * mu * exp(-mu * (t1 - t2)))
                      for (n2, t2, d) in s], reverse=True)
        pr = sum(e[2] for e in prl)
        #if sum(e[2] for e in prl[:1]) > br:
        #    G.add_edge((n1, t1), tuple(prl[0][:2]))
        if br > pr:
            r += 1
            if (n1, t1) in f:
                rf += 1
            else:
                rnf += 1
        else:
            G.add_edge((n1, t1), tuple(prl[0][:2]))
            l.append(prl[0][2] / br)
            i += 1
            if (n1, t1) in f:
                iff += 1
            else:
                inf += 1
    print "nedges:", G.number_of_edges()
    cs = {}
    for c in nx.weakly_connected_components(G):
        cs[len(c)] = cs.get(len(c), 0) + 1
    cs = sorted(cs.iteritems(), key=lambda x: x[0])
    x, y = zip(*cs)
    print cs
    plt.loglog(x, y, "-")
    plt.xlabel("Cascade size")
    plt.ylabel("Number of cascades")
    plt.savefig("dist.pdf")
    l.sort(reverse=True)
    plt.plot(l)
    plt.show()
    return (lamb, alpha, mu, dr, r, i, drf, rf, iff,
            dnf, rnf, inf, si, len(event_edges))


if __name__ == "__main__":
    nodes, edges, events, event_edges = load(open("data2.pickle", "rb"))
    lamb, alpha, mu = 1.1847510744e-05, 0.00316718040144, 0.00393069204339
    # print len(event_edges), sum(len(e) for e in events.itervalues())
    # print len(fatal())
    print main(lamb, alpha, mu)