diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-01 00:00:46 -0500 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-01 00:00:46 -0500 |
| commit | 600251accf79333d487c7186dcc5354e310c84c7 (patch) | |
| tree | 5c8f0d5324301d8816d46caa88e3e17f42b0c6ae /simulation | |
| parent | 6b067d273e07f487d774009083d6c18b6e7c6e06 (diff) | |
| download | cascades-600251accf79333d487c7186dcc5354e310c84c7.tar.gz | |
Add extension to dump the log to a file, will make plotting easier
Diffstat (limited to 'simulation')
| -rw-r--r-- | simulation/active_blocks.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py index 77a4f2d..1495eb8 100644 --- a/simulation/active_blocks.py +++ b/simulation/active_blocks.py @@ -9,6 +9,7 @@ import picklable_itertools import numpy as np import fuel import fuel.datasets +from json import dumps import collections @@ -45,6 +46,24 @@ class ActiveLearning(blocks.extensions.SimpleExtension): print(self.dataset.node_p) +class JSONDump(blocks.extensions.SimpleExtension): + """Dump a JSON-serialized version of the log to a file.""" + + def __init__(self, filename, **kwargs): + super(JSONDump, self).__init__(**kwargs) + self.fh = open(filename, "w") + + def do(self, which_callback, *args): + log = self.main_loop.log + d = {k: v for (k, v) in log.current_row.items() + if not k.startswith("_")} + d["time"] = log.status["iterations_done"] + self.fh.write(dumps(d, default=lambda o: str(o)) + "\n") + + def __del__(self): + self.fh.close() + + class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme): """Iteration scheme over finite dataset: -shuffles batches but not within batch @@ -150,7 +169,8 @@ if __name__ == "__main__": bm.TrainingDataMonitoring([cost, params, rmse, error], every_n_batches=10), be.Printing(every_n_batches=10), - #ActiveLearning(data_stream.dataset), - ] + JSONDump("log.json", every_n_batches=10) + # ActiveLearning(data_stream.dataset), + ], ) loop.run() |
