Skip to content

Commit

Permalink
Merge pull request #1409 from ParallelogramPal/master
Browse files Browse the repository at this point in the history
Calculate Conv2D Output Shape in Accordance with PyTorch Docs
  • Loading branch information
mkskeller authored May 23, 2024
2 parents 5ba7e71 + 10484c0 commit f8fc839
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Compiler/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2967,11 +2967,11 @@ def apply_padding(input_shape, kernel_size, strides, padding):
if isinstance(padding, int):
padding = [padding, padding]
if isinstance(padding, (tuple, list)):
input_shape = [x + sum(padding) for x in input_shape]
input_shape = [input_shape[i] + 2*padding[i] for i in range(len(input_shape))]
padding = 'valid'
if padding.lower() == 'valid':
res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \
(input_shape[1] - kernel_size[1] + 1) // strides[1],
res = (input_shape[0] - kernel_size[0]) // strides[0] + 1, \
(input_shape[1] - kernel_size[1]) // strides[1] + 1,
assert min(res) > 0, (input_shape, kernel_size, strides, padding)
return res
elif padding.lower() == 'same':
Expand Down

0 comments on commit f8fc839

Please sign in to comment.