From 9c0a23b61350ca58abb9b7a31ca004da65eb7923 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Thu, 7 Mar 2024 14:32:09 -0800 Subject: [PATCH] [ET-VK][EZ] Clean up OperatorRegistry Align `OperatorRegistry` with the style of `ShaderRegistry` in https://github.com/pytorch/executorch/pull/2222 This means - Improve comments and comment formatting. - Use snake case, even if it deviates from the original registry I was following. Snake case is more consistent with the Vulkan backend code. https://www.internalfb.com/code/fbsource/[a97f9ed1a715231bb61b05942273f1e8f8631503]/fbcode/executorch/runtime/kernel/operator_registry.h?lines=208%2C213 - Move `using` declarations and member variables to top of class definition. - Place static `OperatorRegistry` instance declaration in a global function `operator_registry()` instead of in member function `getInstance()`. - Use macros to wrap `OperatorRegistry` functions instead of global functions. - For simplicity, remove unneeded ctor and assignment operator deletion/hiding. Note users can now create their own non-static `OperatorRegistry` instance and we can consider hiding this again later. Differential Revision: [D54640160](https://our.internmc.facebook.com/intern/diff/D54640160/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 4 +- .../runtime/graph/ops/OperatorRegistry.cpp | 25 +++------ .../runtime/graph/ops/OperatorRegistry.h | 56 ++++++++++--------- 3 files changed, 40 insertions(+), 45 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index b5d34418867..a073919c696 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -173,7 +173,7 @@ class GraphBuilder { // Parse the operators for (OpCallPtr op_call : *(flatbuffer_->chain())) { std::string op_name = op_call->name()->str(); - ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str()); + ET_CHECK_MSG(VK_HAS_OP(op_name), "Missing operator: %s", op_name.c_str()); const std::vector arg_fb_ids( op_call->args()->cbegin(), op_call->args()->cend()); @@ -183,7 +183,7 @@ class GraphBuilder { args.push_back(get_fb_id_valueref(arg_fb_id)); } - auto vkFn = getOpsFn(op_name); + auto vkFn = VK_GET_OP_FN(op_name); vkFn(*compute_graph_, args); } diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp index 0d46e5b3514..922e0d87044 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp @@ -14,31 +14,19 @@ namespace at { namespace native { namespace vulkan { -bool hasOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().hasOpsFn(name); -} - -OpFunction& getOpsFn(const std::string& name) { - return OperatorRegistry::getInstance().getOpsFn(name); -} - -OperatorRegistry& OperatorRegistry::getInstance() { - static OperatorRegistry instance; - return instance; -} - -bool OperatorRegistry::hasOpsFn(const std::string& name) { +bool OperatorRegistry::has_op(const std::string& name) { return OperatorRegistry::kTable.count(name) > 0; } -OpFunction& OperatorRegistry::getOpsFn(const std::string& name) { +OperatorRegistry::OpFunction& OperatorRegistry::get_op_fn( + const std::string& name) { return OperatorRegistry::kTable.find(name)->second; } // @lint-ignore-every CLANGTIDY modernize-avoid-bind // clang-format off #define OPERATOR_ENTRY(name, function) \ - { #name, std::bind(&at::native::vulkan::function, std::placeholders::_1, std::placeholders::_2) } + { #name, std::bind(&function, std::placeholders::_1, std::placeholders::_2) } // clang-format on const OperatorRegistry::OpTable OperatorRegistry::kTable = { @@ -50,6 +38,11 @@ const OperatorRegistry::OpTable OperatorRegistry::kTable = { OPERATOR_ENTRY(aten.pow.Tensor_Tensor, pow), }; +OperatorRegistry& operator_registry() { + static OperatorRegistry registry; + return registry; +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h index 06245d889e7..9157d9d35a1 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h @@ -15,42 +15,44 @@ #include #include +#define VK_HAS_OP(name) ::at::native::vulkan::operator_registry().has_op(name) + +#define VK_GET_OP_FN(name) \ + ::at::native::vulkan::operator_registry().get_op_fn(name) + namespace at { namespace native { namespace vulkan { -using OpFunction = - const std::function&)>; - -bool hasOpsFn(const std::string& name); - -OpFunction& getOpsFn(const std::string& name); - -// The Vulkan operator registry is a simplified version of -// fbcode/executorch/runtime/kernel/operator_registry.h -// that uses the C++ Standard Library. -class OperatorRegistry { - public: - static OperatorRegistry& getInstance(); - - bool hasOpsFn(const std::string& name); - OpFunction& getOpsFn(const std::string& name); - - OperatorRegistry(const OperatorRegistry&) = delete; - OperatorRegistry(OperatorRegistry&&) = delete; - OperatorRegistry& operator=(const OperatorRegistry&) = delete; - OperatorRegistry& operator=(OperatorRegistry&&) = delete; - - private: - // TODO: Input string corresponds to target_name. We may need to pass kwargs. +/* + * The Vulkan operator registry maps ATen operator names to their Vulkan + * delegate function implementation. It is a simplified version of + * executorch/runtime/kernel/operator_registry.h that uses the C++ Standard + * Library. + */ +class OperatorRegistry final { + using OpFunction = + const std::function&)>; using OpTable = std::unordered_map; - // @lint-ignore CLANGTIDY facebook-hte-NonPodStaticDeclaration + static const OpTable kTable; - OperatorRegistry() = default; - ~OperatorRegistry() = default; + public: + /* + * Check if the registry has an operator registered under the given name + */ + bool has_op(const std::string& name); + + /* + * Given an operator name, return the Vulkan delegate function + */ + OpFunction& get_op_fn(const std::string& name); }; +// The Vulkan operator registry is global. It is retrieved using this function, +// where it is declared as a static local variable. +OperatorRegistry& operator_registry(); + } // namespace vulkan } // namespace native } // namespace at