aboutsummaryrefslogtreecommitdiffstats
path: root/simulation/utils_blocks.py
diff options
context:
space:
mode:
Diffstat (limited to 'simulation/utils_blocks.py')
-rw-r--r--simulation/utils_blocks.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py
index 00b429e..2dc9f85 100644
--- a/simulation/utils_blocks.py
+++ b/simulation/utils_blocks.py
@@ -85,18 +85,18 @@ def rmse_error(graph, params):
diff = (graph - params) ** 2
subarray = tsr.arange(n_nodes)
tsr.set_subtensor(diff[subarray, subarray], 0)
- rmse = tsr.sum(diff) / (n_nodes ** 2)
+ rmse = tsr.sqrt(tsr.sum(diff) / (n_nodes ** 2))
rmse.name = 'rmse'
return rmse
-def relative_error(graph, params):
+def absolute_error(graph, params):
n_nodes = graph.shape[0]
diff = abs(graph - params)
subarray = tsr.arange(n_nodes)
tsr.set_subtensor(diff[subarray, subarray], 0)
- error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes
- error.name = 'rel_error'
+ error = tsr.sum(diff) / (n_nodes ** 2)
+ error.name = 'abs_error'
return error