diff options
Diffstat (limited to 'data/svm/classification.py')
| -rwxr-xr-x | data/svm/classification.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/data/svm/classification.py b/data/svm/classification.py new file mode 100755 index 0000000..5515364 --- /dev/null +++ b/data/svm/classification.py @@ -0,0 +1,86 @@ +#! /usr/bin/python +import copy +import sys +from svmutil import * +import numpy as np +import matplotlib.pyplot as plt +from sets import Set +import itertools +import random + +def normalize(a,weights=None): + if weights == None: + weights= {} + cols = a.shape[1] + for i in range(cols): + weights[i] = None + + for i in weights.keys(): + column = a[:,i] + if weights[i] == None: + weights[i] = np.mean(column), np.std(column) + a[:,i] = (column-weights[i][0])/weights[i][1] + return a,weights + +def read_filter(filename) : + a = np.loadtxt(filename,comments="#",delimiter=",", + usecols=(1,4,5,6,7,8,9,10,11,12,13,14,15)) + + #remove rows with missing values, filter data + a = np.ma.masked_equal(a,-1) + a = np.ma.mask_rows(a) + a = np.ma.compress_rows(a) + distance = a[:,1] + variance = a[:,2] + diff = a[:,3] + a = a[(distance>2) & (distance<3.2) & (diff<0.5)] + return a + +def normalize_filter(a,weights=None,nameset=None): + a = np.ma.masked_array(a) + #normalize data + if weights==None: + weights = dict(zip(range(4,13),[None]*9)) + a,weights = normalize(a,weights) + + if nameset != None: + indexes = [i for i,v in enumerate(a[:,0]) if v in nameset] + a = a[indexes] + + return list(a[:,0]),[dict(zip(range(1,11),r)) for r in a[:,4:]],weights + +def perform_svm(a,b,nameset=None): + y1,x1,weights = normalize_filter(a,nameset=nameset) + model = svm_train(y1,x1) + y2,x2,weights = normalize_filter(b,weights=weights,nameset=nameset) + p_labels,p_acc,p_vals = svm_predict(y2,x2,model) + return p_labels,p_acc,p_vals + +def accuracy_subsets(n): + for s in itertools.combinations(main_set,n): + p_acc = perform_svm(s)[1] + log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n") + log_filename.flush() + +def accuracy_sample(n,m): + for i in range(m): + s = random.sample(main_set,n) + p_acc = perform_svm(s)[1] + log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n") + log_filename.flush() + +if __name__ == "__main__": + random.seed() + train_filename = sys.argv[1] + test_filename = sys.argv[2] + log_filename = open(sys.argv[3],"w") + a = read_filter(train_filename) + b = read_filter(test_filename) + main_set = Set(range(1,26)).difference(Set([13,19,3])) + perform_svm(a,b,nameset=main_set) + +#for i in [6]: +# accuracy_subsets(i) + +#for i in [6]: +# accuracy_sample(i,200) |
