summaryrefslogtreecommitdiffstats
path: root/hw3/plot.py
blob: 979798535adb6d42004b9ebd6684f87f9f206841 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import matplotlib.pyplot as plt
import seaborn as sb
from math import sqrt

sb.set_style("white")

values = [map(float, line.strip().split()) for line in open("results.txt")]
x, y, z = zip(*values)
y = map(sqrt, y)
z = map(sqrt, z)

plt.figure(figsize=(9, 6))
plt.plot(x, y, label="train")
plt.plot(x, z, label="validation")
plt.legend()
plt.xlabel("K")
plt.ylabel("RMSE")
plt.savefig("rmse.pdf")