diff options
| author | Ben Green <ben@SEASITs-MacBook-Pro.local> | 2015-06-20 19:09:29 -0400 |
|---|---|---|
| committer | Ben Green <ben@SEASITs-MacBook-Pro.local> | 2015-06-20 19:09:29 -0400 |
| commit | b26412ed5a3e08e9dc32fc73c27c42be54d82aa8 (patch) | |
| tree | c7813f644f68fa819263339185afcc47a4439ae8 /experiments/ml2.pyx | |
| parent | a473003961419502b66b5111374de26331bf4fc3 (diff) | |
| download | criminal_cascades-b26412ed5a3e08e9dc32fc73c27c42be54d82aa8.tar.gz | |
altered weight_success to show dist and dt between each parent and child
Diffstat (limited to 'experiments/ml2.pyx')
| -rw-r--r-- | experiments/ml2.pyx | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/experiments/ml2.pyx b/experiments/ml2.pyx index 8974106..e3678dd 100644 --- a/experiments/ml2.pyx +++ b/experiments/ml2.pyx @@ -46,6 +46,8 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age, np.ndarray[DTYPE_t] probs = np.zeros(n_victims, dtype=DTYPE) np.ndarray[DTYPE_t] probs_fail = np.zeros(n_victims, dtype=DTYPE) np.ndarray[DTYPE_t] probs_nv = np.zeros(len(non_victims), dtype=DTYPE) + np.ndarray[DTYPE_t] parent_dists = np.zeros(n_victims, dtype=DTYPE) + np.ndarray[DTYPE_t] parent_dts = np.zeros(n_victims, dtype=DTYPE) # loop through victims for i, parents in enumerate(victims.itervalues()): @@ -57,9 +59,17 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age, probs_fail[i] = sum(failures) successes = [weight_success(dist, dt, alpha, delta, w1, w2, w3) for (dist, dt, w1, w2, w3) in parents] + dists = [dist for (dist, dt, w1, w2, w3) in parents] + dts = [dt for (dist, dt, w1, w2, w3) in parents] # find parent that maximizes log(p) - log(\tilde{p}) - probs[i] = max(s - failures[l] for l, s in enumerate(successes)) - probs_data[i] = + # probs[i] = max(s - failures[l] for l, s in enumerate(successes)) + probs[i] = float("-inf") + for l, s in enumerate(successes): + prob = s - failures[l] + if prob > probs[i]: + probs[i] = prob + parent_dists[i] = dists[l] + parent_dts[i] = dts[l] # loop through non-victims for i, parents in enumerate(non_victims.itervalues()): @@ -84,4 +94,6 @@ def ml2(dict root_victims, dict victims, dict non_victims, DTYPE_t age, roots = n_roots beta = 0 # print n_nodes, n_roots, n_victims, max_i, roots + print parent_dists[1:100] + print parent_dts[1:100] return (beta, roots, ll) |
