From 8f2800b43239494bb86bef0987dcb1472536d0e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 28 Sep 2023 09:20:59 +0000 Subject: [PATCH] [mlir][transform] Make variable names in interpreter consistent. (NFC) This commit renames the arguments of several static implementation functions of the transform interpreter base class to match the names of the corresponding member variables in order to clarify their intent. Similarly, it renames some local variables to reflect their relationship with corresponding member variables. Finally, this commit also asserts in `interpreterBaseRunOnOperationImpl` that at most one of shared and library module are set (which the initialization function guarantees) and simplifies some related `if` conditions. --- .../TransformInterpreterPassBase.cpp | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index d5c65b23e3a21..47fa5cde11907 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -379,7 +379,7 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - const std::shared_ptr> &libraryModule, + const std::shared_ptr> &transformLibraryModule, const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, @@ -387,6 +387,12 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName) { + bool hasSharedTransformModule = + sharedTransformModule && *sharedTransformModule; + bool hasTransformLibraryModule = + transformLibraryModule && *transformLibraryModule; + assert((!hasSharedTransformModule || !hasTransformLibraryModule) && + "at most one of shared or library transform module can be set"); // Step 1 // ------ @@ -407,9 +413,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // transform is embedded in the payload IR. If debugTransformRootTag was // passed, then we are in user-specified selection of the transforming IR. // This corresponds to REPL debug mode. - bool sharedTransform = (sharedTransformModule && *sharedTransformModule); Operation *transformContainer = - sharedTransform ? sharedTransformModule->get() : target; + hasSharedTransformModule ? sharedTransformModule->get() : target; Operation *transformRoot = debugTransformRootTag.empty() ? findTopLevelTransform(transformContainer, @@ -430,7 +435,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // Copy external defintions for symbols if provided. Be aware of potential // concurrent execution (normally, the error shouldn't be triggered unless the // transform IR modifies itself in a pass, which is also forbidden elsewhere). - if (!sharedTransform && libraryModule && *libraryModule) { + if (hasTransformLibraryModule) { if (!target->isProperAncestor(transformRoot)) { InFlightDiagnostic diag = transformRoot->emitError() @@ -439,7 +444,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( return diag; } if (failed(defineDeclaredSymbols(*transformRoot->getBlock(), - libraryModule->get()))) + transformLibraryModule->get()))) return failure(); } @@ -461,25 +466,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( LogicalResult transform::detail::interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, - std::shared_ptr> &module, - std::shared_ptr> &libraryModule, + std::shared_ptr> &sharedTransformModule, + std::shared_ptr> &transformLibraryModule, function_ref(OpBuilder &, Location)> moduleBuilder) { - OwningOpRef parsed; - if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) + OwningOpRef parsedTransformModule; + if (failed(parseTransformModuleFromFile(context, transformFileName, + parsedTransformModule))) return failure(); - if (parsed && failed(mlir::verify(*parsed))) + if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule))) return failure(); - OwningOpRef parsedLibrary; + OwningOpRef parsedLibraryModule; if (failed(parseTransformModuleFromFile(context, transformLibraryFileName, - parsedLibrary))) + parsedLibraryModule))) return failure(); - if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) + if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule))) return failure(); - if (parsed) { - module = std::make_shared>(std::move(parsed)); + if (parsedTransformModule) { + sharedTransformModule = std::make_shared>( + std::move(parsedTransformModule)); } else if (moduleBuilder) { // TODO: better location story. auto location = UnknownLoc::get(context); @@ -491,20 +498,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( if (std::optional result = moduleBuilder(b, location)) { if (failed(*result)) return failure(); - module = std::move(localModule); + sharedTransformModule = std::move(localModule); } } - if (!parsedLibrary || !*parsedLibrary) + if (!parsedLibraryModule || !*parsedLibraryModule) return success(); - if (module && *module) { - if (failed(defineDeclaredSymbols(*module->get().getBody(), - parsedLibrary.get()))) + if (sharedTransformModule && *sharedTransformModule) { + if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(), + parsedLibraryModule.get()))) return failure(); } else { - libraryModule = - std::make_shared>(std::move(parsedLibrary)); + transformLibraryModule = + std::make_shared>(std::move(parsedLibraryModule)); } return success(); }