diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index f7480c7b0a..0cad62d38c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -36,6 +36,18 @@ ${layout_declare_ubo(4, "float", "maximum")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +float hardswish(float x){ + if(x <= -3) { + return 0; + } + else if(x >= 3) { + return x; + } + else { + return x * (x + 3)/6; + } +} + #ifdef USING_BUFFER void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index cc41a579c0..f39abc2134 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -36,3 +36,5 @@ unary_op: OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: hardshrink OPERATOR: X * (vec4(greaterThan(X, vec4(A))) + vec4(lessThan(X, vec4(B)))) + - NAME: hardswish + OPERATOR: vec4(hardswish(X.x),hardswish(X.y),hardswish(X.z),hardswish(X.w)) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 10a05c3d36..d64f82ee63 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -113,6 +113,12 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { "hardshrink"); \ } +#define DEFINE_HARDSWISH_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, args[0], kDummyFloat, kDummyFloat, args[1], #op_name); \ + } + void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -133,6 +139,7 @@ DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); DEFINE_HARDSHRINK_FN(hardshrink); +DEFINE_HARDSWISH_FN(hardswish); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -148,6 +155,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.sqrt.default, sqrt); VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.hardshrink.default, hardshrink); + VK_REGISTER_OP(aten.hardswish.default, hardswish); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8f7cb59740..f50e44be72 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -876,6 +876,7 @@ def get_softmax_inputs(): "aten.sin.default", "aten.neg.default", "aten.cos.default", + "aten.hardswish.default", ] ) def get_unary_ops_inputs():