diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index d5c65b23e3a21..47fa5cde11907 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -379,7 +379,7 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, - const std::shared_ptr> &libraryModule, + const std::shared_ptr> &transformLibraryModule, const RaggedArray &extraMappings, const TransformOptions &options, const Pass::Option &transformFileName, @@ -387,6 +387,12 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, StringRef binaryName) { + bool hasSharedTransformModule = + sharedTransformModule && *sharedTransformModule; + bool hasTransformLibraryModule = + transformLibraryModule && *transformLibraryModule; + assert((!hasSharedTransformModule || !hasTransformLibraryModule) && + "at most one of shared or library transform module can be set"); // Step 1 // ------ @@ -407,9 +413,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // transform is embedded in the payload IR. If debugTransformRootTag was // passed, then we are in user-specified selection of the transforming IR. // This corresponds to REPL debug mode. - bool sharedTransform = (sharedTransformModule && *sharedTransformModule); Operation *transformContainer = - sharedTransform ? sharedTransformModule->get() : target; + hasSharedTransformModule ? sharedTransformModule->get() : target; Operation *transformRoot = debugTransformRootTag.empty() ? findTopLevelTransform(transformContainer, @@ -430,7 +435,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( // Copy external defintions for symbols if provided. Be aware of potential // concurrent execution (normally, the error shouldn't be triggered unless the // transform IR modifies itself in a pass, which is also forbidden elsewhere). - if (!sharedTransform && libraryModule && *libraryModule) { + if (hasTransformLibraryModule) { if (!target->isProperAncestor(transformRoot)) { InFlightDiagnostic diag = transformRoot->emitError() @@ -439,7 +444,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( return diag; } if (failed(defineDeclaredSymbols(*transformRoot->getBlock(), - libraryModule->get()))) + transformLibraryModule->get()))) return failure(); } @@ -461,25 +466,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( LogicalResult transform::detail::interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, - std::shared_ptr> &module, - std::shared_ptr> &libraryModule, + std::shared_ptr> &sharedTransformModule, + std::shared_ptr> &transformLibraryModule, function_ref(OpBuilder &, Location)> moduleBuilder) { - OwningOpRef parsed; - if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) + OwningOpRef parsedTransformModule; + if (failed(parseTransformModuleFromFile(context, transformFileName, + parsedTransformModule))) return failure(); - if (parsed && failed(mlir::verify(*parsed))) + if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule))) return failure(); - OwningOpRef parsedLibrary; + OwningOpRef parsedLibraryModule; if (failed(parseTransformModuleFromFile(context, transformLibraryFileName, - parsedLibrary))) + parsedLibraryModule))) return failure(); - if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) + if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule))) return failure(); - if (parsed) { - module = std::make_shared>(std::move(parsed)); + if (parsedTransformModule) { + sharedTransformModule = std::make_shared>( + std::move(parsedTransformModule)); } else if (moduleBuilder) { // TODO: better location story. auto location = UnknownLoc::get(context); @@ -491,20 +498,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( if (std::optional result = moduleBuilder(b, location)) { if (failed(*result)) return failure(); - module = std::move(localModule); + sharedTransformModule = std::move(localModule); } } - if (!parsedLibrary || !*parsedLibrary) + if (!parsedLibraryModule || !*parsedLibraryModule) return success(); - if (module && *module) { - if (failed(defineDeclaredSymbols(*module->get().getBody(), - parsedLibrary.get()))) + if (sharedTransformModule && *sharedTransformModule) { + if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(), + parsedLibraryModule.get()))) return failure(); } else { - libraryModule = - std::make_shared>(std::move(parsedLibrary)); + transformLibraryModule = + std::make_shared>(std::move(parsedLibraryModule)); } return success(); }