-
Notifications
You must be signed in to change notification settings - Fork 627
/
Copy pathtest_fbnet.py
executable file
·84 lines (66 loc) · 2.77 KB
/
test_fbnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import numpy as np
import torch
import fcos_core.modeling.backbone.fbnet_builder as fbnet_builder
TEST_CUDA = torch.cuda.is_available()
def _test_primitive(self, device, op_name, op_func, N, C_in, C_out, expand, stride):
op = op_func(C_in, C_out, expand, stride).to(device)
input = torch.rand([N, C_in, 7, 7], dtype=torch.float32).to(device)
output = op(input)
self.assertEqual(
output.shape[:2], torch.Size([N, C_out]),
'Primitive {} failed for shape {}.'.format(op_name, input.shape)
)
class TestFBNetBuilder(unittest.TestCase):
def test_identity(self):
id_op = fbnet_builder.Identity(20, 20, 1)
input = torch.rand([10, 20, 7, 7], dtype=torch.float32)
output = id_op(input)
np.testing.assert_array_equal(np.array(input), np.array(output))
id_op = fbnet_builder.Identity(20, 40, 2)
input = torch.rand([10, 20, 7, 7], dtype=torch.float32)
output = id_op(input)
np.testing.assert_array_equal(output.shape, [10, 40, 4, 4])
def test_primitives(self):
''' Make sures the primitives runs '''
for op_name, op_func in fbnet_builder.PRIMITIVES.items():
print('Testing {}'.format(op_name))
_test_primitive(
self, "cpu",
op_name, op_func,
N=20, C_in=16, C_out=32, expand=4, stride=1
)
@unittest.skipIf(not TEST_CUDA, "no CUDA detected")
def test_primitives_cuda(self):
''' Make sures the primitives runs on cuda '''
for op_name, op_func in fbnet_builder.PRIMITIVES.items():
print('Testing {}'.format(op_name))
_test_primitive(
self, "cuda",
op_name, op_func,
N=20, C_in=16, C_out=32, expand=4, stride=1
)
def test_primitives_empty_batch(self):
''' Make sures the primitives runs '''
for op_name, op_func in fbnet_builder.PRIMITIVES.items():
print('Testing {}'.format(op_name))
# test empty batch size
_test_primitive(
self, "cpu",
op_name, op_func,
N=0, C_in=16, C_out=32, expand=4, stride=1
)
@unittest.skipIf(not TEST_CUDA, "no CUDA detected")
def test_primitives_cuda_empty_batch(self):
''' Make sures the primitives runs '''
for op_name, op_func in fbnet_builder.PRIMITIVES.items():
print('Testing {}'.format(op_name))
# test empty batch size
_test_primitive(
self, "cuda",
op_name, op_func,
N=0, C_in=16, C_out=32, expand=4, stride=1
)
if __name__ == "__main__":
unittest.main()