summaryrefslogtreecommitdiffstats
path: root/data/face-train.py
blob: 80ca4fa0608d8ae22dfd743804e70155e49bd4f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/python
from subprocess import Popen
import os
import sys
from face_client import FaceClient

api_key = '34a84a7835bf24df2d84b4bded84e838'
api_secret = '5bc9e8c5a9e3a2d916abbe3659a1b3f8'
url_base = 'http://74.95.195.225/static/img/jon/users/'
client = FaceClient(api_key, api_secret)

def train(ns,dataset,dirs):
    f = open('train.log', 'w')
    for d in dirs:
        tids = []
        count = 0
        user = d.rstrip('/').split('/')[-1]
        for line in open(d+'/'+dataset+'-train', 'r'):
            pic = line.split(',')[-2].split('/')[-1].split('.')[0]+'.jpeg'
            if pic != "face.jpeg":
                count += 1
                url = url_base+user+'/train/'+pic
                try:
                    response = client.faces_detect(url)
                    sys.stderr.write(str(response['photos'][0]['tags'])+"\n")
                    for photo in response['photos']:
                        if len(photo['tags']) > 0:
                            tids += [photo['tags'][0]['tid']]
                            f.write("\t".join([user, pic, "face"])+"/n")
                        else:
                            f.write("\t".join([user, pic, "no-face"])+"/n")
                except:
                    sys.stderr.write("Unexpexted error:"+str(sys.exc_info()[0]))
                    pass
        saved = client.tags_save(tids = ',' . join(tids), uid = user+'@'+ns)
        sys.stderr.write(str(saved)+"\n")
        trained = client.faces_train(user+'@'+ns)
        sys.stderr.write(str(trained)+"\n")
        f.write("\t".join([user, str(len(saved['saved_tags'])), str(trained)])+"/n")
        k = trained.keys()
        k.remove(u'status')
        ntrain = trained[k[0]][0][u'training_set_size']

        print ",".join([user,str(count),str(len(saved['saved_tags'])),str(ntrain)])


if __name__ == "__main__":
    train(sys.argv[1],sys.argv[2],sys.argv[3:])