diff options
Diffstat (limited to 'hw3/build_train_test.py')
| -rw-r--r-- | hw3/build_train_test.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/hw3/build_train_test.py b/hw3/build_train_test.py new file mode 100644 index 0000000..2498ce0 --- /dev/null +++ b/hw3/build_train_test.py @@ -0,0 +1,28 @@ +import sys +from random import shuffle +from pickle import dump + + +def get_ratings(filename): + d = {} + with open(filename) as fh: + for line in fh: + i, j, r = map(int, line.strip().split()) + d[(i, j)] = r + n = max(i for (i, j) in d) + m = max(j for (i, j) in d) + return d, n, m + + +def split_train_test(): + keys = ratings.keys() + s = int(1e5) + shuffle(keys) + test = {k: ratings[k] for k in keys[:s]} + train = {k: ratings[k] for k in keys[s:]} + return train, test + +if __name__ == "__main__": + ratings, n, m = get_ratings(sys.argv[1]) + train, test = split_train_test() + dump((n, m, train, test), open("data.pickle", "wb")) |
