diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index b5d34418867..a073919c696 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -173,7 +173,7 @@ class GraphBuilder { // Parse the operators for (OpCallPtr op_call : *(flatbuffer_->chain())) { std::string op_name = op_call->name()->str(); - ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str()); + ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str()); const std::vector arg_fb_ids( op_call->args()->cbegin(), op_call->args()->cend()); @@ -183,7 +183,7 @@ class GraphBuilder { args.push_back(get_fb_id_valueref(arg_fb_id)); } - auto vkFn = getOpsFn(op_name); + auto vkFn = VK_GET_OP_FN(op_name); vkFn(*compute_graph_, args); } diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp index 0d46e5b3514..922e0d87044 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp @@ -14,31 +14,19 @@ namespace at { namespace native { namespace vulkan { -bool hasOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().hasOpsFn(name); -} - -OpFunction& getOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().getOpsFn(name); -} - -OperatorRegistry& OperatorRegistry::getInstance() { - static OperatorRegistry instance; - return instance; -} - -bool OperatorRegistry::hasOpsFn(const std::string& name) { +bool OperatorRegistry::has_op(const std::string& name) { return OperatorRegistry::kTable.count(name) > 0; } -OpFunction& OperatorRegistry::getOpsFn(const std::string& name) { +OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn( + const std::string& name) { return OperatorRegistry::kTable.find(name)->second; } // @lint-ignore-every CLANGTIDY modernize-avoid-bind // clang-format off #define OPERATOR_ENTRY(name, function) \ - { #name, std::bind(&at::native::vulkan::function, std::placeholders::_1, std::placeholders::_2) } + { #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) } // clang-format on const OperatorRegistry::OpTable OperatorRegistry::kTable = { @@ -50,6 +38,11 @@ const OperatorRegistry::OpTable OperatorRegistry::kTable = { OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow), }; +OperatorRegistry& operator_registry() { + static OperatorRegistry registry; + return registry; +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h index 06245d889e7..9157d9d35a1 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h @@ -15,42 +15,44 @@ #include #include +#define VK_HAS_OP(name) ::at::native::vulkan::operator_registry().has_op(name) + +#define VK_GET_OP_FN(name) \ + ::at::native::vulkan::operator_registry().get_op_fn(name) + namespace at { namespace native { namespace vulkan { -using OpFunction = - const std::function&)>; - -bool hasOpsFn(const std::string& name); - -OpFunction& getOpsFn(const std::string& name); - -// The Vulkan operator registry is a simplified version of -// fbcode/executorch/runtime/kernel/operator_registry.h -// that uses the C++ Standard Library. -class OperatorRegistry { - public: - static OperatorRegistry& getInstance(); - - bool hasOpsFn(const std::string& name); - OpFunction& getOpsFn(const std::string& name); - - OperatorRegistry(const OperatorRegistry&) = delete; - OperatorRegistry(OperatorRegistry&&) = delete; - OperatorRegistry& operator=(const OperatorRegistry&) = delete; - OperatorRegistry& operator=(OperatorRegistry&&) = delete; - - private: - // TODO: Input string corresponds to target_name. We may need to pass kwargs. +/* + * The Vulkan operator registry maps ATen operator names to their Vulkan + * delegate function implementation. It is a simplified version of + * executorch/runtime/kernel/operator_registry.h that uses the C++ Standard + * Library. + */ +class OperatorRegistry final { + using OpFunction = + const std::function&)>; using OpTable = std::unordered_map; - // @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration + static const OpTable kTable; - OperatorRegistry() = default; - ~OperatorRegistry() = default; + public: + /* + * Check if the registry has an operator registered under the given name + */ + bool has_op(const std::string& name); + + /* + * Given an operator name, return the Vulkan delegate function + */ + OpFunction& get_op_fn(const std::string& name); }; +// The Vulkan operator registry is global. It is retrieved using this function, +// where it is declared as a static local variable. +OperatorRegistry& operator_registry(); + } // namespace vulkan } // namespace native } // namespace at