diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-10-30 17:16:32 -0400 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-10-30 17:16:32 -0400 |
| commit | 61f644a6a7d36dc5c15d957c48d10675ab3627ae (patch) | |
| tree | e765c3ac2b1239ea2728a625a7a19196c370adbe /hw3/build_train_test.py | |
| parent | 6a969e7afb0b796996f63b8d341f8891f187ca8e (diff) | |
| download | cs281-61f644a6a7d36dc5c15d957c48d10675ab3627ae.tar.gz | |
[hw3]
Diffstat (limited to 'hw3/build_train_test.py')
| -rw-r--r-- | hw3/build_train_test.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/hw3/build_train_test.py b/hw3/build_train_test.py new file mode 100644 index 0000000..2498ce0 --- /dev/null +++ b/hw3/build_train_test.py @@ -0,0 +1,28 @@ +import sys +from random import shuffle +from pickle import dump + + +def get_ratings(filename): + d = {} + with open(filename) as fh: + for line in fh: + i, j, r = map(int, line.strip().split()) + d[(i, j)] = r + n = max(i for (i, j) in d) + m = max(j for (i, j) in d) + return d, n, m + + +def split_train_test(): + keys = ratings.keys() + s = int(1e5) + shuffle(keys) + test = {k: ratings[k] for k in keys[:s]} + train = {k: ratings[k] for k in keys[s:]} + return train, test + +if __name__ == "__main__": + ratings, n, m = get_ratings(sys.argv[1]) + train, test = split_train_test() + dump((n, m, train, test), open("data.pickle", "wb")) |
