From 5dc3b2b4c5968adc360bae4aea3fdd1bf9818348 Mon Sep 17 00:00:00 2001 From: Esteban Padilla Cerdio Date: Tue, 2 Jul 2024 08:48:44 -0700 Subject: [PATCH] aten.hardswish.default in unary_ops (#4087) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4087 Adds [hardswish](https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html?fbclid=IwZXh0bgNhZW0CMTEAAR0uHmWFireZynV9UbC1qtfr774yAf5B6GYQuL2ESD51SscE0OCzr_0ueMg_aem_OjRsr-yJzL_xjusQKqGDVA) to unary_ops. There is no hardswish function in GLSL so it has to be hardcoded in. Reviewed By: SS-JIA Differential Revision: D59117722 fbshipit-source-id: 342e0a194ddc7eda87d2e08e68075d6d3538644e --- backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl | 12 ++++++++++++ backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 8 ++++++++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 23 insertions(+) diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index f7480c7b0a..0cad62d38c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -36,6 +36,18 @@ ${layout_declare_ubo(4, "float", "maximum")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +float hardswish(float x){ + if(x <= -3) { + return 0; + } + else if(x >= 3) { + return x; + } + else { + return x * (x + 3)/6; + } +} + #ifdef USING_BUFFER void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index cc41a579c0..f39abc2134 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -36,3 +36,5 @@ unary_op: OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: hardshrink OPERATOR: X * (vec4(greaterThan(X, vec4(A))) + vec4(lessThan(X, vec4(B)))) + - NAME: hardswish + OPERATOR: vec4(hardswish(X.x),hardswish(X.y),hardswish(X.z),hardswish(X.w)) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 10a05c3d36..d64f82ee63 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -113,6 +113,12 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { "hardshrink"); \ } +#define DEFINE_HARDSWISH_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \ + } + void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -133,6 +139,7 @@ DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); DEFINE_HARDSHRINK_FN(hardshrink); +DEFINE_HARDSWISH_FN(hardswish); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -148,6 +155,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sqrt.default, sqrt); VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); + VK_REGISTER_OP(aten.hardswish.default, hardswish); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8f7cb59740..f50e44be72 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -876,6 +876,7 @@ def get_softmax_inputs(): "aten.sin.default", "aten.neg.default", "aten.cos.default", + "aten.hardswish.default", ] ) def get_unary_ops_inputs():