blob: 99430b6bfdf8bb724300e7bbcafc319b61eea3b3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
import numpy as np
import networkx as nx
import cascade_creation
from collections import Counter
def greedy_prediction(G, cascades):
"""
Returns estimated graph
"""
G_hat = cascade_creation.InfluenceGraph(max_proba=None)
G_hat.add_nodes_from(G.nodes())
for node in G_hat.nodes():
unaccounted = np.ones(len(cascades), dtype=bool)
for t, cascade in zip(xrange(len(cascades)), cascades):
if not cascade.infection_time(node) or \
cascade.infection_time(node)[0] == 0:
unaccounted[t] = False
while unaccounted.any():
tmp = [cascade for boolean, cascade in zip(unaccounted,
cascades) if boolean]
parents = Counter()
for cascade in tmp:
parents += cascade.candidate_infectors(node)
parent = parents.most_common(1)[0][0]
G_hat.add_edge(parent, node)
for t, cascade in zip(xrange(len(cascades)), cascades):
if (cascade.infection_time(parent) == \
[item - 1 for item in cascade.infection_time(node)]):
unaccounted[t] = False
def test():
"""
unit test
"""
G = cascade_creation.InfluenceGraph(max_proba = .3)
G.erdos_init(n = 100, p = 1)
import time
t0 = time.time()
A = cascade_creation.generate_cascades(G, .1, 4)
greedy_prediction(G, A)
t1 = time.time()
print t1 - t0
if __name__=="__main__":
test()
|