diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 54b7b8651bc..b33430a6bca 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -276,6 +276,7 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.round.default, ] ) def register_unary_op(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 6757d2a6d45..f13393ce6c7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -44,3 +44,5 @@ unary_op: OPERATOR: hardsigmoid(X) - NAME: leaky_relu OPERATOR: leaky_relu(X, A) + - NAME: round + OPERATOR: round(X) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 4bf73fad5a1..9a3ab002403 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -149,6 +149,7 @@ DEFINE_HARDSHRINK_FN(hardshrink); DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); +DEFINE_ACTIVATION_FN(round); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -168,6 +169,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardswish.default, hardswish); VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); + VK_REGISTER_OP(aten.round.default, round); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 88f5bea5c3e..85008a52ff0 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1092,6 +1092,7 @@ def get_reduce_op_inputs(): "aten.hardswish.default", "aten.hardsigmoid.default", "aten.leaky_relu.default", + "aten.round.default", ] ) def get_unary_ops_inputs():