Skip to content

Commit

Permalink
aten.avg_pool2d (pytorch#3770)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3770

## The Operator
`nn.Module` invocations of [`torch.nn.AvgPool2d`](https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) get compiled to `aten.avg_pool2d.default` in the Edge Dialect, which carries the following signature.
```
- func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
```

## Implementation
This is a full C-packing implementation including dynamic shape support. We start with [LiteInterpreter's `avg_pool2d.glsl` logic](https://github.com/pytorch/pytorch/blob/9257a0698b57acc5607ee6fe31a16fdd93af1731/aten/src/ATen/native/vulkan/glsl/avg_pool2d.glsl), which is incomplete, and cover `ceil_mode=True`,  `count_include_pad=True`, and `divisor_override` cases for full support. As a result, the divisor's computation is now a bit complex. If needed, we can simplify it into separate shaders in the future.
ghstack-source-id: 228476264

Reviewed By: copyrightly

Differential Revision: D57918523

fbshipit-source-id: 8069c4a2dcc5d46da7221d58661e57bf2055b521
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed May 31, 2024
1 parent 8c8d965 commit a463f0b
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 19 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __contains__(self, op):
]

POOLING_OPS = [
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
]

Expand Down
57 changes: 57 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}

#include "indexing_utils.h"

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_ubo(2, "ivec3", "out_limits")}
${layout_declare_ubo(3, "ivec4", "in_sizes")}
${layout_declare_ubo(4, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
${layout_declare_ubo(5, "int", "divisor_override", "int", "count_include_pad")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

const ivec2 ipos = pos.xy * stride - padding;

const ivec2 start = max(ivec2(0), ipos);
const ivec2 end = min(ipos + kernel_size, ivec2(in_sizes.xy));

VEC4_T sum = VEC4_T(0);
for (int y = start.y; y < end.y; ++y) {
for (int x = start.x; x < end.x; ++x) {
sum += texelFetch(t_in, ivec3(x, y, pos.z), 0);
}
}

int div;
if (divisor_override > 0) {
div = divisor_override;
} else if (count_include_pad > 0) {
ivec2 empty = max(ipos + kernel_size - padding - ivec2(in_sizes.xy), ivec2(0));
div = (kernel_size.y - empty.y) * (kernel_size.x - empty.x);
} else {
div = (end.y - start.y) * (end.x - start.x);
}
imageStore(t_out, pos, sum / div);
}
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

avg_pool2d:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int
shader_variants:
- NAME: avg_pool2d
112 changes: 103 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

namespace vkcompute {

void resize_max_pool2d_node(
void check_pool2d_args(const vTensor& in, const vTensor& out) {
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
}

void resize_pool2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
bool is_max_pool2d = extra_args[3] != kDummyValueRef;

vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
vTensorPtr self = graph->get_tensor(args[1].refs[0]);

size_t ndim = self->sizes().size();
Expand All @@ -45,14 +51,17 @@ void resize_max_pool2d_node(
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);

out->virtual_resize(new_out_sizes);
indices->virtual_resize(new_out_sizes);
}

void check_max_pool2d_args(const vTensor& in, const vTensor& out) {
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
if (is_max_pool2d) {
vTensorPtr indices = graph->get_tensor(args[0].refs[1]);
indices->virtual_resize(new_out_sizes);
}
}

//
// max_pool2d
//

void add_max_pool2d_node(
ComputeGraph& graph,
const ValueRef in,
Expand All @@ -68,7 +77,7 @@ void add_max_pool2d_node(
const auto out_val = graph.get_value_list(out);
vTensorPtr t_out = graph.get_tensor(out_val->at(0));

check_max_pool2d_args(*t_in, *t_out);
check_pool2d_args(*t_in, *t_out);

api::utils::uvec3 global_size = t_out->image_extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
Expand Down Expand Up @@ -101,7 +110,7 @@ void add_max_pool2d_node(
// Specialization Constants
{},
// Resizing Logic
resize_max_pool2d_node,
resize_pool2d_node,
{kernel_size, stride, padding, dilation, ceil_mode}));
}

Expand All @@ -110,7 +119,92 @@ void max_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
graph, args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
}

//
// avg_pool2d
//

struct DivisorParams final {
int32_t divisor_override;
bool count_include_pad;
};

DivisorParams create_divisor_params(
ComputeGraph& graph,
const ValueRef divisor_override,
const ValueRef count_include_pad) {
return {
graph.val_is_int(divisor_override)
? static_cast<int32_t>(graph.get_int(divisor_override))
: 0,
graph.get_bool(count_include_pad)};
}

void add_avg_pool2d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef ceil_mode,
const ValueRef count_include_pad,
const ValueRef divisor_override,
const ValueRef out) {
ValueRef arg = prepack_if_tensor_ref(graph, in);
vTensorPtr t_in = graph.get_tensor(arg);
vTensorPtr t_out = graph.get_tensor(out);

check_pool2d_args(*t_in, *t_out);

api::utils::uvec3 global_size = t_out->image_extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

std::string kernel_name("avg_pool2d");
add_dtype_suffix(kernel_name, *t_out);

Kernel2dParams kernel_params =
create_kernel2d_params(graph, kernel_size, stride, padding);

DivisorParams divisor_params =
create_divisor_params(graph, divisor_override, count_include_pad);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_size,
local_size,
// Inputs and Outputs
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
// Shader params buffers
{t_out->texture_limits_ubo(),
t_in->sizes_ubo(),
graph.create_params_buffer(kernel_params),
graph.create_params_buffer(divisor_params)},
// Specialization Constants
{},
// Resizing Logic
resize_pool2d_node,
{kernel_size,
stride,
padding,
/*dilation= */ kDummyValueRef,
ceil_mode}));
}

void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_avg_pool2d_node(
graph,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
}

Expand Down
17 changes: 16 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ Kernel2dParams create_kernel2d_params(
};
}

Kernel2dParams create_kernel2d_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding) {
return {
make_ivec2_kernel_size(graph, kernel_size, /*kernel_size_only = */ true),
make_ivec2_from_list(graph, stride),
make_ivec2_from_list(graph, padding),
{},
};
}

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel_size,
Expand Down Expand Up @@ -143,7 +156,9 @@ std::vector<int64_t> calc_out_sizes_hw(
make_ivec2_kernel_size(graph, weight, kernel_size_only);
const auto stride = make_ivec2_from_list(graph, args[0]);
const auto padding = make_ivec2_from_list(graph, args[1]);
const auto dilation = make_ivec2_from_list(graph, args[2]);
const auto dilation = args[2] == kDummyValueRef
? api::utils::ivec2{1, 1}
: make_ivec2_from_list(graph, args[2]);

if (transposed) {
const auto output_padding = make_ivec2_from_list(graph, args[3]);
Expand Down
20 changes: 13 additions & 7 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@

namespace vkcompute {

struct Kernel2dParams final {
api::utils::ivec2 kernel_size;
api::utils::ivec2 stride;
api::utils::ivec2 padding;
api::utils::ivec2 dilation;
};

struct Kernel1dParams final {
int kernel_size;
int stride;
Expand All @@ -32,6 +25,13 @@ struct Kernel1dParams final {
int out_group_size;
};

struct Kernel2dParams final {
api::utils::ivec2 kernel_size;
api::utils::ivec2 stride;
api::utils::ivec2 padding;
api::utils::ivec2 dilation;
};

Kernel2dParams create_kernel2d_params(
ComputeGraph& graph,
const ValueRef weight,
Expand All @@ -40,6 +40,12 @@ Kernel2dParams create_kernel2d_params(
const ValueRef padding,
const ValueRef dilation);

Kernel2dParams create_kernel2d_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding);

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel_size,
Expand Down
60 changes: 58 additions & 2 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,62 @@ def get_linear_inputs():
return test_suite


def get_pool2d_inputs():
def get_avg_pool2d_inputs():
Test = namedtuple(
"VkAvgPoolTest",
[
"self",
"kernel_size",
"stride",
"padding",
"ceil_mode",
"count_include_pad",
"divisor_override",
],
)
Test.__new__.__defaults__ = (None, None)

test_cases = []

for ceil_mode in [True, False]:
for count_include_pad in [True, False]:
for divisor_override in [None, 5]:
test_cases += [
Test(
self=(S, M1, M2),
kernel_size=[2, 2],
stride=[1, 1],
padding=[0, 0],
ceil_mode=ceil_mode,
count_include_pad=count_include_pad,
divisor_override=divisor_override,
),
]
test_cases += [
Test(
self=(S, M1, M2),
kernel_size=[5, 4],
stride=[3, 1],
padding=[2, 1],
ceil_mode=ceil_mode,
count_include_pad=True,
divisor_override=None,
),
Test(
self=(S, M1, M2),
kernel_size=[4, 5],
stride=[1, 3],
padding=[2, 1],
ceil_mode=ceil_mode,
count_include_pad=True,
divisor_override=None,
),
]
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
return test_suite


def get_max_pool2d_inputs():
test_suite = VkTestSuite(
[
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
Expand Down Expand Up @@ -869,7 +924,8 @@ def get_arange_inputs():
"aten.bmm.default": get_bmm_inputs(),
"aten.mm.default": get_mm_inputs(),
"aten.linear.default": get_linear_inputs(),
"aten.max_pool2d_with_indices.default": get_pool2d_inputs(),
"aten.avg_pool2d.default": get_avg_pool2d_inputs(),
"aten.max_pool2d_with_indices.default": get_max_pool2d_inputs(),
"aten.convolution.default": get_conv_inputs(),
"aten.native_layer_norm.default": get_native_layer_norm_inputs(),
"aten.full.default": get_full_inputs(),
Expand Down
Loading

0 comments on commit a463f0b

Please sign in to comment.