Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import math
5import random
6import unittest
8import numpy
9import pytest
10import torch
12from ..engine.evaluator import sample_measures_for_threshold
13from ..utils.measure import (
14 auc,
15 base_measures,
16 bayesian_measures,
17 beta_credible_region,
18)
21class TestFrequentist(unittest.TestCase):
22 """
23 Unit test for frequentist base measures
24 """
26 def setUp(self):
27 self.tp = random.randint(1, 100)
28 self.fp = random.randint(1, 100)
29 self.tn = random.randint(1, 100)
30 self.fn = random.randint(1, 100)
32 def test_precision(self):
33 precision = base_measures(self.tp, self.fp, self.tn, self.fn)[0]
34 self.assertEqual((self.tp) / (self.tp + self.fp), precision)
36 def test_recall(self):
37 recall = base_measures(self.tp, self.fp, self.tn, self.fn)[1]
38 self.assertEqual((self.tp) / (self.tp + self.fn), recall)
40 def test_specificity(self):
41 specificity = base_measures(self.tp, self.fp, self.tn, self.fn)[2]
42 self.assertEqual((self.tn) / (self.tn + self.fp), specificity)
44 def test_accuracy(self):
45 accuracy = base_measures(self.tp, self.fp, self.tn, self.fn)[3]
46 self.assertEqual(
47 (self.tp + self.tn) / (self.tp + self.tn + self.fp + self.fn),
48 accuracy,
49 )
51 def test_jaccard(self):
52 jaccard = base_measures(self.tp, self.fp, self.tn, self.fn)[4]
53 self.assertEqual(self.tp / (self.tp + self.fp + self.fn), jaccard)
55 def test_f1(self):
56 p, r, s, a, j, f1 = base_measures(self.tp, self.fp, self.tn, self.fn)
57 self.assertEqual(
58 (2.0 * self.tp) / (2.0 * self.tp + self.fp + self.fn), f1
59 )
60 self.assertAlmostEqual((2 * p * r) / (p + r), f1) # base definition
63class TestBayesian:
64 """
65 Unit test for bayesian base measures
66 """
68 def mean(self, k, l, lambda_):
69 return (k + lambda_) / (k + l + 2 * lambda_)
71 def mode1(self, k, l, lambda_): # (k+lambda_), (l+lambda_) > 1
72 return (k + lambda_ - 1) / (k + l + 2 * lambda_ - 2)
74 def test_beta_credible_region_base(self):
75 k = 40
76 l = 10
77 lambda_ = 0.5
78 cover = 0.95
79 got = beta_credible_region(k, l, lambda_, cover)
80 # mean, mode, lower, upper
81 exp = (
82 self.mean(k, l, lambda_),
83 self.mode1(k, l, lambda_),
84 0.6741731038857685,
85 0.8922659692341358,
86 )
87 assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
89 def test_beta_credible_region_small_k(self):
91 k = 4
92 l = 1
93 lambda_ = 0.5
94 cover = 0.95
95 got = beta_credible_region(k, l, lambda_, cover)
96 # mean, mode, lower, upper
97 exp = (
98 self.mean(k, l, lambda_),
99 self.mode1(k, l, lambda_),
100 0.37137359936800574,
101 0.9774872340008449,
102 )
103 assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
105 def test_beta_credible_region_precision_jeffrey(self):
107 # simulation of situation for precision TP == FP == 0, Jeffrey's prior
108 k = 0
109 l = 0
110 lambda_ = 0.5
111 cover = 0.95
112 got = beta_credible_region(k, l, lambda_, cover)
113 # mean, mode, lower, upper
114 exp = (
115 self.mean(k, l, lambda_),
116 0.0,
117 0.0015413331334360135,
118 0.998458666866564,
119 )
120 assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
122 def test_beta_credible_region_precision_flat(self):
124 # simulation of situation for precision TP == FP == 0, flat prior
125 k = 0
126 l = 0
127 lambda_ = 1.0
128 cover = 0.95
129 got = beta_credible_region(k, l, lambda_, cover)
130 # mean, mode, lower, upper
131 exp = (self.mean(k, l, lambda_), 0.0, 0.025000000000000022, 0.975)
132 assert numpy.isclose(got, exp).all(), f"{got} <> {exp}"
134 def test_bayesian_measures(self):
136 tp = random.randint(100000, 1000000)
137 fp = random.randint(100000, 1000000)
138 tn = random.randint(100000, 1000000)
139 fn = random.randint(100000, 1000000)
141 _prec, _rec, _spec, _acc, _jac, _f1 = base_measures(tp, fp, tn, fn)
142 prec, rec, spec, acc, jac, f1 = bayesian_measures(
143 tp, fp, tn, fn, 0.5, 0.95
144 )
146 # Notice that for very large k and l, the base frequentist measures
147 # should be approximately the same as the bayesian mean and mode
148 # extracted from the beta posterior. We test that here.
149 assert numpy.isclose(
150 _prec, prec[0]
151 ), f"freq: {_prec} <> bays: {prec[0]}"
152 assert numpy.isclose(
153 _prec, prec[1]
154 ), f"freq: {_prec} <> bays: {prec[1]}"
155 assert numpy.isclose(_rec, rec[0]), f"freq: {_rec} <> bays: {rec[0]}"
156 assert numpy.isclose(_rec, rec[1]), f"freq: {_rec} <> bays: {rec[1]}"
157 assert numpy.isclose(
158 _spec, spec[0]
159 ), f"freq: {_spec} <> bays: {spec[0]}"
160 assert numpy.isclose(
161 _spec, spec[1]
162 ), f"freq: {_spec} <> bays: {spec[1]}"
163 assert numpy.isclose(_acc, acc[0]), f"freq: {_acc} <> bays: {acc[0]}"
164 assert numpy.isclose(_acc, acc[1]), f"freq: {_acc} <> bays: {acc[1]}"
165 assert numpy.isclose(_jac, jac[0]), f"freq: {_jac} <> bays: {jac[0]}"
166 assert numpy.isclose(_jac, jac[1]), f"freq: {_jac} <> bays: {jac[1]}"
167 assert numpy.isclose(_f1, f1[0]), f"freq: {_f1} <> bays: {f1[0]}"
168 assert numpy.isclose(_f1, f1[1]), f"freq: {_f1} <> bays: {f1[1]}"
170 # We also test that the interval in question includes the mode and the
171 # mean in this case.
172 assert (prec[2] < prec[1]) and (
173 prec[1] < prec[3]
174 ), f"precision is out of bounds {_prec[2]} < {_prec[1]} < {_prec[3]}"
175 assert (rec[2] < rec[1]) and (
176 rec[1] < rec[3]
177 ), f"recall is out of bounds {_rec[2]} < {_rec[1]} < {_rec[3]}"
178 assert (spec[2] < spec[1]) and (
179 spec[1] < spec[3]
180 ), f"specif. is out of bounds {_spec[2]} < {_spec[1]} < {_spec[3]}"
181 assert (acc[2] < acc[1]) and (
182 acc[1] < acc[3]
183 ), f"accuracy is out of bounds {_acc[2]} < {_acc[1]} < {_acc[3]}"
184 assert (jac[2] < jac[1]) and (
185 jac[1] < jac[3]
186 ), f"jaccard is out of bounds {_jac[2]} < {_jac[1]} < {_jac[3]}"
187 assert (f1[2] < f1[1]) and (
188 f1[1] < f1[3]
189 ), f"f1-score is out of bounds {_f1[2]} < {_f1[1]} < {_f1[3]}"
192def test_auc():
194 # basic tests
195 assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0, 1.0]), 1.0)
196 assert math.isclose(
197 auc([0.0, 0.5, 1.0], [1.0, 0.5, 0.0]), 0.5, rel_tol=0.001
198 )
199 assert math.isclose(
200 auc([0.0, 0.5, 1.0], [0.0, 0.0, 0.0]), 0.0, rel_tol=0.001
201 )
202 assert math.isclose(
203 auc([0.0, 0.5, 1.0], [0.0, 1.0, 0.0]), 0.5, rel_tol=0.001
204 )
205 assert math.isclose(
206 auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001
207 )
208 assert math.isclose(
209 auc([0.0, 0.5, 1.0], [0.0, 0.5, 0.0]), 0.25, rel_tol=0.001
210 )
212 # reversing tht is also true
213 assert math.isclose(auc([0.0, 0.5, 1.0][::-1], [1.0, 1.0, 1.0][::-1]), 1.0)
214 assert math.isclose(
215 auc([0.0, 0.5, 1.0][::-1], [1.0, 0.5, 0.0][::-1]), 0.5, rel_tol=0.001
216 )
217 assert math.isclose(
218 auc([0.0, 0.5, 1.0][::-1], [0.0, 0.0, 0.0][::-1]), 0.0, rel_tol=0.001
219 )
220 assert math.isclose(
221 auc([0.0, 0.5, 1.0][::-1], [0.0, 1.0, 0.0][::-1]), 0.5, rel_tol=0.001
222 )
223 assert math.isclose(
224 auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001
225 )
226 assert math.isclose(
227 auc([0.0, 0.5, 1.0][::-1], [0.0, 0.5, 0.0][::-1]), 0.25, rel_tol=0.001
228 )
231def test_auc_raises_value_error():
233 with pytest.raises(
234 ValueError, match=r".*neither increasing nor decreasing.*"
235 ):
236 # x is **not** monotonically increasing or decreasing
237 assert math.isclose(auc([0.0, 0.5, 0.0], [1.0, 1.0, 1.0]), 1.0)
240def test_auc_raises_assertion_error():
242 with pytest.raises(AssertionError, match=r".*must have the same length.*"):
243 # x is **not** the same size as y
244 assert math.isclose(auc([0.0, 0.5, 1.0], [1.0, 1.0]), 1.0)
247def test_sample_measures_mask_checkerbox():
249 prediction = torch.ones((4, 4), dtype=float)
250 ground_truth = torch.ones((4, 4), dtype=float)
251 ground_truth[2:, :2] = 0.0
252 ground_truth[:2, 2:] = 0.0
253 mask = torch.zeros((4, 4), dtype=float)
254 mask[1:3, 1:3] = 1.0
255 threshold = 0.5
257 # with this configuration, this should be the correct count
258 tp = 2
259 fp = 2
260 tn = 0
261 fn = 0
263 assert (tp, fp, tn, fn) == sample_measures_for_threshold(
264 prediction, ground_truth, mask, threshold
265 )
268def test_sample_measures_mask_cross():
270 prediction = torch.ones((10, 10), dtype=float)
271 prediction[0, :] = 0.0
272 prediction[9, :] = 0.0
273 ground_truth = torch.ones((10, 10), dtype=float)
274 ground_truth[
275 :5,
276 ] = 0.0 # lower part is not to be set
277 mask = torch.zeros((10, 10), dtype=float)
278 mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)] = 1.0
279 mask[(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), (9, 8, 7, 6, 5, 4, 3, 2, 1, 0)] = 1.0
280 threshold = 0.5
282 # with this configuration, this should be the correct count
283 tp = 8
284 fp = 8
285 tn = 2
286 fn = 2
288 assert (tp, fp, tn, fn) == sample_measures_for_threshold(
289 prediction, ground_truth, mask, threshold
290 )
293def test_sample_measures_mask_border():
295 prediction = torch.zeros((10, 10), dtype=float)
296 prediction[:, 4] = 1.0
297 prediction[:, 5] = 1.0
298 prediction[0, 4] = 0.0
299 prediction[8, 4] = 0.0
300 prediction[1, 6] = 1.0
301 ground_truth = torch.zeros((10, 10), dtype=float)
302 ground_truth[:, 4] = 1.0
303 ground_truth[:, 5] = 1.0
304 mask = torch.ones((10, 10), dtype=float)
305 mask[:, 0] = 0.0
306 mask[0, :] = 0.0
307 mask[:, 9] = 0.0
308 mask[9, :] = 0.0
309 threshold = 0.5
311 # with this configuration, this should be the correct count
312 tp = 15
313 fp = 1
314 tn = 47
315 fn = 1
317 assert (tp, fp, tn, fn) == sample_measures_for_threshold(
318 prediction, ground_truth, mask, threshold
319 )