summaryrefslogtreecommitdiffstats
path: root/hawkes_experiments/cause.py
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2015-09-14 23:08:02 -0400
committerThibaut Horel <thibaut.horel@gmail.com>2015-09-14 23:08:02 -0400
commitab0b1f3cefedb35327a19ec1b6afd560bfdf802d (patch)
treeb777f3e2c0ac0e712d8c5faab5107b1d236e2c3a /hawkes_experiments/cause.py
parent960676226862d2d68c7a9c04c56d4f8157803025 (diff)
downloadcriminal_cascades-ab0b1f3cefedb35327a19ec1b6afd560bfdf802d.tar.gz
Import supplements and repo reorganization
Diffstat (limited to 'hawkes_experiments/cause.py')
-rw-r--r--hawkes_experiments/cause.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/hawkes_experiments/cause.py b/hawkes_experiments/cause.py
new file mode 100644
index 0000000..17ec884
--- /dev/null
+++ b/hawkes_experiments/cause.py
@@ -0,0 +1,72 @@
+from cPickle import load
+from math import exp, sin
+from collections import Counter
+from csv import reader, writer
+from data2 import parse
+import sys
+import networkx as nx
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def get_fatals():
+ with open(sys.argv[1]) as fh:
+ fh.readline()
+ r = reader(fh)
+ d = {i + 1: parse(row[7]) for (i, row) in enumerate(r)}
+ d = {k: v for k, v in d.iteritems() if v}
+ return d.items()
+
+
+def cause(lamb, alpha, mu):
+ G = nx.DiGraph()
+ roots, droots, infections = 0, 0, 0
+ fatal_droots, fatal_infections, fatal_roots = 0, 0, 0
+ fatals = get_fatals()
+ for ((n1, t1), s) in event_edges.iteritems():
+ G.add_node((n1, t1))
+ if not s:
+ droots += 1
+ if (n1, t1) in fatals:
+ fatal_droots += 1
+ continue
+ background_rate = lamb * (1 + 0.43 * sin(0.0172 * t1 + 4.36))
+ neighbors = sorted([(n2, t2, alpha / d * mu * exp(-mu * (t1 - t2)))
+ for (n2, t2, d) in s], reverse=True)
+ neighbor_rate = sum(e[2] for e in neighbors)
+ # if sum(e[2] for e in prl[:1]) > br:
+ # G.add_edge((n1, t1), tuple(prl[0][:2]))
+ if background_rate > neighbor_rate:
+ roots += 1
+ if (n1, t1) in fatals:
+ fatal_roots += 1
+ else:
+ G.add_edge((n1, t1), tuple(neighbors[0][:2]))
+ # l.append(prl[0][2] / br)
+ infections += 1
+ if (n1, t1) in fatals:
+ fatal_infections += 1
+ # l.sort(reverse=True)
+ # plt.plot(l)
+ # plt.show()
+ return (droots, roots, infections, fatal_droots,
+ fatal_roots, fatal_infections, G)
+
+
+def analyze_graph(G):
+ counts = Counter(len(c) for c in nx.weakly_connected_components(G))
+ w = writer(open("components_dist.csv", "w"))
+ w.writerows(counts.most_common())
+ edges = ((n1, t1, n2, t2) for ((n1, t1), (n2, t2)) in G.edges_iter())
+ e = writer(open("edges.csv", "w"))
+ e.writerows(edges)
+
+
+if __name__ == "__main__":
+ nodes, edges, events, event_edges = load(open("data2.pickle", "rb"))
+ lamb, alpha, mu = 1.1847510744e-05, 0.00316718040144, 0.00393069204339
+ # print len(event_edges), sum(len(e) for e in events.itervalues())
+ # print len(fatal())
+ (doors, roots, infections, fatal_droots,
+ fatal_roots, fatal_infections, G) = cause(lamb, alpha, mu)
+ analyze_graph(G)