summaryrefslogtreecommitdiffstats
path: root/data/svm/classification.py
blob: 5515364d490eb48a01c05ceff6dc3b8eeb059683 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#! /usr/bin/python
import copy
import sys
from svmutil import *
import numpy as np
import matplotlib.pyplot as plt
from sets import Set
import itertools
import random

def normalize(a,weights=None):
    if weights == None:
        weights= {}
        cols = a.shape[1]
        for i in range(cols):
            weights[i] = None

    for i in weights.keys():
        column = a[:,i]
        if weights[i] == None:
            weights[i] = np.mean(column), np.std(column)
        a[:,i] = (column-weights[i][0])/weights[i][1]
    return a,weights

def read_filter(filename) :
    a = np.loadtxt(filename,comments="#",delimiter=",",
                   usecols=(1,4,5,6,7,8,9,10,11,12,13,14,15))
    
    #remove rows with missing values, filter data
    a = np.ma.masked_equal(a,-1)
    a = np.ma.mask_rows(a)
    a = np.ma.compress_rows(a)
    distance = a[:,1]
    variance = a[:,2]
    diff = a[:,3]
    a = a[(distance>2) & (distance<3.2) & (diff<0.5)]
    return a

def normalize_filter(a,weights=None,nameset=None):
    a = np.ma.masked_array(a)
    #normalize data
    if weights==None:
        weights = dict(zip(range(4,13),[None]*9))
    a,weights = normalize(a,weights)
    
    if nameset != None:
        indexes = [i for i,v in enumerate(a[:,0]) if v in nameset]
        a = a[indexes]

    return list(a[:,0]),[dict(zip(range(1,11),r)) for r in a[:,4:]],weights    

def perform_svm(a,b,nameset=None):
    y1,x1,weights = normalize_filter(a,nameset=nameset)
    model = svm_train(y1,x1)
    y2,x2,weights = normalize_filter(b,weights=weights,nameset=nameset)
    p_labels,p_acc,p_vals = svm_predict(y2,x2,model)
    return p_labels,p_acc,p_vals

def accuracy_subsets(n):
    for s in itertools.combinations(main_set,n):
        p_acc = perform_svm(s)[1]
        log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n")
        log_filename.flush()

def accuracy_sample(n,m):
    for i in range(m):
        s = random.sample(main_set,n)
        p_acc = perform_svm(s)[1]
        log_filename.write(str(n)+"#"+str(s)+"#"+str(p_acc[0])+"\n")
        log_filename.flush()

if __name__ == "__main__":
    random.seed()
    train_filename = sys.argv[1]
    test_filename = sys.argv[2]
    log_filename = open(sys.argv[3],"w")
    a = read_filter(train_filename)
    b = read_filter(test_filename)
    main_set = Set(range(1,26)).difference(Set([13,19,3]))
    perform_svm(a,b,nameset=main_set)

#for i in [6]:
#    accuracy_subsets(i)

#for i in [6]:
#    accuracy_sample(i,200)