aboutsummaryrefslogtreecommitdiffstats
path: root/jpa_test
diff options
context:
space:
mode:
Diffstat (limited to 'jpa_test')
-rw-r--r--jpa_test/algorithms.py28
-rw-r--r--jpa_test/cascade_creation.py1
2 files changed, 23 insertions, 6 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py
index 9973dc3..99430b6 100644
--- a/jpa_test/algorithms.py
+++ b/jpa_test/algorithms.py
@@ -1,17 +1,34 @@
import numpy as np
import networkx as nx
import cascade_creation
+from collections import Counter
def greedy_prediction(G, cascades):
"""
- returns estimated graph
+ Returns estimated graph
"""
G_hat = cascade_creation.InfluenceGraph(max_proba=None)
G_hat.add_nodes_from(G.nodes())
- for node in G.nodes():
- unaccounted = cascades
- for cascade in cascades:
+ for node in G_hat.nodes():
+ unaccounted = np.ones(len(cascades), dtype=bool)
+ for t, cascade in zip(xrange(len(cascades)), cascades):
+ if not cascade.infection_time(node) or \
+ cascade.infection_time(node)[0] == 0:
+ unaccounted[t] = False
+ while unaccounted.any():
+ tmp = [cascade for boolean, cascade in zip(unaccounted,
+ cascades) if boolean]
+ parents = Counter()
+ for cascade in tmp:
+ parents += cascade.candidate_infectors(node)
+ parent = parents.most_common(1)[0][0]
+ G_hat.add_edge(parent, node)
+ for t, cascade in zip(xrange(len(cascades)), cascades):
+ if (cascade.infection_time(parent) == \
+ [item - 1 for item in cascade.infection_time(node)]):
+ unaccounted[t] = False
+
def test():
"""
@@ -21,7 +38,8 @@ def test():
G.erdos_init(n = 100, p = 1)
import time
t0 = time.time()
- print len(cascade_creation.icc_cascade(G, p_init=.1))
+ A = cascade_creation.generate_cascades(G, .1, 4)
+ greedy_prediction(G, A)
t1 = time.time()
print t1 - t0
diff --git a/jpa_test/cascade_creation.py b/jpa_test/cascade_creation.py
index 01332f8..32d9c4a 100644
--- a/jpa_test/cascade_creation.py
+++ b/jpa_test/cascade_creation.py
@@ -12,7 +12,6 @@ class InfluenceGraph(nx.Graph):
def erdos_init(self, n, p):
G = nx.erdos_renyi_graph(n, p, directed=True)
- self.add_nodes_from(G.nodes())
self.add_edges_from(G.edges())
def import_from_file(self, file_name):