From f954c3e37dd044f8d697b6555fdab868502350dd Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 31 Mar 2025 21:08:05 -0700 Subject: [PATCH] [ET-VK] Adding round op support. This diff adds support for the round op in the Vulkan backend for Executorch. Differential Revision: [D72218482](https://our.internmc.facebook.com/intern/diff/D72218482/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 1 + backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 2 ++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 6 insertions(+) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 54b7b8651bc..b33430a6bca 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -276,6 +276,7 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.round.default, ] ) def register_unary_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 6757d2a6d45..f13393ce6c7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -44,3 +44,5 @@ unary_op: OPERATOR: hardsigmoid(X) - NAME: leaky_relu OPERATOR: leaky_relu(X, A) + - NAME: round + OPERATOR: round(X) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 4bf73fad5a1..9a3ab002403 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -149,6 +149,7 @@ DEFINE_HARDSHRINK_FN(hardshrink); DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); +DEFINE_ACTIVATION_FN(round); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -168,6 +169,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardswish.default, hardswish); VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); + VK_REGISTER_OP(aten.round.default, round); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 418ef9cd208..625c5e6fbc5 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1090,6 +1090,7 @@ def get_reduce_op_inputs(): "aten.hardswish.default", "aten.hardsigmoid.default", "aten.leaky_relu.default", + "aten.round.default", ] ) def get_unary_ops_inputs():