summaryrefslogtreecommitdiffstats
path: root/hw3/build_train_test.py
blob: 2498ce0f6c02d2886c91a32a2f5396fe2922c94f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
from random import shuffle
from pickle import dump


def get_ratings(filename):
    d = {}
    with open(filename) as fh:
        for line in fh:
            i, j, r = map(int, line.strip().split())
            d[(i, j)] = r
    n = max(i for (i, j) in d)
    m = max(j for (i, j) in d)
    return d, n, m


def split_train_test():
    keys = ratings.keys()
    s = int(1e5)
    shuffle(keys)
    test = {k: ratings[k] for k in keys[:s]}
    train = {k: ratings[k] for k in keys[s:]}
    return train, test

if __name__ == "__main__":
    ratings, n, m = get_ratings(sys.argv[1])
    train, test = split_train_test()
    dump((n, m, train, test), open("data.pickle", "wb"))