diff options
Diffstat (limited to 'simulation/utils_blocks.py')
| -rw-r--r-- | simulation/utils_blocks.py | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py new file mode 100644 index 0000000..5e91658 --- /dev/null +++ b/simulation/utils_blocks.py @@ -0,0 +1,121 @@ +from theano import tensor as tsr +import fuel.datasets +import blocks.extensions as be +import picklable_itertools +import numpy as np +from json import dumps +import collections +import utils + + +class LearnedDataset(fuel.datasets.Dataset): + """ + Dynamically-created dataset (for active learning) + -compatible with ConstantScheme with request corresponding to a + batch_size + """ + provides_sources = ('x', 's') + + def __init__(self, node_p, graph, **kwargs): + super(LearnedDataset, self).__init__(**kwargs) + self.node_p = node_p + self.graph = graph + self.source = lambda graph: utils.random_source(graph, self.node_p) + + def get_data(self, state=None, request=None): + return utils.simulate_cascades(request, self.graph, self.source) + + +class ActiveLearning(be.SimpleExtension): + """ + Extension which updates the node_p array passed to the get_data method of + LearnedDataset + """ + def __init__(self, dataset, **kwargs): + super(ActiveLearning, self).__init__(**kwargs) + self.dataset = dataset + + def do(self, which_callback, *args): + out_degree = np.sum(self.dataset.graph, axis=1) + self.dataset.node_p = out_degree / np.sum(out_degree) + print(self.dataset.node_p) + + +class JSONDump(be.SimpleExtension): + """Dump a JSON-serialized version of the log to a file.""" + + def __init__(self, filename, **kwargs): + super(JSONDump, self).__init__(**kwargs) + self.fh = open(filename, "w") + + def do(self, which_callback, *args): + log = self.main_loop.log + d = {k: v for (k, v) in log.current_row.items() + if not k.startswith("_")} + d["time"] = log.status["iterations_done"] + self.fh.write(dumps(d, default=lambda o: str(o)) + "\n") + + def __del__(self): + self.fh.close() + + +class ShuffledBatchesScheme(fuel.schemes.ShuffledScheme): + """Iteration scheme over finite dataset: + -shuffles batches but not within batch + -arguments: dataset_size (int) ; batch_size (int)""" + def get_request_iterator(self): + indices = list(self.indices) # self.indices = xrange(dataset_size) + start = np.random.randint(self.batch_size) + batches = list(map( + list, + picklable_itertools.extras.partition_all(self.batch_size, + indices[start:]) + )) + if indices[:start]: + batches.append(indices[:start]) + batches = np.asarray(batches) + return iter(batches[np.random.permutation(len(batches))]) + + +def rmse_error(graph, params): + n_nodes = graph.shape[0] + diff = (graph - params) ** 2 + subarray = tsr.arange(n_nodes) + tsr.set_subtensor(diff[subarray, subarray], 0) + rmse = tsr.sum(diff) / (n_nodes ** 2) + rmse.name = 'rmse' + return rmse + + +def relative_error(graph, params): + n_nodes = graph.shape[0] + diff = abs(graph - params) + subarray = tsr.arange(n_nodes) + tsr.set_subtensor(diff[subarray, subarray], 0) + error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes + error.name = 'rel_error' + return error + + +def fixed_data_stream(n_obs, graph, batch_size, shuffle=True): + """ + creates a datastream for a fixed (not learned) dataset: + -shuffle (bool): shuffle minibatches but not within minibatch, else + sequential (non-shuffled) batches are used + """ + x_obs, s_obs = utils.simulate_cascades(n_obs, graph) + data_set = fuel.datasets.base.IndexableDataset(collections.OrderedDict( + [('x', x_obs), ('s', s_obs)] + )) + if shuffle: + scheme = ShuffledBatchesScheme(n_obs, batch_size=batch_size) + else: + scheme = fuel.schemes.SequentialScheme(n_obs, batch_size=batch_size) + return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme) + + +def dynamic_data_stream(graph, batch_size): + node_p = np.ones(len(graph)) / len(graph) + data_set = LearnedDataset(node_p, graph) + scheme = fuel.schemes.ConstantScheme(batch_size) + return fuel.streams.DataStream(dataset=data_set, iteration_scheme=scheme) |
