-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_resnet.py
56 lines (36 loc) · 1.34 KB
/
test_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import pytest
from keras import backend as K
from resnet import ResnetBuilder
DIM_ORDERING = {'th', 'tf'}
def _test_model_compile(model):
for ordering in DIM_ORDERING:
K.set_image_dim_ordering(ordering)
model.compile(loss="categorical_crossentropy", optimizer="sgd")
assert True, "Failed to compile with '{}' dim ordering".format(ordering)
def test_resnet18():
model = ResnetBuilder.build_resnet_18((3, 224, 224), 100)
_test_model_compile(model)
def test_resnet34():
model = ResnetBuilder.build_resnet_34((3, 224, 224), 100)
_test_model_compile(model)
def test_resnet50():
model = ResnetBuilder.build_resnet_50((3, 224, 224), 100)
_test_model_compile(model)
def test_resnet101():
model = ResnetBuilder.build_resnet_101((3, 224, 224), 100)
_test_model_compile(model)
def test_resnet152():
model = ResnetBuilder.build_resnet_152((3, 224, 224), 100)
_test_model_compile(model)
def test_custom1():
""" https://github.com/raghakot/keras-resnet/issues/34
"""
model = ResnetBuilder.build_resnet_152((3, 300, 300), 100)
_test_model_compile(model)
def test_custom2():
""" https://github.com/raghakot/keras-resnet/issues/34
"""
model = ResnetBuilder.build_resnet_152((3, 512, 512), 2)
_test_model_compile(model)
if __name__ == '__main__':
pytest.main([__file__])