aboutsummaryrefslogtreecommitdiffstats
path: root/src/algorithms.py
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2014-12-07 12:08:31 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2014-12-07 12:08:31 -0500
commit9de35421f25bf45158187daea4ddfedd1c93f3d8 (patch)
treef917008b6363a2b9dbff7855781f4fd5a10a6e94 /src/algorithms.py
parent6c874852773329f6fecbbc54476b30a37aa85b79 (diff)
downloadcascades-9de35421f25bf45158187daea4ddfedd1c93f3d8.tar.gz
renaming directory + creating dataset directory
Diffstat (limited to 'src/algorithms.py')
-rw-r--r--src/algorithms.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/src/algorithms.py b/src/algorithms.py
new file mode 100644
index 0000000..0e240c9
--- /dev/null
+++ b/src/algorithms.py
@@ -0,0 +1,103 @@
+import numpy as np
+import networkx as nx
+import cascade_creation
+from collections import Counter
+from itertools import izip
+import convex_optimization
+import timeout
+
+
+def greedy_prediction(G, cascades):
+ """
+ Returns estimated graph from Greedy algorithm in "Learning Epidemic ..."
+ Only words for independent cascade model!
+ """
+ G_hat = cascade_creation.InfluenceGraph(max_proba=None)
+ G_hat.add_nodes_from(G.nodes())
+ for node in G_hat.nodes():
+ print node
+ # Avoid cases where infection time is None or 0
+ tmp = [cascade for cascade in cascades if cascade.infection_time(node)
+ [0]]
+ while tmp:
+ parents = Counter()
+ for cascade in tmp:
+ parents += cascade.candidate_infectors(node)
+ parent = parents.most_common(1)[0][0]
+ G_hat.add_edge(parent, node)
+ tmp = [cascade for cascade in tmp if (
+ cascade.infection_time(parent)[0] is not None and
+ cascade.infection_time(parent)[0]+1 not in
+ cascade.infection_time(node))]
+ return G_hat
+
+
+def recovery_l1obj_l2constraint(G, cascades, floor_cstt, passed_function,
+ *args, **kwargs):
+ """
+ Returns estimated graph from convex program specified by passed_function
+ passed_function should have similar structure to ones in convex_optimation
+ """
+ G_hat = cascade_creation.InfluenceGraph(max_proba=None)
+ G_hat.add_nodes_from(G.nodes())
+ for node in G_hat.nodes():
+ print node
+ try:
+ M, w = cascade_creation.icc_matrixvector_for_node(cascades, node)
+ p_node, __ = passed_function(M,w, *args, **kwargs)
+ G_hat = cascade_creation.add_edges_from_proba_vector(G=G_hat,
+ p_node=p_node, node=node, floor_cstt=floor_cstt)
+ except timeout.TimeoutError:
+ print "TimeoutError, skipping to next node"
+ return G_hat
+
+
+def correctness_measure(G, G_hat, print_values=False):
+ """
+ Measures correctness of estimated graph G_hat to ground truth G
+ """
+ edges = set(G.edges())
+ edges_hat = set(G_hat.edges())
+ fp = len(edges_hat - edges)
+ fn = len(edges - edges_hat)
+ tp = len(edges | edges_hat)
+ tn = G.number_of_nodes() ** 2 - fp - fn - tp
+
+ #Other metrics
+ precision = 1. * tp / (tp + fp)
+ recall = 1. * tp / (tp + fn)
+ f1_score = 2.* tp / (2 * tp + fp + fn)
+
+ if print_values:
+ print "False Positives: {}".format(fp)
+ print "False Negatives: {}".format(fn)
+ print "True Positives: {}".format(tp)
+ print "True Negatives: {}".format(tn)
+ print "-------------------------------"
+ print "Precision: {}".format(precision)
+ print "Recall: {}".format(recall)
+ print "F1 score: {}".format(f1_score)
+
+ return fp, fn, tp, tn
+
+
+def test():
+ """
+ unit test
+ """
+ G = cascade_creation.InfluenceGraph(max_proba = .8)
+ G.erdos_init(n = 100, p = .05)
+ import time
+ t0 = time.time()
+ A = cascade_creation.generate_cascades(G, .2, 10000)
+ if 1:
+ G_hat = greedy_prediction(G, A)
+ if 0:
+ G_hat = recovery_l1obj_l2constraint(G, A,
+ passed_function=convex_optimization.l1obj_l2penalization,
+ floor_cstt=.1, lbda=10)
+ correctness_measure(G, G_hat, print_values=True)
+
+
+if __name__=="__main__":
+ test()