aboutsummaryrefslogtreecommitdiffstats
path: root/jpa_test/algorithms.py
diff options
context:
space:
mode:
Diffstat (limited to 'jpa_test/algorithms.py')
-rw-r--r--jpa_test/algorithms.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py
index 99430b6..7cdf093 100644
--- a/jpa_test/algorithms.py
+++ b/jpa_test/algorithms.py
@@ -3,28 +3,30 @@ import networkx as nx
import cascade_creation
from collections import Counter
+from itertools import izip
+
def greedy_prediction(G, cascades):
"""
- Returns estimated graph
+ Returns estimated graph from Greedy algorithm in "Learning Epidemic ..."
"""
G_hat = cascade_creation.InfluenceGraph(max_proba=None)
G_hat.add_nodes_from(G.nodes())
for node in G_hat.nodes():
unaccounted = np.ones(len(cascades), dtype=bool)
- for t, cascade in zip(xrange(len(cascades)), cascades):
+ for t, cascade in izip(xrange(len(cascades)), cascades):
if not cascade.infection_time(node) or \
cascade.infection_time(node)[0] == 0:
unaccounted[t] = False
while unaccounted.any():
- tmp = [cascade for boolean, cascade in zip(unaccounted,
+ tmp = [cascade for boolean, cascade in izip(unaccounted,
cascades) if boolean]
parents = Counter()
for cascade in tmp:
parents += cascade.candidate_infectors(node)
parent = parents.most_common(1)[0][0]
G_hat.add_edge(parent, node)
- for t, cascade in zip(xrange(len(cascades)), cascades):
+ for t, cascade in izip(xrange(len(cascades)), cascades):
if (cascade.infection_time(parent) == \
[item - 1 for item in cascade.infection_time(node)]):
unaccounted[t] = False