forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: pytorch#3770 ## The Operator `nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature. ``` - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor ``` ## Implementation This is a full C-packing implementation including dynamic shape support. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`, `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future. ghstack-source-id: 228476264 Reviewed By: copyrightly Differential Revision: D57918523 fbshipit-source-id: 8069c4a2dcc5d46da7221d58661e57bf2055b521
- Loading branch information
1 parent
8c8d965
commit a463f0b
Showing
8 changed files
with
300 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#version 450 core | ||
|
||
#define PRECISION ${PRECISION} | ||
|
||
#define VEC4_T ${texel_type(DTYPE)} | ||
|
||
#include "indexing_utils.h" | ||
|
||
layout(std430) buffer; | ||
|
||
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} | ||
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} | ||
${layout_declare_ubo(2, "ivec3", "out_limits")} | ||
${layout_declare_ubo(3, "ivec4", "in_sizes")} | ||
${layout_declare_ubo(4, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")} | ||
${layout_declare_ubo(5, "int", "divisor_override", "int", "count_include_pad")} | ||
|
||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
||
void main() { | ||
const ivec3 pos = ivec3(gl_GlobalInvocationID); | ||
|
||
if (any(greaterThanEqual(pos, out_limits))) { | ||
return; | ||
} | ||
|
||
const ivec2 ipos = pos.xy * stride - padding; | ||
|
||
const ivec2 start = max(ivec2(0), ipos); | ||
const ivec2 end = min(ipos + kernel_size, ivec2(in_sizes.xy)); | ||
|
||
VEC4_T sum = VEC4_T(0); | ||
for (int y = start.y; y < end.y; ++y) { | ||
for (int x = start.x; x < end.x; ++x) { | ||
sum += texelFetch(t_in, ivec3(x, y, pos.z), 0); | ||
} | ||
} | ||
|
||
int div; | ||
if (divisor_override > 0) { | ||
div = divisor_override; | ||
} else if (count_include_pad > 0) { | ||
ivec2 empty = max(ipos + kernel_size - padding - ivec2(in_sizes.xy), ivec2(0)); | ||
div = (kernel_size.y - empty.y) * (kernel_size.x - empty.x); | ||
} else { | ||
div = (end.y - start.y) * (end.x - start.x); | ||
} | ||
imageStore(t_out, pos, sum / div); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
avg_pool2d: | ||
parameter_names_with_default_values: | ||
DTYPE: float | ||
NDIM: 3 | ||
STORAGE: texture3d | ||
generate_variant_forall: | ||
DTYPE: | ||
- VALUE: half | ||
- VALUE: float | ||
- VALUE: int | ||
shader_variants: | ||
- NAME: avg_pool2d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.