aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--simulation/active_blocks.py24
1 files changed, 22 insertions, 2 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py
index 77a4f2d..1495eb8 100644
--- a/simulation/active_blocks.py
+++ b/simulation/active_blocks.py
@@ -9,6 +9,7 @@ import picklable_itertools
import numpy as np
import fuel
import fuel.datasets
+from json import dumps
import collections
@@ -45,6 +46,24 @@ class ActiveLearning(blocks.extensions.SimpleExtension):
print(self.dataset.node_p)
+class JSONDump(blocks.extensions.SimpleExtension):
+ """Dump a JSON-serialized version of the log to a file."""
+
+ def __init__(self, filename, **kwargs):
+ super(JSONDump, self).__init__(**kwargs)
+ self.fh = open(filename, "w")
+
+ def do(self, which_callback, *args):
+ log = self.main_loop.log
+ d = {k: v for (k, v) in log.current_row.items()
+ if not k.startswith("_")}
+ d["time"] = log.status["iterations_done"]
+ self.fh.write(dumps(d, default=lambda o: str(o)) + "\n")
+
+ def __del__(self):
+ self.fh.close()
+
+
class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme):
"""Iteration scheme over finite dataset:
-shuffles batches but not within batch
@@ -150,7 +169,8 @@ if __name__ == "__main__":
bm.TrainingDataMonitoring([cost, params,
rmse, error], every_n_batches=10),
be.Printing(every_n_batches=10),
- #ActiveLearning(data_stream.dataset),
- ]
+ JSONDump("log.json", every_n_batches=10)
+ # ActiveLearning(data_stream.dataset),
+ ],
)
loop.run()