summaryrefslogtreecommitdiffstats
path: root/hawkes/data.py
blob: 4d7744ee03800f981a8442c7f6d62878bc0748fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from csv import DictReader
import sys
from itertools import product
from cPickle import dump

MAX_TIME = 3012


def parse(s):
    return None if s == "NA" else int(float(s))


def load_nodes(filename):
    with open(filename) as fh:
        reader = DictReader(fh)
        d = {parse(row["name"]): parse(row["fatal_day"]) for row in reader}
    for n, t in d.iteritems():
        if t is None:
            d[n] = MAX_TIME
    return d


def load_edges(filename):
    events = {}
    edges = {}
    with open(filename) as fh:
        reader = DictReader(fh)
        for row in reader:
            fro, to, t, weight = map(parse, [row["from"], row["to"],
                                             row["t1"], row["w1"]])
            d = edges.get(fro, dict())
            d[to] = weight
            edges[fro] = d
            s = events.get(fro, set())
            s.add(t)
            events[fro] = s
    return edges, events


def compute_event_edges(events, edges):
    event_edges = {}

    for fro in events:
        for t in events[fro]:
            event_edges[(fro, t)] = set()

    for fro in edges:
        for to in edges[fro]:
            try:
                e1, e2 = events[fro], events[to]
            except KeyError:
                continue
            for t1, t2 in product(e1, e2):
                if t1 < t2:
                    s = event_edges[(to, t2)]
                    s.add((fro, t1, edges[fro][to]))
                    event_edges[(to, t2)] = s
    return event_edges


if __name__ == "__main__":
    nodes = load_nodes(sys.argv[1])
    edges, events = load_edges(sys.argv[2])
    event_edges = compute_event_edges(events, edges)
    dump((nodes, edges, events, event_edges), open("data.pickle", "wb"))