From 10484c08e39aa267215eca1ac17060bb485a17f6 Mon Sep 17 00:00:00 2001 From: Parker Diamond Date: Mon, 20 May 2024 15:50:45 -0400 Subject: [PATCH] Calculate output shape in accordance to PyTorch docs --- Compiler/ml.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 4cb699579..63f815380 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -2967,11 +2967,11 @@ def apply_padding(input_shape, kernel_size, strides, padding): if isinstance(padding, int): padding = [padding, padding] if isinstance(padding, (tuple, list)): - input_shape = [x + sum(padding) for x in input_shape] + input_shape = [input_shape[i] + 2*padding[i] for i in range(len(input_shape))] padding = 'valid' if padding.lower() == 'valid': - res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \ - (input_shape[1] - kernel_size[1] + 1) // strides[1], + res = (input_shape[0] - kernel_size[0]) // strides[0] + 1, \ + (input_shape[1] - kernel_size[1]) // strides[1] + 1, assert min(res) > 0, (input_shape, kernel_size, strides, padding) return res elif padding.lower() == 'same':