aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--jpa_test/algorithms.py27
1 files changed, 20 insertions, 7 deletions
diff --git a/jpa_test/algorithms.py b/jpa_test/algorithms.py
index cae16b8..0e240c9 100644
--- a/jpa_test/algorithms.py
+++ b/jpa_test/algorithms.py
@@ -58,14 +58,27 @@ def correctness_measure(G, G_hat, print_values=False):
"""
edges = set(G.edges())
edges_hat = set(G_hat.edges())
- fp = edges_hat - edges
- fn = edges - edges_hat
- gp = edges | edges_hat
+ fp = len(edges_hat - edges)
+ fn = len(edges - edges_hat)
+ tp = len(edges | edges_hat)
+ tn = G.number_of_nodes() ** 2 - fp - fn - tp
+
+ #Other metrics
+ precision = 1. * tp / (tp + fp)
+ recall = 1. * tp / (tp + fn)
+ f1_score = 2.* tp / (2 * tp + fp + fn)
+
if print_values:
- print "False Positives: {}".format(len(fp))
- print "False Negatives: {}".format(len(fn))
- print "Good Positives: {}".format(len(gp))
- return fp, fn, gp
+ print "False Positives: {}".format(fp)
+ print "False Negatives: {}".format(fn)
+ print "True Positives: {}".format(tp)
+ print "True Negatives: {}".format(tn)
+ print "-------------------------------"
+ print "Precision: {}".format(precision)
+ print "Recall: {}".format(recall)
+ print "F1 score: {}".format(f1_score)
+
+ return fp, fn, tp, tn
def test():