From 59b1331158d27032cf8a9346ef72f1ee28a5578f Mon Sep 17 00:00:00 2001 From: APassbyDreg <36123017+APassbyDreg@users.noreply.github.com> Date: Sat, 7 Oct 2023 20:57:05 +0800 Subject: [PATCH] fix kv check of avg pool --- spconv/pytorch/pool.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spconv/pytorch/pool.py b/spconv/pytorch/pool.py index 2384507..5659993 100644 --- a/spconv/pytorch/pool.py +++ b/spconv/pytorch/pool.py @@ -314,8 +314,10 @@ def __init__(self, self.record_voxel_count = record_voxel_count self.dilation = expand_nd(ndim, dilation) self.indice_key = indice_key - kv = int(np.prod(kernel_size)) - assert kv <= 32, "avg pool only support implicit-gemm style indice gen with kv <= 32 limit" + kv = int(np.prod(self.kernel_size)) + assert kv <= 128, "avg pool only support implicit-gemm style indice gen with kv <= 128 limit" + if kv >= 32: + pass # maybe show some warnings here self.algo = ConvAlgo.MaskImplicitGemm def extra_repr(self): @@ -581,4 +583,4 @@ def __init__(self, ALL_POOL_LAYERS = set([ SparseAvgPool3d, SparseAvgPool2d, SparseAvgPool1d, SparseMaxPool1d, SparseMaxPool2d, SparseMaxPool3d, SparseMaxPool4d, SparseAvgPool, SparseMaxPool -]) \ No newline at end of file +])