summaryrefslogtreecommitdiffstats
path: root/data/pair-matching
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2012-02-28 22:14:54 -0800
committerThibaut Horel <thibaut.horel@gmail.com>2012-02-28 22:20:13 -0800
commit7dc01b33cadc6d565821e19db2bb006f5b1a211f (patch)
tree9a92edbbab5a0592f671095f491070d276ece92d /data/pair-matching
parent62bd4b6b877e15238d070b580014d9dfda230342 (diff)
downloadkinect-7dc01b33cadc6d565821e19db2bb006f5b1a211f.tar.gz
Proper plotting of the ROC curves for the pair matching problem
Diffstat (limited to 'data/pair-matching')
-rwxr-xr-xdata/pair-matching/roc.py56
1 files changed, 29 insertions, 27 deletions
diff --git a/data/pair-matching/roc.py b/data/pair-matching/roc.py
index ebf8f68..46e4d23 100755
--- a/data/pair-matching/roc.py
+++ b/data/pair-matching/roc.py
@@ -30,43 +30,45 @@ def gen_pairs(var,sk_data):
m_pairs = zip(range(sk_data.shape[0]),range(sk_data.shape[0]))
result = []
for j in range(sk_data.shape[0]):
- result += [(distance(sk1[m_pairs[j][0]],sk2[m_pairs[j][1]]), distance(sk1[u_pairs[j][0]],sk2[u_pairs[j][1]]))]
+ result += [(distance(sk1[m_pairs[j][0]],sk2[m_pairs[j][1]]),
+ distance(sk1[u_pairs[j][0]],sk2[u_pairs[j][1]]))]
return result
if __name__ == "__main__":
-# eg = np.loadtxt("eigenfaces.txt",delimiter=" ")
ap = np.loadtxt("associatepredict.txt",delimiter=",")
-# plt.plot(eg[:,0],eg[:,1])
- plt.plot(ap[:,1],ap[:,0])
- plt.xlabel("False positive %")
- plt.ylabel("True positive %")
+ indices = [i for i in range(ap.shape[0]) if ap[i,1]<0.1]
+ ap_false = ap[:,1][indices]
+ ap_true = ap[:,0][indices]
+ plt.plot(ap_false,ap_true,label="Face detection")
+ plt.xlabel("False positive \%")
+ plt.ylabel("True positive \%")
np.random.seed()
- var = map(float,sys.argv[2].split(","))
+ std = map(float,sys.argv[2].split(","))
sk_data = np.loadtxt(sys.argv[1],comments="#",delimiter=",")
- for v in var:
- result = gen_pairs(v,sk_data)
- thresholds = np.square(np.arange(0,10,0.01))
+ for s in std:
+ result = gen_pairs(s,sk_data)
+ thresholds = np.square(np.arange(0,10,0.001))
true_pos = []
false_pos = []
for threshold in thresholds:
- true_values = []
- false_values = []
- for i in range(4):
- true = 0
- false = 0
- min_j = i*300
- max_j = min(min_j+300,sk_data.shape[0])
- for j in range(min_j,max_j):
- if result[j][0] < threshold:
- true += 1
- if result[j][1] < threshold:
- false += 1
- true_values += [float(true)/(max_j-min_j)]
- false_values += [float(false)/(max_j-min_j)]
- true_pos += [sum(true_values)/4]
- false_pos += [sum(false_values)/4]
- plt.plot(false_pos,true_pos)
+ true = 0
+ false = 0
+ for j in range(len(result)):
+ if result[j][0] < threshold:
+ true += 1
+ if result[j][1] < threshold:
+ false += 1
+ true_pos += [float(true)/len(result)]
+ false_pos += [float(false)/len(result)]
+ indices = [i for i in range(len(false_pos)) if false_pos[i]<0.1]
+ false_pos = np.array(false_pos)
+ false_pos = false_pos[indices]
+ true_pos = np.array(true_pos)
+ true_pos = true_pos[indices]
+ plt.plot(false_pos,true_pos,label="$\sigma$ = "+str(s))
+ plt.legend(loc="lower right")
+ plt.savefig("roc.pdf")
plt.show()