diff options
Diffstat (limited to 'src/algorithms.py')
| -rw-r--r-- | src/algorithms.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/algorithms.py b/src/algorithms.py new file mode 100644 index 0000000..0e240c9 --- /dev/null +++ b/src/algorithms.py @@ -0,0 +1,103 @@ +import numpy as np +import networkx as nx +import cascade_creation +from collections import Counter +from itertools import izip +import convex_optimization +import timeout + + +def greedy_prediction(G, cascades): + """ + Returns estimated graph from Greedy algorithm in "Learning Epidemic ..." + Only words for independent cascade model! + """ + G_hat = cascade_creation.InfluenceGraph(max_proba=None) + G_hat.add_nodes_from(G.nodes()) + for node in G_hat.nodes(): + print node + # Avoid cases where infection time is None or 0 + tmp = [cascade for cascade in cascades if cascade.infection_time(node) + [0]] + while tmp: + parents = Counter() + for cascade in tmp: + parents += cascade.candidate_infectors(node) + parent = parents.most_common(1)[0][0] + G_hat.add_edge(parent, node) + tmp = [cascade for cascade in tmp if ( + cascade.infection_time(parent)[0] is not None and + cascade.infection_time(parent)[0]+1 not in + cascade.infection_time(node))] + return G_hat + + +def recovery_l1obj_l2constraint(G, cascades, floor_cstt, passed_function, + *args, **kwargs): + """ + Returns estimated graph from convex program specified by passed_function + passed_function should have similar structure to ones in convex_optimation + """ + G_hat = cascade_creation.InfluenceGraph(max_proba=None) + G_hat.add_nodes_from(G.nodes()) + for node in G_hat.nodes(): + print node + try: + M, w = cascade_creation.icc_matrixvector_for_node(cascades, node) + p_node, __ = passed_function(M,w, *args, **kwargs) + G_hat = cascade_creation.add_edges_from_proba_vector(G=G_hat, + p_node=p_node, node=node, floor_cstt=floor_cstt) + except timeout.TimeoutError: + print "TimeoutError, skipping to next node" + return G_hat + + +def correctness_measure(G, G_hat, print_values=False): + """ + Measures correctness of estimated graph G_hat to ground truth G + """ + edges = set(G.edges()) + edges_hat = set(G_hat.edges()) + fp = len(edges_hat - edges) + fn = len(edges - edges_hat) + tp = len(edges | edges_hat) + tn = G.number_of_nodes() ** 2 - fp - fn - tp + + #Other metrics + precision = 1. * tp / (tp + fp) + recall = 1. * tp / (tp + fn) + f1_score = 2.* tp / (2 * tp + fp + fn) + + if print_values: + print "False Positives: {}".format(fp) + print "False Negatives: {}".format(fn) + print "True Positives: {}".format(tp) + print "True Negatives: {}".format(tn) + print "-------------------------------" + print "Precision: {}".format(precision) + print "Recall: {}".format(recall) + print "F1 score: {}".format(f1_score) + + return fp, fn, tp, tn + + +def test(): + """ + unit test + """ + G = cascade_creation.InfluenceGraph(max_proba = .8) + G.erdos_init(n = 100, p = .05) + import time + t0 = time.time() + A = cascade_creation.generate_cascades(G, .2, 10000) + if 1: + G_hat = greedy_prediction(G, A) + if 0: + G_hat = recovery_l1obj_l2constraint(G, A, + passed_function=convex_optimization.l1obj_l2penalization, + floor_cstt=.1, lbda=10) + correctness_measure(G, G_hat, print_values=True) + + +if __name__=="__main__": + test() |
