From 0fc1aa731e26683d21ed1a91f56d047d1682fd7e Mon Sep 17 00:00:00 2001 From: Thibaut Horel Date: Wed, 2 Dec 2015 16:30:37 -0500 Subject: Fix errors computation --- simulation/utils_blocks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'simulation/utils_blocks.py') diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py index 00b429e..2dc9f85 100644 --- a/simulation/utils_blocks.py +++ b/simulation/utils_blocks.py @@ -85,18 +85,18 @@ def rmse_error(graph, params): diff = (graph - params) ** 2 subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - rmse = tsr.sum(diff) / (n_nodes ** 2) + rmse = tsr.sqrt(tsr.sum(diff) / (n_nodes ** 2)) rmse.name = 'rmse' return rmse -def relative_error(graph, params): +def absolute_error(graph, params): n_nodes = graph.shape[0] diff = abs(graph - params) subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes - error.name = 'rel_error' + error = tsr.sum(diff) / (n_nodes ** 2) + error.name = 'abs_error' return error -- cgit v1.2.3-70-g09d2