summaryrefslogtreecommitdiffstats
path: root/hw3/build_train_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'hw3/build_train_test.py')
-rw-r--r--hw3/build_train_test.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/hw3/build_train_test.py b/hw3/build_train_test.py
new file mode 100644
index 0000000..2498ce0
--- /dev/null
+++ b/hw3/build_train_test.py
@@ -0,0 +1,28 @@
+import sys
+from random import shuffle
+from pickle import dump
+
+
+def get_ratings(filename):
+ d = {}
+ with open(filename) as fh:
+ for line in fh:
+ i, j, r = map(int, line.strip().split())
+ d[(i, j)] = r
+ n = max(i for (i, j) in d)
+ m = max(j for (i, j) in d)
+ return d, n, m
+
+
+def split_train_test():
+ keys = ratings.keys()
+ s = int(1e5)
+ shuffle(keys)
+ test = {k: ratings[k] for k in keys[:s]}
+ train = {k: ratings[k] for k in keys[s:]}
+ return train, test
+
+if __name__ == "__main__":
+ ratings, n, m = get_ratings(sys.argv[1])
+ train, test = split_train_test()
+ dump((n, m, train, test), open("data.pickle", "wb"))