aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/utils_blocks.py
diff options
context:
space:
mode:
Diffstat (limited to 'simulation/utils_blocks.py')
-rw-r--r--simulation/utils_blocks.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py
new file mode 100644
index 0000000..0d30786
--- /dev/null
+++ b/simulation/utils_blocks.py
@@ -0,0 +1,154 @@
+from theano import tensor as tsr
+import fuel.datasets
+import blocks.extensions as be
+import picklable_itertools
+import numpy as np
+from json import dumps
+import collections
+import utils
+
+
+class LearnedDataset(fuel.datasets.Dataset):
+ """
+ Dynamically-created dataset (for active learning)
+ -compatible with ConstantScheme with request corresponding to a
+ batch_size
+ """
+ provides_sources = ('x', 's')
+
+ def __init__(self, node_p, graph, **kwargs):
+ super(LearnedDataset, self).__init__(**kwargs)
+ self.node_p = node_p
+ self.graph = graph
+ self.source = lambda graph: utils.random_source(graph, self.node_p)
+
+ def get_data(self, state=None, request=None):
+ return utils.simulate_cascades(request, self.graph, self.source)
+
+
+class ActiveLearning(be.SimpleExtension):
+ """
+ Extension which updates the node_p array passed to the get_data method of
+ LearnedDataset
+ """
+ def __init__(self, dataset, **kwargs):
+ super(ActiveLearning, self).__init__(**kwargs)
+ self.dataset = dataset
+
+ def do(self, which_callback, *args):
+ out_degree = np.sum(self.dataset.graph, axis=1)
+ self.dataset.node_p = out_degree / np.sum(out_degree)
+
+# def do(self, which_callback, *args):
+
+
+
+class JSONDump(be.SimpleExtension):
+ """Dump a JSON-serialized version of the log to a file."""
+
+ def __init__(self, filename, **kwargs):
+ super(JSONDump, self).__init__(**kwargs)
+ self.fh = open(filename, "w")
+
+ def do(self, which_callback, *args):
+ log = self.main_loop.log
+ d = {k: v for (k, v) in log.current_row.items()
+ if not k.startswith("_")}
+ d["time"] = log.status["iterations_done"]
+ self.fh.write(dumps(d, default=lambda o: str(o)) + "\n")
+
+ def __del__(self):
+ self.fh.close()
+
+
+class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme):
+ """Iteration scheme over finite dataset:
+ -shuffles batches but not within batch
+ -arguments: dataset_size (int) ; batch_size (int)"""
+ def get_request_iterator(self):
+ indices = list(self.indices) # self.indices = xrange(dataset_size)
+ start = np.random.randint(self.batch_size)
+ batches = list(map(
+ list,
+ picklable_itertools.extras.partition_all(self.batch_size,
+ indices[start:])
+ ))
+ if indices[:start]:
+ batches.append(indices[:start])
+ batches = np.asarray(batches)
+ return iter(batches[np.random.permutation(len(batches))])
+
+
+def rmse_error(graph, params):
+ n_nodes = graph.shape[0]
+ diff = (graph - params) ** 2
+ subarray = tsr.arange(n_nodes)
+ tsr.set_subtensor(diff[subarray, subarray], 0)
+ rmse = tsr.sum(diff) / (n_nodes ** 2)
+ rmse.name = 'rmse'
+ return rmse
+
+
+def relative_error(graph, params):
+ n_nodes = graph.shape[0]
+ diff = abs(graph - params)
+ subarray = tsr.arange(n_nodes)
+ tsr.set_subtensor(diff[subarray, subarray], 0)
+ error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes
+ error.name = 'rel_error'
+ return error
+
+
+def fixed_data_stream(n_obs, graph, batch_size, shuffle=True):
+ """
+ creates a datastream for a fixed (not learned) dataset:
+ -shuffle (bool): shuffle minibatches but not within minibatch, else
+ sequential (non-shuffled) batches are used
+ """
+ x_obs, s_obs = utils.simulate_cascades(n_obs, graph)
+ data_set = fuel.datasets.base.IndexableDataset(collections.OrderedDict(
+ [('x', x_obs), ('s', s_obs)]
+ ))
+ if shuffle:
+ scheme = ShuffledBatchesScheme(n_obs, batch_size=batch_size)
+ else:
+ scheme = fuel.schemes.SequentialScheme(n_obs, batch_size=batch_size)
+ return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme)
+
+
+def dynamic_data_stream(graph, batch_size):
+ node_p = np.ones(len(graph)) / len(graph)
+ data_set = LearnedDataset(node_p, graph)
+ scheme = fuel.schemes.ConstantScheme(batch_size)
+ return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme)
+
+
+if __name__ == "__main__":
+ batch_size = 100
+ n_obs = 1000
+ frequency = 1
+ graph = utils.create_wheel(1000)
+ print('GRAPH:\n', graph, '\n-------------\n')
+
+ g_shared = theano.shared(value=graph, name='graph')
+ x, s, params, cost = create_mle_model(graph)
+ rmse = rmse_error(g_shared, params)
+ error = relative_error(g_shared, params)
+
+ alg = algorithms.GradientDescent(
+ cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta()
+ )
+ data_stream = create_learned_data_stream(graph, batch_size)
+ #data_stream = create_fixed_data_stream(n_obs, graph, batch_size)
+ loop = main_loop.MainLoop(
+ alg, data_stream,
+ extensions=[
+ be.FinishAfter(after_n_batches=10**4),
+ bm.TrainingDataMonitoring([cost, rmse, error],
+ every_n_batches=frequency),
+ be.Printing(every_n_batches=frequency),
+ JSONDump("tmpactive_log.json", every_n_batches=frequency),
+ ActiveLearning(data_stream.dataset, every_n_batches=frequency)
+ ],
+ )
+ loop.run()