Skip to content

Commit

Permalink
Merge pull request #4 from ViswanathaReddyGajjala/deepsource-fix-6307…
Browse files Browse the repository at this point in the history
…007c

Remove assert statement from non-test files
  • Loading branch information
ViswanathaReddyGajjala authored Sep 27, 2020
2 parents d9b83b7 + c3ae87b commit 58e2ad6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
6 changes: 4 additions & 2 deletions efficientnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class EfficientNet(nn.Module):

def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), 'blocks_args should be a list'
assert len(blocks_args) > 0, 'block args must be greater than 0'
if not isinstance(blocks_args, list):
raise AssertionError('blocks_args should be a list')
if len(blocks_args) <= 0:
raise AssertionError('block args must be greater than 0')
self._global_params = global_params
self._blocks_args = blocks_args

Expand Down
17 changes: 11 additions & 6 deletions efficientnet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kw
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

# Calculate padding based on image size and save it
assert image_size is not None
if image_size is None:
raise AssertionError
ih, iw = image_size if type(image_size) == list else [image_size, image_size]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
Expand Down Expand Up @@ -163,7 +164,8 @@ class BlockDecoder(object):
@staticmethod
def _decode_block_string(block_string):
""" Gets a block through a string notation of arguments. """
assert isinstance(block_string, str)
if not isinstance(block_string, str):
raise AssertionError

ops = block_string.split('_')
options = {}
Expand All @@ -174,8 +176,9 @@ def _decode_block_string(block_string):
options[key] = value

# Check stride
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
if not (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1])):
raise AssertionError

return BlockArgs(
kernel_size=int(options['k']),
Expand Down Expand Up @@ -212,7 +215,8 @@ def decode(string_list):
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
if not isinstance(string_list, list):
raise AssertionError
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
Expand Down Expand Up @@ -307,5 +311,6 @@ def load_pretrained_weights(model, model_name, load_fc=True):
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
res = model.load_state_dict(state_dict, strict=False)
assert str(res.missing_keys) == str(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
if str(res.missing_keys) != str(['_fc.weight', '_fc.bias']):
raise AssertionError('issue loading pretrained weights')
print('Loaded pretrained weights for {}'.format(model_name))

0 comments on commit 58e2ad6

Please sign in to comment.