aboutsummaryrefslogtreecommitdiffstats
path: root/jpa_test
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2014-11-29 16:56:30 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2014-11-29 16:56:30 -0500
commit3cbe733eb6ea3ee2a7b9b28319da6f2145ab4e46 (patch)
tree327320a2f68b84a3eef5f8118d585eb1a00b1c8d /jpa_test
parent496d7e5b022dfc005bc2f697a62522d7022519cd (diff)
downloadcascades-3cbe733eb6ea3ee2a7b9b28319da6f2145ab4e46.tar.gz
zip -> izip
Diffstat (limited to 'jpa_test')
-rw-r--r--jpa_test/algorithms.py10
-rw-r--r--jpa_test/cascade_creation.py4
2 files changed, 9 insertions, 5 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py
index 99430b6..7cdf093 100644
--- a/jpa_test/algorithms.py
+++ b/jpa_test/algorithms.py
@@ -3,28 +3,30 @@ import networkx as nx
import cascade_creation
from collections import Counter
+from itertools import izip
+
def greedy_prediction(G, cascades):
"""
- Returns estimated graph
+ Returns estimated graph from Greedy algorithm in "Learning Epidemic ..."
"""
G_hat = cascade_creation.InfluenceGraph(max_proba=None)
G_hat.add_nodes_from(G.nodes())
for node in G_hat.nodes():
unaccounted = np.ones(len(cascades), dtype=bool)
- for t, cascade in zip(xrange(len(cascades)), cascades):
+ for t, cascade in izip(xrange(len(cascades)), cascades):
if not cascade.infection_time(node) or \
cascade.infection_time(node)[0] == 0:
unaccounted[t] = False
while unaccounted.any():
- tmp = [cascade for boolean, cascade in zip(unaccounted,
+ tmp = [cascade for boolean, cascade in izip(unaccounted,
cascades) if boolean]
parents = Counter()
for cascade in tmp:
parents += cascade.candidate_infectors(node)
parent = parents.most_common(1)[0][0]
G_hat.add_edge(parent, node)
- for t, cascade in zip(xrange(len(cascades)), cascades):
+ for t, cascade in izip(xrange(len(cascades)), cascades):
if (cascade.infection_time(parent) == \
[item - 1 for item in cascade.infection_time(node)]):
unaccounted[t] = False
diff --git a/jpa_test/cascade_creation.py b/jpa_test/cascade_creation.py
index 32d9c4a..80054a7 100644
--- a/jpa_test/cascade_creation.py
+++ b/jpa_test/cascade_creation.py
@@ -2,6 +2,8 @@ import networkx as nx
import numpy as np
import collections
+from itertools import izip
+
class InfluenceGraph(nx.Graph):
"""
networkX graph with mat and logmat attributes
@@ -50,7 +52,7 @@ class Cascade(list):
Returns lists of infections times for node i in cascade
"""
infected_times = []
- for t, infected_set in zip(xrange(len(self)), self):
+ for t, infected_set in izip(xrange(len(self)), self):
if infected_set[node]:
infected_times.append(t)
return infected_times