diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp index 922e0d87044..ba559d870c5 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp @@ -15,28 +15,17 @@ namespace native { namespace vulkan { bool OperatorRegistry::has_op(const std::string& name) { - return OperatorRegistry::kTable.count(name) > 0; + return table_.count(name) > 0; } OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn( const std::string& name) { - return OperatorRegistry::kTable.find(name)->second; + return table_.find(name)->second; } -// @lint-ignore-every CLANGTIDY modernize-avoid-bind -// clang-format off -#define OPERATOR_ENTRY(name, function) \ - { #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) } -// clang-format on - -const OperatorRegistry::OpTable OperatorRegistry::kTable = { - OPERATOR_ENTRY(aten.add.Tensor, add), - OPERATOR_ENTRY(aten.sub.Tensor, sub), - OPERATOR_ENTRY(aten.mul.Tensor, mul), - OPERATOR_ENTRY(aten.div.Tensor, div), - OPERATOR_ENTRY(aten.div.Tensor_mode, floor_div), - OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow), -}; +void OperatorRegistry::register_op(const std::string& name, OpFunction& fn) { + table_.insert(std::make_pair(name, fn)); +} OperatorRegistry& operator_registry() { static OperatorRegistry registry; diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h index 9157d9d35a1..1088ab2e44f 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h @@ -20,6 +20,16 @@ #define VK_GET_OP_FN(name) \ ::at::native::vulkan::operator_registry().get_op_fn(name) +#define VK_REGISTER_OP(name, function) \ + ::at::native::vulkan::operator_registry().register_op( \ + #name, \ + std::bind(&function, std::placeholders::_1, std::placeholders::_2)) + +#define REGISTER_OPERATORS \ + static void register_ops(); \ + static const OperatorRegisterInit reg(®ister_ops); \ + static void register_ops() + namespace at { namespace native { namespace vulkan { @@ -35,7 +45,7 @@ class OperatorRegistry final { const std::function&)>; using OpTable = std::unordered_map; - static const OpTable kTable; + OpTable table_; public: /* @@ -47,6 +57,20 @@ class OperatorRegistry final { * Given an operator name, return the Vulkan delegate function */ OpFunction& get_op_fn(const std::string& name); + + /* + * Register a function to a given operator name + */ + void register_op(const std::string& name, OpFunction& fn); +}; + +class OperatorRegisterInit final { + using InitFn = void(); + + public: + explicit OperatorRegisterInit(InitFn* init_fn) { + init_fn(); + } }; // The Vulkan operator registry is global. It is retrieved using this function, diff --git a/backends/vulkan/runtime/graph/ops/Utils.h b/backends/vulkan/runtime/graph/ops/Utils.h index 918318178b3..b79c95eb934 100644 --- a/backends/vulkan/runtime/graph/ops/Utils.h +++ b/backends/vulkan/runtime/graph/ops/Utils.h @@ -16,9 +16,6 @@ namespace at { namespace native { namespace vulkan { -#define DECLARE_OP_FN(function) \ - void function(ComputeGraph& graph, const std::vector& args); - api::utils::ivec4 get_size_as_ivec4(const vTensor& t); void bind_tensor_to_descriptor_set( diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp index 108ff2b2dc0..ff9b4ff2b2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp @@ -9,6 +9,7 @@ #include #include +#include #include @@ -81,6 +82,15 @@ void add_arithmetic_node( std::move(params))); } +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.add.Tensor, add); + VK_REGISTER_OP(aten.sub.Tensor, sub); + VK_REGISTER_OP(aten.mul.Tensor, mul); + VK_REGISTER_OP(aten.div.Tensor, div); + VK_REGISTER_OP(aten.div.Tensor_mode, floor_div); + VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow); +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h index 3ef3cb3e426..8e5c345a92c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h +++ b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h @@ -18,13 +18,6 @@ namespace at { namespace native { namespace vulkan { -DECLARE_OP_FN(add); -DECLARE_OP_FN(sub); -DECLARE_OP_FN(mul); -DECLARE_OP_FN(div); -DECLARE_OP_FN(floor_div); -DECLARE_OP_FN(pow); - void add_arithmetic_node( ComputeGraph& graph, const ValueRef in1, diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 2045cb3725e..105ef501fae 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -56,6 +56,10 @@ def define_common_targets(): "//caffe2:torch_vulkan_spv", ], define_static_target = False, + # Static initialization is used to register operators to the global operator registry, + # therefore link_whole must be True to make sure unused symbols are not discarded. + # @lint-ignore BUCKLINT: Avoid `link_whole=True` + link_whole = True, ) runtime.cxx_library( @@ -81,4 +85,6 @@ def define_common_targets(): # VulkanBackend.cpp needs to compile with executor as whole # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, + # Define an soname that can be used for dynamic loading in Java, Python, etc. + soname = "libvulkan_graph_runtime.$(ext)", ) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a646e0c4ed5..334a2937105 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -11,6 +11,8 @@ #include #include + +#include #include #include @@ -585,8 +587,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { out.value = graph.add_tensor(size_big, api::kFloat); - add_arithmetic_node( - graph, a.value, b.value, kDummyValueRef, out.value, VK_KERNEL(add)); + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, out.value}); out.staging = graph.set_output_tensor(out.value); @@ -636,8 +638,11 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { ValueRef c = graph.add_tensor(size_big, api::kFloat); ValueRef e = graph.add_tensor(size_big, api::kFloat); - add_arithmetic_node(graph, a.value, w1, kDummyValueRef, c, VK_KERNEL(add)); - add_arithmetic_node(graph, c, w2, kDummyValueRef, e, VK_KERNEL(mul)); + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, w1, kDummyValueRef, c}); + + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, w2, e}); IOValueRef out = {}; out.value = e; @@ -697,8 +702,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 6); - add_arithmetic_node( - graph, a.value, b.value, kDummyValueRef, c, VK_KERNEL(add)); + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, c}); IOValueRef d = graph.add_input_tensor( size_small, @@ -716,7 +721,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 4); - add_arithmetic_node(graph, c, d.value, kDummyValueRef, e, VK_KERNEL(mul)); + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, d.value, e}); IOValueRef out = {}; out.value = e;