diff options
Diffstat (limited to 'jpa_test')
| -rw-r--r-- | jpa_test/algorithms.py | 28 | ||||
| -rw-r--r-- | jpa_test/cascade_creation.py | 1 |
2 files changed, 23 insertions, 6 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py index 9973dc3..99430b6 100644 --- a/jpa_test/algorithms.py +++ b/jpa_test/algorithms.py @@ -1,17 +1,34 @@ import numpy as np import networkx as nx import cascade_creation +from collections import Counter def greedy_prediction(G, cascades): """ - returns estimated graph + Returns estimated graph """ G_hat = cascade_creation.InfluenceGraph(max_proba=None) G_hat.add_nodes_from(G.nodes()) - for node in G.nodes(): - unaccounted = cascades - for cascade in cascades: + for node in G_hat.nodes(): + unaccounted = np.ones(len(cascades), dtype=bool) + for t, cascade in zip(xrange(len(cascades)), cascades): + if not cascade.infection_time(node) or \ + cascade.infection_time(node)[0] == 0: + unaccounted[t] = False + while unaccounted.any(): + tmp = [cascade for boolean, cascade in zip(unaccounted, + cascades) if boolean] + parents = Counter() + for cascade in tmp: + parents += cascade.candidate_infectors(node) + parent = parents.most_common(1)[0][0] + G_hat.add_edge(parent, node) + for t, cascade in zip(xrange(len(cascades)), cascades): + if (cascade.infection_time(parent) == \ + [item - 1 for item in cascade.infection_time(node)]): + unaccounted[t] = False + def test(): """ @@ -21,7 +38,8 @@ def test(): G.erdos_init(n = 100, p = 1) import time t0 = time.time() - print len(cascade_creation.icc_cascade(G, p_init=.1)) + A = cascade_creation.generate_cascades(G, .1, 4) + greedy_prediction(G, A) t1 = time.time() print t1 - t0 diff --git a/jpa_test/cascade_creation.py b/jpa_test/cascade_creation.py index 01332f8..32d9c4a 100644 --- a/jpa_test/cascade_creation.py +++ b/jpa_test/cascade_creation.py @@ -12,7 +12,6 @@ class InfluenceGraph(nx.Graph): def erdos_init(self, n, p): G = nx.erdos_renyi_graph(n, p, directed=True) - self.add_nodes_from(G.nodes()) self.add_edges_from(G.edges()) def import_from_file(self, file_name): |
