summaryrefslogtreecommitdiffstats
path: root/hawkes/data.py
blob: 0f6135be53d9aed80dc3b2b360bb941ad14a3d84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from csv import DictReader
import sys
from itertools import product
from cPickle import dump
from math import cos

MAX_TIME = 3012


def parse(s):
    return None if s == "NA" else int(float(s))


def fluctuation_int(t):
    if t is None:
        t = MAX_TIME
    return (t, t + 0.43 / 0.0172 * (cos(4.36) - cos(0.0172 * t + 4.36)))


def load_nodes(filename):
    with open(filename) as fh:
        reader = DictReader(fh)
        d = {parse(row["name"]): parse(row["fatal_day"]) for row in reader}
    for n, t in d.iteritems():
        d[n] = fluctuation_int(t)
    return d


def load_edges(filename):
    events = {}
    edges = {}
    with open(filename) as fh:
        reader = DictReader(fh)
        for row in reader:
            fro, to, t, dist = map(parse, [row["from"], row["to"],
                                           row["t1"], row["dist"]])
            d = edges.get(fro, dict())
            d[to] = dist
            edges[fro] = d
            s = events.get(fro, set())
            s.add(t)
            events[fro] = s
    return edges, events


def compute_event_edges(events, edges):
    event_edges = {}

    for fro in events:
        for t in events[fro]:
            event_edges[(fro, t)] = set()

    for fro in edges:
        for to in edges[fro]:
            try:
                e1, e2 = events[fro], events[to]
            except KeyError:
                continue
            for t1, t2 in product(e1, e2):
                if t1 < t2:
                    s = event_edges[(to, t2)]
                    s.add((fro, t1, edges[fro][to]))
                    event_edges[(to, t2)] = s
    return event_edges


if __name__ == "__main__":
    nodes = load_nodes(sys.argv[1])
    edges, events = load_edges(sys.argv[2])
    event_edges = compute_event_edges(events, edges)
    dump((nodes, edges, events, event_edges), open("data-all.pickle", "wb"))