diff options
| author | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-02 16:30:37 -0500 |
|---|---|---|
| committer | Thibaut Horel <thibaut.horel@gmail.com> | 2015-12-02 16:30:37 -0500 |
| commit | 0fc1aa731e26683d21ed1a91f56d047d1682fd7e (patch) | |
| tree | c80ad05a80b8aab39f917495ce5aeb2e01ea5dad /simulation/utils_blocks.py | |
| parent | 0e90119296f6bbbaf28fbaa329556d6d9cd86f3f (diff) | |
| download | cascades-0fc1aa731e26683d21ed1a91f56d047d1682fd7e.tar.gz | |
Fix errors computation
Diffstat (limited to 'simulation/utils_blocks.py')
| -rw-r--r-- | simulation/utils_blocks.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/simulation/utils_blocks.py b/simulation/utils_blocks.py index 00b429e..2dc9f85 100644 --- a/simulation/utils_blocks.py +++ b/simulation/utils_blocks.py @@ -85,18 +85,18 @@ def rmse_error(graph, params): diff = (graph - params) ** 2 subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - rmse = tsr.sum(diff) / (n_nodes ** 2) + rmse = tsr.sqrt(tsr.sum(diff) / (n_nodes ** 2)) rmse.name = 'rmse' return rmse -def relative_error(graph, params): +def absolute_error(graph, params): n_nodes = graph.shape[0] diff = abs(graph - params) subarray = tsr.arange(n_nodes) tsr.set_subtensor(diff[subarray, subarray], 0) - error = tsr.sum(tsr.switch(tsr.eq(graph, 0.), 0., diff / graph)) / n_nodes - error.name = 'rel_error' + error = tsr.sum(diff) / (n_nodes ** 2) + error.name = 'abs_error' return error |
