-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_metrics.py
42 lines (34 loc) · 1.28 KB
/
test_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import unittest
import torch
from metrics import BinaryExpectedCost
class TestMetrics(unittest.TestCase):
def test_binary_expected_cost_tn(self):
target = torch.tensor([0])
preds = torch.tensor([0])
expected_cost = BinaryExpectedCost()
cost = expected_cost(preds, target)
self.assertEqual(cost, 0)
def test_binary_expected_cost_fp(self):
target = torch.tensor([0])
preds = torch.tensor([1])
expected_cost = BinaryExpectedCost()
cost = expected_cost(preds, target)
self.assertEqual(cost, 1)
def test_binary_expected_cost_fn(self):
target = torch.tensor([1])
preds = torch.tensor([0])
expected_cost = BinaryExpectedCost()
cost = expected_cost(preds, target)
self.assertEqual(cost, 5)
def test_binary_expected_cost_tp(self):
target = torch.tensor([1])
preds = torch.tensor([1])
expected_cost = BinaryExpectedCost()
cost = expected_cost(preds, target)
self.assertEqual(cost, 0)
def test_binary_expected_cost(self):
target = torch.tensor([1, 1, 0, 0])
preds = torch.tensor([0, 1, 1, 0])
expected_cost = BinaryExpectedCost()
cost = expected_cost(preds, target)
self.assertEqual(cost, 1.5)