summaryrefslogtreecommitdiffstats
path: root/data/svm/classification.py
diff options
context:
space:
mode:
authorThibaut Horel <thibaut.horel@gmail.com>2012-02-23 17:23:33 -0800
committerThibaut Horel <thibaut.horel@gmail.com>2012-02-23 17:23:33 -0800
commit5ae9358487b6266d2a221ebc2bcf72d735e09b25 (patch)
tree300082ceab9f7938cbe717cd67cc8e5a8d88d1c7 /data/svm/classification.py
parentac9fe3539eec8c4ec0852a21e74fa30576c02d0b (diff)
downloadkinect-5ae9358487b6266d2a221ebc2bcf72d735e09b25.tar.gz
Cleaning of svm code. Analysis of the performances
Diffstat (limited to 'data/svm/classification.py')
-rwxr-xr-xdata/svm/classification.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/data/svm/classification.py b/data/svm/classification.py
new file mode 100755
index 0000000..5515364
--- /dev/null
+++ b/data/svm/classification.py
@@ -0,0 +1,86 @@
+#! /usr/bin/python
+import copy
+import sys
+from svmutil import *
+import numpy as np
+import matplotlib.pyplot as plt
+from sets import Set
+import itertools
+import random
+
+def normalize(a,weights=None):
+ if weights == None:
+ weights= {}
+ cols = a.shape[1]
+ for i in range(cols):
+ weights[i] = None
+
+ for i in weights.keys():
+ column = a[:,i]
+ if weights[i] == None:
+ weights[i] = np.mean(column), np.std(column)
+ a[:,i] = (column-weights[i][0])/weights[i][1]
+ return a,weights
+
+def read_filter(filename) :
+ a = np.loadtxt(filename,comments="#",delimiter=",",
+ usecols=(1,4,5,6,7,8,9,10,11,12,13,14,15))
+
+ #remove rows with missing values, filter data
+ a = np.ma.masked_equal(a,-1)
+ a = np.ma.mask_rows(a)
+ a = np.ma.compress_rows(a)
+ distance = a[:,1]
+ variance = a[:,2]
+ diff = a[:,3]
+ a = a[(distance>2) & (distance<3.2) & (diff<0.5)]
+ return a
+
+def normalize_filter(a,weights=None,nameset=None):
+ a = np.ma.masked_array(a)
+ #normalize data
+ if weights==None:
+ weights = dict(zip(range(4,13),[None]*9))
+ a,weights = normalize(a,weights)
+
+ if nameset != None:
+ indexes = [i for i,v in enumerate(a[:,0]) if v in nameset]
+ a = a[indexes]
+
+ return list(a[:,0]),[dict(zip(range(1,11),r)) for r in a[:,4:]],weights
+
+def perform_svm(a,b,nameset=None):
+ y1,x1,weights = normalize_filter(a,nameset=nameset)
+ model = svm_train(y1,x1)
+ y2,x2,weights = normalize_filter(b,weights=weights,nameset=nameset)
+ p_labels,p_acc,p_vals = svm_predict(y2,x2,model)
+ return p_labels,p_acc,p_vals
+
+def accuracy_subsets(n):
+ for s in itertools.combinations(main_set,n):
+ p_acc = perform_svm(s)[1]
+ log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n")
+ log_filename.flush()
+
+def accuracy_sample(n,m):
+ for i in range(m):
+ s = random.sample(main_set,n)
+ p_acc = perform_svm(s)[1]
+ log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n")
+ log_filename.flush()
+
+if __name__ == "__main__":
+ random.seed()
+ train_filename = sys.argv[1]
+ test_filename = sys.argv[2]
+ log_filename = open(sys.argv[3],"w")
+ a = read_filter(train_filename)
+ b = read_filter(test_filename)
+ main_set = Set(range(1,26)).difference(Set([13,19,3]))
+ perform_svm(a,b,nameset=main_set)
+
+#for i in [6]:
+# accuracy_subsets(i)
+
+#for i in [6]:
+# accuracy_sample(i,200)