aboutsummaryrefslogtreecommitdiffstats
path: root/simulation
diff options
context:
space:
mode:
authorjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-12-02 12:31:05 -0500
committerjeanpouget-abadie <jean.pougetabadie@gmail.com>2015-12-02 12:31:05 -0500
commit5fbd4664e76d25de95f89329ee5f0f912fee4259 (patch)
tree05812c8f49cd44aca4ebf19a9b836df41ea33396 /simulation
parent600251accf79333d487c7186dcc5354e310c84c7 (diff)
downloadcascades-5fbd4664e76d25de95f89329ee5f0f912fee4259.tar.gz
frequency param introduced + plots_utils file sketch
Diffstat (limited to 'simulation')
-rw-r--r--simulation/active_blocks.py23
-rw-r--r--simulation/plot_utils.py26
2 files changed, 39 insertions, 10 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py
index 1495eb8..be1fc3d 100644
--- a/simulation/active_blocks.py
+++ b/simulation/active_blocks.py
@@ -43,7 +43,9 @@ class ActiveLearning(blocks.extensions.SimpleExtension):
def do(self, which_callback, *args):
out_degree = np.sum(self.dataset.graph, axis=1)
self.dataset.node_p = out_degree / np.sum(out_degree)
- print(self.dataset.node_p)
+
+# def do(self, which_callback, *args):
+
class JSONDump(blocks.extensions.SimpleExtension):
@@ -149,7 +151,8 @@ def create_learned_data_stream(graph, batch_size):
if __name__ == "__main__":
batch_size = 100
n_obs = 1000
- graph = utils.create_wheel(10)
+ frequency = 1
+ graph = utils.create_wheel(1000)
print('GRAPH:\n', graph, '\n-------------\n')
g_shared = theano.shared(value=graph, name='graph')
@@ -160,17 +163,17 @@ if __name__ == "__main__":
alg = algorithms.GradientDescent(
cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
)
- # data_stream = create_learned_data_stream(graph, batch_size)
- data_stream = create_fixed_data_stream(n_obs, graph, batch_size)
+ data_stream = create_learned_data_stream(graph, batch_size)
+ #data_stream = create_fixed_data_stream(n_obs, graph, batch_size)
loop = main_loop.MainLoop(
alg, data_stream,
extensions=[
- be.FinishAfter(after_n_batches=10**3),
- bm.TrainingDataMonitoring([cost, params,
- rmse, error], every_n_batches=10),
- be.Printing(every_n_batches=10),
- JSONDump("log.json", every_n_batches=10)
- # ActiveLearning(data_stream.dataset),
+ be.FinishAfter(after_n_batches=10**4),
+ bm.TrainingDataMonitoring([cost, rmse, error],
+ every_n_batches=frequency),
+ be.Printing(every_n_batches=frequency),
+ JSONDump("tmpactive_log.json", every_n_batches=frequency),
+ ActiveLearning(data_stream.dataset, every_n_batches=frequency)
],
)
loop.run()
diff --git a/simulation/plot_utils.py b/simulation/plot_utils.py
new file mode 100644
index 0000000..af5269c
--- /dev/null
+++ b/simulation/plot_utils.py
@@ -0,0 +1,26 @@
+import matplotlib.pyplot as plt
+import argparse
+import json
+import seaborn
+seaborn.set_style('whitegrid')
+
+parser = argparse.ArgumentParser(description='Process logs')
+parser.add_argument('-x', help='name of parameters on x axis', default='time')
+parser.add_argument('-y', help='name of parameters on y axis', default='rmse')
+parser.add_argument('f', help='list of logs to parse', nargs='+')
+parser.add_argument('-dest', help='name of figure to save', default='fig.png')
+args = parser.parse_args()
+
+for file_name in args.f:
+ x, y = [], []
+ with open(file_name) as f:
+ for line in f:
+ jason = json.loads(line)
+ x.append(jason[args.x])
+ y.append(jason[args.y])
+ plt.plot(x, y, label=file_name)
+
+plt.legend()
+plt.xlabel(args.x)
+plt.ylabel(args.y)
+plt.savefig(args.dest)