diff options
Diffstat (limited to 'simulation/utils_blocks.py')
| -rw-r--r-- | simulation/utils_blocks.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py index 00b429e..2dc9f85 100644 --- a/simulation/utils_blocks.py +++ b/simulation/utils_blocks.py @@ -85,18 +85,18 @@ def rmse_error(graph, params): diff = (graph - params) ** 2 subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - rmse = tsr.sum(diff) / (n_nodes ** 2) + rmse = tsr.sqrt(tsr.sum(diff) / (n_nodes ** 2)) rmse.name = 'rmse' return rmse -def relative_error(graph, params): +def absolute_error(graph, params): n_nodes = graph.shape[0] diff = abs(graph - params) subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes - error.name = 'rel_error' + error = tsr.sum(diff) / (n_nodes ** 2) + error.name = 'abs_error' return error |
