summaryrefslogtreecommitdiffstats
path: root/experiments/ml2.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/ml2.pyx')
-rw-r--r--experiments/ml2.pyx16
1 files changed, 14 insertions, 2 deletions
diff --git a/experiments/ml2.pyx b/experiments/ml2.pyx
index 8974106..e3678dd 100644
--- a/experiments/ml2.pyx
+++ b/experiments/ml2.pyx
@@ -46,6 +46,8 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age,
np.ndarray[DTYPE_t] probs = np.zeros(n_victims, dtype=DTYPE)
np.ndarray[DTYPE_t] probs_fail = np.zeros(n_victims, dtype=DTYPE)
np.ndarray[DTYPE_t] probs_nv = np.zeros(len(non_victims), dtype=DTYPE)
+ np.ndarray[DTYPE_t] parent_dists = np.zeros(n_victims, dtype=DTYPE)
+ np.ndarray[DTYPE_t] parent_dts = np.zeros(n_victims, dtype=DTYPE)
# loop through victims
for i, parents in enumerate(victims.itervalues()):
@@ -57,9 +59,17 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age,
probs_fail[i] = sum(failures)
successes = [weight_success(dist, dt, alpha, delta, w1, w2, w3)
for (dist, dt, w1, w2, w3) in parents]
+ dists = [dist for (dist, dt, w1, w2, w3) in parents]
+ dts = [dt for (dist, dt, w1, w2, w3) in parents]
# find parent that maximizes log(p) - log(\tilde{p})
- probs[i] = max(s - failures[l] for l, s in enumerate(successes))
- probs_data[i] =
+ # probs[i] = max(s - failures[l] for l, s in enumerate(successes))
+ probs[i] = float("-inf")
+ for l, s in enumerate(successes):
+ prob = s - failures[l]
+ if prob > probs[i]:
+ probs[i] = prob
+ parent_dists[i] = dists[l]
+ parent_dts[i] = dts[l]
# loop through non-victims
for i, parents in enumerate(non_victims.itervalues()):
@@ -84,4 +94,6 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age,
roots = n_roots
beta = 0
# print n_nodes, n_roots, n_victims, max_i, roots
+ print parent_dists[1:100]
+ print parent_dts[1:100]
return (beta, roots, ll)