diff options
Diffstat (limited to 'simulation/active_blocks.py')
| -rw-r--r-- | simulation/active_blocks.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/simulation/active_blocks.py b/simulation/active_blocks.py index 40266b2..32ce927 100644 --- a/simulation/active_blocks.py +++ b/simulation/active_blocks.py @@ -27,10 +27,7 @@ class LearnedDataset(fuel.datasets.Dataset): self.source = lambda graph: utils.random_source(graph, self.node_p) def get_data(self, state=None, request=None): - # floatX = 'int8' - x_obs, s_obs = utils.simulate_cascades(request, self.graph, self.source) - - return (x_obs, s_obs) + return utils.simulate_cascades(request, self.graph, self.source) class ActiveLearning(blocks.extensions.SimpleExtension): @@ -87,15 +84,23 @@ def create_mle_model(graph): def rmse_error(graph, params): - n_nodes = len(graph) - g_shared = theano.shared(value=graph, name='graph') - diff = (g_shared - params) ** 2 - subarray = tsr.arange(g_shared.shape[0]) + n_nodes = graph.shape[0] + diff = (graph - params) ** 2 + subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) rmse = tsr.sum(diff) / (n_nodes ** 2) rmse.name = 'rmse' - g_shared.name = 'graph' - return rmse, g_shared + return rmse + + +def relative_error(graph, params): + n_nodes = graph.shape[0] + diff = abs(g_shared - params) + subarray = tsr.arange(n_nodes) + tsr.set_subtensor(diff[subarray, subarray], 0) + error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes + error.name = 'rel_error' + return error def create_fixed_data_stream(n_obs, graph, batch_size, shuffle=True): @@ -128,8 +133,10 @@ if __name__ == "__main__": graph = utils.create_wheel(10) print('GRAPH:\n', graph, '\n-------------\n') + g_shared = theano.shared(value=graph, name='graph') x, s, params, cost = create_mle_model(graph) - rmse, g_shared = rmse_error(graph, params) + rmse = rmse_error(g_shared, params) + error = relative_error(g_shared, params) alg = algorithms.GradientDescent( cost=-cost, parameters=[params], step_rule=blocks.algorithms.AdaDelta() |
