aboutsummaryrefslogtreecommitdiffstats
path: root/jpa_test
diff options
context:
space:
mode:
Diffstat (limited to 'jpa_test')
-rw-r--r--jpa_test/algorithms.py15
-rw-r--r--jpa_test/convex_optimization.py3
-rw-r--r--jpa_test/timeout.py5
3 files changed, 15 insertions, 8 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py
index f3bf917..0d5f154 100644
--- a/jpa_test/algorithms.py
+++ b/jpa_test/algorithms.py
@@ -34,7 +34,7 @@ def greedy_prediction(G, cascades):
unaccounted[t] = False
return G_hat
-@timeout.timeout(10)
+
def recovery_l1obj_l2constraint(G, cascades):
"""
Returns estimated graph from following convex program:
@@ -45,10 +45,13 @@ def recovery_l1obj_l2constraint(G, cascades):
G_hat.add_nodes_from(G.nodes())
for node in G_hat.nodes():
print node
- M, w = cascade_creation.icc_matrixvector_for_node(cascades, node)
- p_node, __ = convex_optimization.l1obj_l2constraint(M,w)
- G_hat = cascade_creation.add_edges_from_proba_vector(G=G_hat,
- p_node=p_node, node=node, floor_cstt=.01)
+ try:
+ M, w = cascade_creation.icc_matrixvector_for_node(cascades, node)
+ p_node, __ = convex_optimization.l1obj_l2constraint(M,w)
+ G_hat = cascade_creation.add_edges_from_proba_vector(G=G_hat,
+ p_node=p_node, node=node, floor_cstt=.01)
+ except timeout.TimeoutError:
+ print "TimeoutError, skipping to next node"
return G_hat
@@ -76,7 +79,7 @@ def test():
G.erdos_init(n = 100, p = .3)
import time
t0 = time.time()
- A = cascade_creation.generate_cascades(G, .2, 50)
+ A = cascade_creation.generate_cascades(G, .2, 100)
G_hat = recovery_l1obj_l2constraint(G, A)
diff --git a/jpa_test/convex_optimization.py b/jpa_test/convex_optimization.py
index ef54892..a9556ae 100644
--- a/jpa_test/convex_optimization.py
+++ b/jpa_test/convex_optimization.py
@@ -2,10 +2,11 @@ import theano
import cascade_creation
from theano import tensor, function
import numpy as np
+import timeout
import cvxopt
-
+@timeout.timeout(10)
def l1obj_l2constraint(M_val, w_val):
"""
Solves:
diff --git a/jpa_test/timeout.py b/jpa_test/timeout.py
index 52d7d92..d7381c3 100644
--- a/jpa_test/timeout.py
+++ b/jpa_test/timeout.py
@@ -3,10 +3,13 @@ import errno
import os
import signal
+class TimeoutError(Exception):
+ pass
+
def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
def decorator(func):
def _handle_timeout(signum, frame):
- raise Exception(error_message)
+ raise TimeoutError(error_message)
def wrapper(*args, **kwargs):
signal.signal(signal.SIGALRM, _handle_timeout)