diff options
| author | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2014-11-30 23:17:41 -0500 |
|---|---|---|
| committer | jeanpouget-abadie <jean.pougetabadie@gmail.com> | 2014-11-30 23:17:41 -0500 |
| commit | b49e39e300f9d87310f6cce20018427a98f34486 (patch) | |
| tree | 272262ef3f19a5170dfd74dadda37dd9c3b61fb8 /jpa_test/algorithms.py | |
| parent | 2a6010634417eac9bf2ac4682ac3675dc5074518 (diff) | |
| download | cascades-b49e39e300f9d87310f6cce20018427a98f34486.tar.gz | |
convex_optimization first draft
Diffstat (limited to 'jpa_test/algorithms.py')
| -rw-r--r-- | jpa_test/algorithms.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py index 2a32f57..cf4ce50 100644 --- a/jpa_test/algorithms.py +++ b/jpa_test/algorithms.py @@ -2,13 +2,14 @@ import numpy as np import networkx as nx import cascade_creation from collections import Counter - from itertools import izip +import convex_optimization def greedy_prediction(G, cascades): """ Returns estimated graph from Greedy algorithm in "Learning Epidemic ..." + TODO: write cleaner code? """ G_hat = cascade_creation.InfluenceGraph(max_proba=None) G_hat.add_nodes_from(G.nodes()) @@ -33,6 +34,19 @@ def greedy_prediction(G, cascades): return G_hat +def sparserecovery(G, cascades): + """ + Returns estimated graph from following convex program: + min |theta_1| + lbda | exp(M theta) -(1- w)| + where theta = log (1 - p); w = 1_{infected}; lbda = lagrange cstt + """ + G_hat = cascade_creation.InfluenceGraph(max_proba=None) + G_hat.add_nodes_from(G.nodes()) + for node in G_hat.nodes(): + M, w = cascade_creation.icc_matrixvector_for_node(cascades, node) + edges_node = convex_optimization.l1regls(M,w) + + def correctness_measure(G, G_hat): """ Measures correctness of estimated graph G_hat to ground truth G @@ -54,6 +68,9 @@ def test(): import time t0 = time.time() A = cascade_creation.generate_cascades(G, .1, 100) + + sparserecovery(G, A) + G_hat = greedy_prediction(G, A) fp, fn, gp = correctness_measure(G, G_hat) print "False Positive: {}".format(len(fp)) |
