From 3b35edfa524eb28908b86415aeb3b0186a6aa423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 22 Sep 2023 15:10:30 +0000 Subject: [PATCH 01/13] [mlir][transform] Fix handling of transitive include in interpreter. Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of `FunctionOpInterface`) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise. --- .../TransformInterpreterPassBase.cpp | 284 ++++++++++++++---- ...ter-external-symbol-decl-and-schedule.mlir | 6 + ...erpreter-external-symbol-decl-invalid.mlir | 23 +- ...test-interpreter-external-symbol-decl.mlir | 61 +++- ...terpreter-external-symbol-def-invalid.mlir | 14 + .../test-interpreter-external-symbol-def.mlir | 43 ++- 6 files changed, 353 insertions(+), 78 deletions(-) create mode 100644 mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 23640c92457a8..7d09be07f66d9 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -302,77 +302,236 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } -/// Replaces external symbols in `block` with their (non-external) definitions -/// from the given module. -static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { - MLIRContext &ctx = *definitions->getContext(); - auto consumedName = - StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName); - auto readOnlyName = - StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName); - - for (Operation &op : llvm::make_early_inc_range(block)) { - LLVM_DEBUG(DBGS() << op << "\n"); - auto symbol = dyn_cast(op); - if (!symbol) - continue; - if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()) - continue; - - LLVM_DEBUG(DBGS() << "looking for definition of symbol " - << symbol.getNameAttr() << ":"); - SymbolTable symbolTable(definitions); - Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr()); - if (!externalSymbol || externalSymbol->getNumRegions() != 1 || - externalSymbol->getRegion(0).empty()) { - LLVM_DEBUG(llvm::dbgs() << "not found\n"); - continue; - } +/// Merge all symbols from `other` into `target`. Both ops need to implement the +/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be +/// modified by this function. Upon merging, private symbols may be renamed in +/// order to avoid collisions in the result. Public symbols may not collide, +/// with the exception of `SymbolInterfaceOp`s, where collisions are allowed if +/// at least one of the two is external, in which case the other op preserved +/// (or one of the two if both are external). The `target` op might not verify +/// after this function returns. +// XXX: Make `other` argument an `OwningOpRef`? +static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { + assert(target->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + assert(other->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + + MLIRContext *context = other->getContext(); + auto consumedName = StringAttr::get( + context, transform::TransformDialect::kArgConsumedAttrName); + auto readOnlyName = StringAttr::get( + context, transform::TransformDialect::kArgReadOnlyAttrName); + + int uniqueId = 0; + + auto canBeMerged = [](FunctionOpInterface func1, FunctionOpInterface func2) { + return func1.isExternal() && (func2.isPublic() || func2.isExternal()); + ; + }; + + // Rename private symbols in both ops in order to resolve conflicts that can + // be resolved that way. + LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + for (auto [symbolTableOp, otherSymbolTableOp] : + llvm::zip(SmallVector{target, other}, + SmallVector{other, target})) { + SymbolTable symbolTable(symbolTableOp); // XXX: build only once + SymbolTable otherSymbolTable(otherSymbolTableOp); + for (Operation &op : symbolTableOp->getRegion(0).front()) { + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; + StringAttr name = symbolOp.getNameAttr(); + LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); + + // Check if there is a colliding op in the other module. + auto collidingOp = + cast_or_null(otherSymbolTable.lookup(name)); + if (!collidingOp) + continue; + + LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); + + // Collisions are fine if both opt are functions and can be merged. + if (auto funcOp = dyn_cast(op), + collidingFuncOp = + dyn_cast(collidingOp.getOperation()); + funcOp && collidingFuncOp) { + if (canBeMerged(funcOp, collidingFuncOp) || + canBeMerged(collidingFuncOp, funcOp)) { + LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " + "will be merged\n"); + continue; + } + + // If they can't be merged, proceed like any other collision. + LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); + } + + /// Rename `op` inside `symbolTableOp` with symbol table `symbolTable` + /// to avoid a collision with `otherOp`. + auto renameToUnique = + [&uniqueId = + uniqueId](SymbolOpInterface op, SymbolOpInterface otherOp, + Operation *symbolTableOp, SymbolTable &symbolTable, + SymbolTable &otherSymbolTable) -> LogicalResult { + assert(SymbolTable::getNearestSymbolTable(op) == symbolTableOp && + "expected 'op' to be inside of 'symbolTableOp'"); + MLIRContext *context = op->getContext(); + + // Determine new name that is unique in both symbol tables. + StringAttr newName; + { + SmallString<64> prefix = op.getNameAttr().getValue(); + prefix.push_back('_'); + while (true) { + newName = StringAttr::get(context, prefix + Twine(uniqueId++)); + if (!symbolTable.lookup(newName) && + !otherSymbolTable.lookup(newName)) { + break; + } + } + } + + // Apply renaming. + LLVM_DEBUG(llvm::dbgs() + << ", renaming to @" << newName.getValue() << "\n"); + if (failed(SymbolTable::replaceAllSymbolUses(op, newName, + symbolTableOp))) { + InFlightDiagnostic diag = + emitError(op->getLoc(), Twine("failed to rename symbol to @") + + newName.getValue()); + diag.attachNote(otherOp->getLoc()) + << "renaming due to collision with this op"; + return diag; + } + op.setName(newName); // XXX: Why is this necessary? Why does + // SymbolTable::renameAllSymbolUses not do it? + return success(); + }; + + // Collision can be resolved if one of the ops is private. + if (symbolOp.isPrivate()) { + if (failed(renameToUnique(symbolOp, collidingOp, symbolTableOp, + symbolTable, otherSymbolTable))) + return failure(); + continue; + } + if (collidingOp.isPrivate()) { + if (failed(renameToUnique(collidingOp, symbolOp, otherSymbolTableOp, + symbolTable, otherSymbolTable))) + return failure(); + continue; + } - auto symbolFunc = dyn_cast(op); - auto externalSymbolFunc = dyn_cast(externalSymbol); - if (!symbolFunc || !externalSymbolFunc) { - LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); - continue; + LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); + InFlightDiagnostic diag = + emitError(symbolOp->getLoc(), + Twine("doubly defined symbol @") + name.getValue()); + diag.attachNote(collidingOp->getLoc()) << "previously defined here"; + return diag; } + } + + for (auto *op : SmallVector{target, other}) { + if (failed(mlir::verify(op))) + return emitError(op->getLoc(), + "failed to verify input op after renaming"); + } - LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n"); - if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) { - return symbolFunc.emitError() - << "external definition has a mismatching signature (" - << externalSymbolFunc.getFunctionType() << ")"; + LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); + { + SmallVector opsToMove; + for (Operation &op : other->getRegion(0).front()) { + if (auto symbol = dyn_cast(op)) + opsToMove.push_back(symbol); } - for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) { - bool isExternalConsumed = - externalSymbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isExternalReadonly = - externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr; - bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr; - if (!isExternalConsumed && !isExternalReadonly) { - if (isConsumed) - externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx)); - else if (isReadonly) - externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx)); + SymbolTable symbolTable(target); + for (SymbolOpInterface op : opsToMove) { + // Remember potentially colliding op in the target module. + auto collidingOp = + cast_or_null(symbolTable.lookup(op.getNameAttr())); + + // Move op even if we get a collision. + LLVM_DEBUG(DBGS() << " moving @" << op.getName()); + op->moveAfter(&target->getRegion(0).front(), + target->getRegion(0).front().begin()); + + // If there is no collision, we are done. + if (!collidingOp) { + LLVM_DEBUG(llvm::dbgs() << " without collision\n"); continue; } - if ((isExternalConsumed && !isConsumed) || - (isExternalReadonly && !isReadonly)) { + // We now have a collision that we resolve through merging. The merging + // may bring the symbol table out of date but we don't need to access the + // table for that symbol anymore. + + // The two colliding ops must bot be functions because we have already + // emitted errors otherwise earlier. + auto symbolFunc = cast(op.getOperation()); + auto externalSymbolFunc = + cast(collidingOp.getOperation()); + + // Both ops are in the target module now and can be treated symmetrically, + // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. + if (!canBeMerged(symbolFunc, externalSymbolFunc)) + std::swap(symbolFunc, externalSymbolFunc); + assert(canBeMerged(symbolFunc, externalSymbolFunc)); + + LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " + << externalSymbolFunc.getLoc() << ":\n" + << externalSymbolFunc << "\n"); + + // Check that function signatures match. + // XXX: Do that check earlier? + if (symbolFunc.getFunctionType() != + externalSymbolFunc.getFunctionType()) { return symbolFunc.emitError() - << "external definition has mismatching consumption annotations " - "for argument #" - << i; + << "external definition has a mismatching signature (" + << externalSymbolFunc.getFunctionType() << ")"; } - } - OpBuilder builder(&op); - builder.setInsertionPoint(&op); - builder.clone(*externalSymbol); - symbol->erase(); + // Check and merge argument attributes. + for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) { + bool isExternalConsumed = + externalSymbolFunc.getArgAttr(i, consumedName) != nullptr; + bool isExternalReadonly = + externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr; + bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr; + bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr; + if (!isExternalConsumed && !isExternalReadonly) { + if (isConsumed) + externalSymbolFunc.setArgAttr(i, consumedName, + UnitAttr::get(context)); + else if (isReadonly) + externalSymbolFunc.setArgAttr(i, readOnlyName, + UnitAttr::get(context)); + continue; + } + + if ((isExternalConsumed && !isConsumed) || + (isExternalReadonly && !isReadonly)) { + return symbolFunc.emitError() + << "external definition has mismatching consumption " + "annotations for argument #" + << i; + } + } + + // `funcOp` is the external one, so we can remove it. + assert(symbolFunc.isExternal()); + symbolFunc->erase(); + } } + if (failed(mlir::verify(target))) + return emitError(target->getLoc(), + "failed to verify target op after merging symbols"); + + LLVM_DEBUG(DBGS() << "done merging ops\n"); return success(); } @@ -443,8 +602,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( diag.attachNote(target->getLoc()) << "pass anchor op"; return diag; } - if (failed(defineDeclaredSymbols(*transformRoot->getBlock(), - transformLibraryModule->get()))) + if (failed( + mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot), + transformLibraryModule->get()))) return failure(); } @@ -506,8 +666,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( return success(); if (sharedTransformModule && *sharedTransformModule) { - if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(), - parsedLibraryModule.get()))) + if (failed(mergeSymbolsInto(sharedTransformModule->get(), + parsedLibraryModule.get()))) return failure(); } else { transformLibraryModule = diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir index 3d4cb07769829..dd8d141e994da 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir @@ -11,4 +11,10 @@ // expected-remark @below {{message}} // expected-remark @below {{unannotated}} +// expected-remark @below {{internal colliding (without suffix)}} +// expected-remark @below {{internal colliding_0}} +// expected-remark @below {{internal colliding_1}} +// expected-remark @below {{internal colliding_3}} +// expected-remark @below {{internal colliding_4}} +// expected-remark @below {{internal colliding_5}} module {} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir index b21abbbdfd6d0..7452deb39b6c1 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir @@ -1,16 +1,16 @@ -// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \ // RUN: --verify-diagnostics --split-input-file -// The definition of the @foo named sequence is provided in another file. It +// The definition of the @print_message named sequence is provided in another file. It // will be included because of the pass option. module attributes {transform.with_named_sequence} { // expected-error @below {{external definition has a mismatching signature}} - transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly}) + transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly}) transform.sequence failures(propagate) { ^bb0(%arg0: !transform.op<"builtin.module">): - include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> () + include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> () } } @@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} { include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> () } } + +// ----- + +module attributes {transform.with_named_sequence} { + // expected-error @below {{doubly defined symbol @print_message}} + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op + transform.yield + } + + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir index 04b6c5a02e0ad..d14c55e6b7be8 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -4,29 +4,68 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ // RUN: --verify-diagnostics --split-input-file | FileCheck %s -// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ -// RUN: --verify-diagnostics --split-input-file | FileCheck %s +// XXX: This currently fails. +// RoooooN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ +// RoooooN: --verify-diagnostics --split-input-file | FileCheck %s -// The definition of the @foo named sequence is provided in another file. It -// will be included because of the pass option. Repeated application of the -// same pass, with or without the library option, should not be a problem. +// The definition of the @print_message named sequence is provided in another +// file. It will be included because of the pass option. Repeated application of +// the same pass, with or without the library option, should not be a problem. // Note that the same diagnostic produced twice at the same location only // needs to be matched once. // expected-remark @below {{message}} // expected-remark @below {{unannotated}} +// expected-remark @below {{internal colliding (without suffix)}} +// expected-remark @below {{internal colliding_0}} +// expected-remark @below {{internal colliding_1}} +// expected-remark @below {{internal colliding_3}} +// expected-remark @below {{internal colliding_4}} +// expected-remark @below {{internal colliding_5}} module attributes {transform.with_named_sequence} { - // CHECK: transform.named_sequence @foo - // CHECK: test_print_remark_at_operand %{{.*}}, "message" - transform.named_sequence private @foo(!transform.any_op {transform.readonly}) + // CHECK: transform.named_sequence @print_message( + // CHECK: transform.include @private_helper + transform.named_sequence private @print_message(!transform.any_op {transform.readonly}) + + // These ops collide with ops from the other module before or after renaming. + transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding (without suffix)" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_0" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_1(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_1" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_3" : !transform.any_op + transform.yield + } + transform.named_sequence @colliding_4(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_4" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_5(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_5" : !transform.any_op + transform.yield + } - // CHECK: transform.named_sequence @unannotated + // CHECK: transform.named_sequence @unannotated( // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated" - transform.named_sequence private @unannotated(!transform.any_op {transform.readonly}) + transform.named_sequence @unannotated(!transform.any_op {transform.readonly}) transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> () include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_1 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> () } } diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir new file mode 100644 index 0000000000000..1d9ef1dbead63 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s + +module attributes {transform.with_named_sequence} { + // expected-note @below {{previously defined here}} + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op + transform.yield + } + + transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) { + transform.test_consume_operand %arg0 : !transform.any_op + transform.yield + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir index 1149bda98ab85..66f0f1f62683b 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir @@ -1,11 +1,42 @@ // RUN: mlir-opt %s module attributes {transform.with_named_sequence} { - transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) { + transform.named_sequence private @private_helper(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op transform.yield } + // These ops collide with ops from the other module before or after renaming. + transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding (without suffix)" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_0" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_2(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_2" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_3" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_4(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_4" : !transform.any_op + transform.yield + } + transform.named_sequence @colliding_5(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_5" : !transform.any_op + transform.yield + } + + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.include @private_helper failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } + transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) { transform.test_consume_operand %arg0 : !transform.any_op transform.yield @@ -15,4 +46,14 @@ module attributes {transform.with_named_sequence} { transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op transform.yield } + + transform.named_sequence @symbol_user(%arg0: !transform.any_op {transform.readonly}) { + transform.include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_2 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } } From 0d799d51038a1411fdfc3ba12e10e33b07491434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 29 Sep 2023 07:37:51 +0000 Subject: [PATCH 02/13] Several minor and major clean-ups: * Move all private functions of the CPP file into anonymous namespace. * Remove test with second interpreter pass that reloads the library. I think that this shouldn't be possible. * Factor out `renameToUnique`, `canMergeInto`, and `mergeInto` into proper functions. * Use a single symbol table per input op and update it correctly whenever symbols or ops change. * Make `other` arg an `OwningOpRef` and clone the arguments where necessary. * Improve comments. --- .../TransformInterpreterPassBase.cpp | 285 ++++++++++-------- ...test-interpreter-external-symbol-decl.mlir | 10 +- 2 files changed, 163 insertions(+), 132 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 7d09be07f66d9..7983ec6acece7 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -50,6 +50,8 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = constexpr static llvm::StringLiteral kTransformDialectTagTransformContainerValue = "transform_container"; +namespace { + /// Utility to parse the content of a `transformFileName` MLIR file containing /// a transform dialect specification. static LogicalResult @@ -302,42 +304,142 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } +/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and +/// `otherSymbolTable` are the symbol tables of the two ops, respectively. +/// `uniqueId` is used to generate a unique name in the context of the caller. +LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp, + SymbolTable &symbolTable, + SymbolTable &otherSymbolTable, int &uniqueId) { + assert(symbolTable.lookup(op.getNameAttr()) == op && + "symbol table does not contain op"); + assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp && + "other symbol table does not contain other op"); + + // Determine new name that is unique in both symbol tables. + StringAttr oldName = op.getNameAttr(); + StringAttr newName; + { + MLIRContext *context = op->getContext(); + SmallString<64> prefix = oldName.getValue(); + prefix.push_back('_'); + while (true) { + newName = StringAttr::get(context, prefix + Twine(uniqueId++)); + if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) { + break; + } + } + } + + // Apply renaming. + LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n"); + Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); + if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) { + InFlightDiagnostic diag = + emitError(op->getLoc(), + Twine("failed to rename symbol to @") + newName.getValue()); + diag.attachNote(otherOp->getLoc()) + << "attempted renaming due to collision with this op"; + return diag; + } + + // Change the symbol in the op itself and update the symbol table. + symbolTable.remove(op); + SymbolTable::setSymbolName(op, newName); + symbolTable.insert(op); + + assert(symbolTable.lookup(newName) == op && + "symbol table does not resolve to renamed op"); + assert(symbolTable.lookup(oldName) == nullptr && + "symbol table still resolves old name"); + + return success(); +} + +/// Return whether `func1` can be merged into `func2`. +bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { + return func1.isExternal() && (func2.isPublic() || func2.isExternal()); +} + +/// Merge `func1` into `func2`. The two ops must be inside the same parent op +/// and mergable according to `canMergeInto`. The function erases `func1` such +/// that only `func2` exists when the function returns. +LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { + assert(canMergeInto(func1, func2)); + assert(func1->getParentOp() == func2->getParentOp() && + "expected func1 and func2 to be in the same parent op"); + + MLIRContext *context = func1->getContext(); + auto consumedName = StringAttr::get( + context, transform::TransformDialect::kArgConsumedAttrName); + auto readOnlyName = StringAttr::get( + context, transform::TransformDialect::kArgReadOnlyAttrName); + + // Check that function signatures match. + if (func1.getFunctionType() != func2.getFunctionType()) { + return func1.emitError() + << "external definition has a mismatching signature (" + << func2.getFunctionType() << ")"; + } + + // Check and merge argument attributes. + for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { + bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; + bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; + bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; + bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; + if (!isExternalConsumed && !isExternalReadonly) { + if (isConsumed) + func2.setArgAttr(i, consumedName, UnitAttr::get(context)); + else if (isReadonly) + func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); + continue; + } + + if ((isExternalConsumed && !isConsumed) || + (isExternalReadonly && !isReadonly)) { + return func1.emitError() + << "external definition has mismatching consumption " + "annotations for argument #" + << i; + } + } + + // `func1` is the external one, so we can remove it. + assert(func1.isExternal()); + func1->erase(); + + return success(); +} + /// Merge all symbols from `other` into `target`. Both ops need to implement the /// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be -/// modified by this function. Upon merging, private symbols may be renamed in -/// order to avoid collisions in the result. Public symbols may not collide, -/// with the exception of `SymbolInterfaceOp`s, where collisions are allowed if -/// at least one of the two is external, in which case the other op preserved -/// (or one of the two if both are external). The `target` op might not verify -/// after this function returns. -// XXX: Make `other` argument an `OwningOpRef`? -static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { +/// modified by this function and might not verify after the function returns. +/// Upon merging, private symbols may be renamed in order to avoid collisions in +/// the result. Public symbols may not collide, with the exception of +/// instances of `SymbolOpInterface`, where collisions are allowed if at least +/// one of the two is external, in which case the other op preserved (or any one +/// of the two if both are external). +static LogicalResult mergeSymbolsInto(Operation *target, + OwningOpRef other) { assert(target->hasTrait() && "requires target to implement the 'SymbolTable' trait"); assert(other->hasTrait() && "requires target to implement the 'SymbolTable' trait"); - MLIRContext *context = other->getContext(); - auto consumedName = StringAttr::get( - context, transform::TransformDialect::kArgConsumedAttrName); - auto readOnlyName = StringAttr::get( - context, transform::TransformDialect::kArgReadOnlyAttrName); + SymbolTable targetSymbolTable(target); + SymbolTable otherSymbolTable(*other); int uniqueId = 0; - auto canBeMerged = [](FunctionOpInterface func1, FunctionOpInterface func2) { - return func1.isExternal() && (func2.isPublic() || func2.isExternal()); - ; - }; - + // Step 1: + // // Rename private symbols in both ops in order to resolve conflicts that can // be resolved that way. LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); - for (auto [symbolTableOp, otherSymbolTableOp] : - llvm::zip(SmallVector{target, other}, - SmallVector{other, target})) { - SymbolTable symbolTable(symbolTableOp); // XXX: build only once - SymbolTable otherSymbolTable(otherSymbolTableOp); + for (auto [symbolTable, otherSymbolTable] : llvm::zip( + SmallVector{&targetSymbolTable, &otherSymbolTable}, + SmallVector{&otherSymbolTable, &targetSymbolTable})) { + Operation *symbolTableOp = symbolTable->getOp(); for (Operation &op : symbolTableOp->getRegion(0).front()) { auto symbolOp = dyn_cast(op); if (!symbolOp) @@ -347,7 +449,7 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { // Check if there is a colliding op in the other module. auto collidingOp = - cast_or_null(otherSymbolTable.lookup(name)); + cast_or_null(otherSymbolTable->lookup(name)); if (!collidingOp) continue; @@ -358,8 +460,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { collidingFuncOp = dyn_cast(collidingOp.getOperation()); funcOp && collidingFuncOp) { - if (canBeMerged(funcOp, collidingFuncOp) || - canBeMerged(collidingFuncOp, funcOp)) { + if (canMergeInto(funcOp, collidingFuncOp) || + canMergeInto(collidingFuncOp, funcOp)) { LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " "will be merged\n"); continue; @@ -369,58 +471,16 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); } - /// Rename `op` inside `symbolTableOp` with symbol table `symbolTable` - /// to avoid a collision with `otherOp`. - auto renameToUnique = - [&uniqueId = - uniqueId](SymbolOpInterface op, SymbolOpInterface otherOp, - Operation *symbolTableOp, SymbolTable &symbolTable, - SymbolTable &otherSymbolTable) -> LogicalResult { - assert(SymbolTable::getNearestSymbolTable(op) == symbolTableOp && - "expected 'op' to be inside of 'symbolTableOp'"); - MLIRContext *context = op->getContext(); - - // Determine new name that is unique in both symbol tables. - StringAttr newName; - { - SmallString<64> prefix = op.getNameAttr().getValue(); - prefix.push_back('_'); - while (true) { - newName = StringAttr::get(context, prefix + Twine(uniqueId++)); - if (!symbolTable.lookup(newName) && - !otherSymbolTable.lookup(newName)) { - break; - } - } - } - - // Apply renaming. - LLVM_DEBUG(llvm::dbgs() - << ", renaming to @" << newName.getValue() << "\n"); - if (failed(SymbolTable::replaceAllSymbolUses(op, newName, - symbolTableOp))) { - InFlightDiagnostic diag = - emitError(op->getLoc(), Twine("failed to rename symbol to @") + - newName.getValue()); - diag.attachNote(otherOp->getLoc()) - << "renaming due to collision with this op"; - return diag; - } - op.setName(newName); // XXX: Why is this necessary? Why does - // SymbolTable::renameAllSymbolUses not do it? - return success(); - }; - // Collision can be resolved if one of the ops is private. if (symbolOp.isPrivate()) { - if (failed(renameToUnique(symbolOp, collidingOp, symbolTableOp, - symbolTable, otherSymbolTable))) + if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, + *otherSymbolTable, uniqueId))) return failure(); continue; } if (collidingOp.isPrivate()) { - if (failed(renameToUnique(collidingOp, symbolOp, otherSymbolTableOp, - symbolTable, otherSymbolTable))) + if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, + *symbolTable, uniqueId))) return failure(); continue; } @@ -434,12 +494,15 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { } } - for (auto *op : SmallVector{target, other}) { + for (auto *op : SmallVector{target, *other}) { if (failed(mlir::verify(op))) return emitError(op->getLoc(), "failed to verify input op after renaming"); } + // Step 2: + // + // Move all ops from `other` into target and merge public symbols. LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); { SmallVector opsToMove; @@ -448,11 +511,10 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { opsToMove.push_back(symbol); } - SymbolTable symbolTable(target); for (SymbolOpInterface op : opsToMove) { // Remember potentially colliding op in the target module. - auto collidingOp = - cast_or_null(symbolTable.lookup(op.getNameAttr())); + auto collidingOp = cast_or_null( + targetSymbolTable.lookup(op.getNameAttr())); // Move op even if we get a collision. LLVM_DEBUG(DBGS() << " moving @" << op.getName()); @@ -465,65 +527,34 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { continue; } - // We now have a collision that we resolve through merging. The merging - // may bring the symbol table out of date but we don't need to access the - // table for that symbol anymore. - - // The two colliding ops must bot be functions because we have already + // The two colliding ops must both be functions because we have already // emitted errors otherwise earlier. - auto symbolFunc = cast(op.getOperation()); - auto externalSymbolFunc = + auto funcOp = cast(op.getOperation()); + auto collidingFuncOp = cast(collidingOp.getOperation()); // Both ops are in the target module now and can be treated symmetrically, // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. - if (!canBeMerged(symbolFunc, externalSymbolFunc)) - std::swap(symbolFunc, externalSymbolFunc); - assert(canBeMerged(symbolFunc, externalSymbolFunc)); + if (!canMergeInto(funcOp, collidingFuncOp)) { + std::swap(funcOp, collidingFuncOp); + } + assert(canMergeInto(funcOp, collidingFuncOp)); LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " - << externalSymbolFunc.getLoc() << ":\n" - << externalSymbolFunc << "\n"); - - // Check that function signatures match. - // XXX: Do that check earlier? - if (symbolFunc.getFunctionType() != - externalSymbolFunc.getFunctionType()) { - return symbolFunc.emitError() - << "external definition has a mismatching signature (" - << externalSymbolFunc.getFunctionType() << ")"; - } + << collidingFuncOp.getLoc() << ":\n" + << collidingFuncOp << "\n"); - // Check and merge argument attributes. - for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) { - bool isExternalConsumed = - externalSymbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isExternalReadonly = - externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr; - bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr; - if (!isExternalConsumed && !isExternalReadonly) { - if (isConsumed) - externalSymbolFunc.setArgAttr(i, consumedName, - UnitAttr::get(context)); - else if (isReadonly) - externalSymbolFunc.setArgAttr(i, readOnlyName, - UnitAttr::get(context)); - continue; - } + // Update symbol table. This works with or without the previous `swap`. + targetSymbolTable.remove(funcOp); + targetSymbolTable.insert(collidingFuncOp); + assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); - if ((isExternalConsumed && !isConsumed) || - (isExternalReadonly && !isReadonly)) { - return symbolFunc.emitError() - << "external definition has mismatching consumption " - "annotations for argument #" - << i; - } + // Do the actual merging. + if (failed(mergeInto(funcOp, collidingFuncOp))) { + return failure(); } - // `funcOp` is the external one, so we can remove it. - assert(symbolFunc.isExternal()); - symbolFunc->erase(); + assert(succeeded(mlir::verify(target))); } } @@ -535,6 +566,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, Operation *other) { return success(); } +} // namespace + LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, @@ -604,7 +637,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( } if (failed( mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot), - transformLibraryModule->get()))) + transformLibraryModule->get()->clone()))) return failure(); } @@ -667,7 +700,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( if (sharedTransformModule && *sharedTransformModule) { if (failed(mergeSymbolsInto(sharedTransformModule->get(), - parsedLibraryModule.get()))) + std::move(parsedLibraryModule)))) return failure(); } else { transformLibraryModule = diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir index d14c55e6b7be8..a9083fe3e7078 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -4,13 +4,11 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ // RUN: --verify-diagnostics --split-input-file | FileCheck %s -// XXX: This currently fails. -// RoooooN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ -// RoooooN: --verify-diagnostics --split-input-file | FileCheck %s - // The definition of the @print_message named sequence is provided in another -// file. It will be included because of the pass option. Repeated application of -// the same pass, with or without the library option, should not be a problem. +// file. It will be included because of the pass option. Subsequent application +// of the same pass works but only without the library file (since the first +// application loads external symbols and loading them again woul make them +// clash). // Note that the same diagnostic produced twice at the same location only // needs to be matched once. From 54853136d82cf5a7f51919946ec9f6a8823a585a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 29 Sep 2023 13:55:20 +0000 Subject: [PATCH 03/13] Minor fixes. * Use `moveBefore` instead of `moveAfter` in order to work on empty targets as well. * Do not verify the target after moving each op, since the last op may use symbols of ops that still have to be moved. --- .../Transform/Transforms/TransformInterpreterPassBase.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 7983ec6acece7..89a56df35587e 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -518,8 +518,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, // Move op even if we get a collision. LLVM_DEBUG(DBGS() << " moving @" << op.getName()); - op->moveAfter(&target->getRegion(0).front(), - target->getRegion(0).front().begin()); + op->moveBefore(&target->getRegion(0).front(), + target->getRegion(0).front().end()); // If there is no collision, we are done. if (!collidingOp) { @@ -553,8 +553,6 @@ static LogicalResult mergeSymbolsInto(Operation *target, if (failed(mergeInto(funcOp, collidingFuncOp))) { return failure(); } - - assert(succeeded(mlir::verify(target))); } } From eb68b4fc75d0f8c61e36fea6b3485652c2475990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 09:20:58 +0000 Subject: [PATCH 04/13] Add `SymbolTable::rename`. --- mlir/include/mlir/IR/SymbolTable.h | 8 ++++++ mlir/lib/IR/SymbolTable.cpp | 40 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index 33427788a075e..e0da27c167aed 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -55,6 +55,14 @@ class SymbolTable { /// after insertion as attribute. StringAttr insert(Operation *symbol, Block::iterator insertPt = {}); + /// Renames the given op or the op refered to by the given name to the given + /// new name and updates the symbol table and all usages of the symbol + /// accordingly. Fails if the updating of the usages fails. + LogicalResult rename(StringAttr from, StringAttr to); + LogicalResult rename(Operation *op, StringAttr to); + LogicalResult rename(StringAttr from, StringRef to); + LogicalResult rename(Operation *op, StringRef to); + /// Return the name of the attribute used for symbol names. static StringRef getSymbolAttrName() { return "sym_name"; } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 2494cb7086f0d..139fb043e9807 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -218,6 +218,46 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { return getSymbolName(symbol); } +LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) { + Operation *op = lookup(from); + return rename(op, to); +} + +LogicalResult SymbolTable::rename(Operation *op, StringAttr to) { + StringAttr from = getNameIfSymbol(op); + + assert(from && "expected valid 'name' attribute"); + assert(op->getParentOp() == symbolTableOp && + "expected this operation to be inside of the operation with this " + "SymbolTable"); + assert(lookup(from) == op && "current name does not resolve to op"); + assert(lookup(to) == nullptr && "new name already exists"); + + if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp()))) + return failure(); + + // Remove op with old name, change name, add with new name. The order is + // important here due to how `remove` and `insert` rely on the op name. + remove(op); + setSymbolName(op, to); + insert(op); + + assert(lookup(to) == op && "new name does not resolve to renamed op"); + assert(lookup(from) == nullptr && "old name still exists"); + + return success(); +} + +LogicalResult SymbolTable::rename(StringAttr from, StringRef to) { + auto toAttr = StringAttr::get(getOp()->getContext(), to); + return rename(from, toAttr); +} + +LogicalResult SymbolTable::rename(Operation *op, StringRef to) { + auto toAttr = StringAttr::get(getOp()->getContext(), to); + return rename(op, toAttr); +} + /// Returns the name of the given symbol operation. StringAttr SymbolTable::getSymbolName(Operation *symbol) { StringAttr name = getNameIfSymbol(symbol); From 02e1ef5fbc297ff8c0acbe25d5924da77a714213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 09:42:41 +0000 Subject: [PATCH 05/13] Add `SymbolTable::renameToUnique`. --- mlir/include/mlir/IR/SymbolTable.h | 9 ++++++++ mlir/lib/IR/SymbolTable.cpp | 33 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index e0da27c167aed..7f21f22eba951 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -63,6 +63,15 @@ class SymbolTable { LogicalResult rename(StringAttr from, StringRef to); LogicalResult rename(Operation *op, StringRef to); + /// Renames the given op or the op refered to by the given name to the a name + /// that is unique within this and the provided other symbol tables and + /// updates the symbol table and all usages of the symbol accordingly. Returns + /// the new name or failure if the renaming fails. + FailureOr renameToUnique(StringAttr from, + ArrayRef others); + FailureOr renameToUnique(Operation *op, + ArrayRef others); + /// Return the name of the attribute used for symbol names. static StringRef getSymbolAttrName() { return "sym_name"; } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 139fb043e9807..8ff1859e1383f 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -258,6 +258,39 @@ LogicalResult SymbolTable::rename(Operation *op, StringRef to) { return rename(op, toAttr); } +FailureOr +SymbolTable::renameToUnique(StringAttr oldName, + ArrayRef others) { + + // Determine new name that is unique in all symbol tables. + StringAttr newName; + { + MLIRContext *context = oldName.getContext(); + SmallString<64> prefix = oldName.getValue(); + int uniqueId = 0; + prefix.push_back('_'); + while (true) { + newName = StringAttr::get(context, prefix + Twine(uniqueId++)); + auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); }; + if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) { + break; + } + } + } + + // Apply renaming. + if (failed(rename(oldName, newName))) + return failure(); + return newName; +} + +FailureOr +SymbolTable::renameToUnique(Operation *op, ArrayRef others) { + StringAttr from = getNameIfSymbol(op); + assert(from && "expected valid 'name' attribute"); + return renameToUnique(from, others); +} + /// Returns the name of the given symbol operation. StringAttr SymbolTable::getSymbolName(Operation *symbol) { StringAttr name = getNameIfSymbol(symbol); From 4ad5c20633d1d65af388c8761a565bf8fe2c0c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 09:31:10 +0000 Subject: [PATCH 06/13] Use new functions in `SymbolTable`. --- .../TransformInterpreterPassBase.cpp | 77 +++++-------------- 1 file changed, 21 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 89a56df35587e..2805125f629b4 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -304,57 +304,6 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } -/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and -/// `otherSymbolTable` are the symbol tables of the two ops, respectively. -/// `uniqueId` is used to generate a unique name in the context of the caller. -LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp, - SymbolTable &symbolTable, - SymbolTable &otherSymbolTable, int &uniqueId) { - assert(symbolTable.lookup(op.getNameAttr()) == op && - "symbol table does not contain op"); - assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp && - "other symbol table does not contain other op"); - - // Determine new name that is unique in both symbol tables. - StringAttr oldName = op.getNameAttr(); - StringAttr newName; - { - MLIRContext *context = op->getContext(); - SmallString<64> prefix = oldName.getValue(); - prefix.push_back('_'); - while (true) { - newName = StringAttr::get(context, prefix + Twine(uniqueId++)); - if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) { - break; - } - } - } - - // Apply renaming. - LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n"); - Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); - if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) { - InFlightDiagnostic diag = - emitError(op->getLoc(), - Twine("failed to rename symbol to @") + newName.getValue()); - diag.attachNote(otherOp->getLoc()) - << "attempted renaming due to collision with this op"; - return diag; - } - - // Change the symbol in the op itself and update the symbol table. - symbolTable.remove(op); - SymbolTable::setSymbolName(op, newName); - symbolTable.insert(op); - - assert(symbolTable.lookup(newName) == op && - "symbol table does not resolve to renamed op"); - assert(symbolTable.lookup(oldName) == nullptr && - "symbol table still resolves old name"); - - return success(); -} - /// Return whether `func1` can be merged into `func2`. bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { return func1.isExternal() && (func2.isPublic() || func2.isExternal()); @@ -429,8 +378,6 @@ static LogicalResult mergeSymbolsInto(Operation *target, SymbolTable targetSymbolTable(target); SymbolTable otherSymbolTable(*other); - int uniqueId = 0; - // Step 1: // // Rename private symbols in both ops in order to resolve conflicts that can @@ -471,16 +418,34 @@ static LogicalResult mergeSymbolsInto(Operation *target, LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); } - // Collision can be resolved if one of the ops is private. + // Collision can be resolved by renaming if one of the ops is private. + auto renameToUnique = + [&](SymbolOpInterface op, SymbolOpInterface otherOp, + SymbolTable &symbolTable, + SymbolTable &otherSymbolTable) -> LogicalResult { + LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); + FailureOr maybeNewName = + symbolTable.renameToUnique(op, {&otherSymbolTable}); + if (failed(maybeNewName)) { + InFlightDiagnostic diag = op->emitError("failed to rename symbol"); + diag.attachNote(otherOp->getLoc()) + << "attempted renaming due to collision with this op"; + return diag; + } + LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() + << "\n"); + return success(); + }; + if (symbolOp.isPrivate()) { if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, - *otherSymbolTable, uniqueId))) + *otherSymbolTable))) return failure(); continue; } if (collidingOp.isPrivate()) { if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, - *symbolTable, uniqueId))) + *symbolTable))) return failure(); continue; } From e7bd1048db52a562c3ce412b5b4296599809e472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 10:29:26 +0000 Subject: [PATCH 07/13] Factor out transform dialect attribute name construction. --- .../mlir/Dialect/Transform/IR/TransformDialect.td | 8 ++++++++ .../Transforms/TransformInterpreterPassBase.cpp | 10 ++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 70a76ab9670f9..e6a6385161352 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -43,6 +43,14 @@ def Transform_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = "transform.readonly"; + /// Above attribute names as `StringAttr`. + StringAttr getConsumedAttrName() const { + return StringAttr::get(getContext(), kArgConsumedAttrName); + } + StringAttr getReadOnlyAttrName() const { + return StringAttr::get(getContext(), kArgReadOnlyAttrName); + } + template const DataTy &getExtraData() const { return *static_cast( diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 2805125f629b4..cfbef7f0aa517 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -317,12 +317,6 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { assert(func1->getParentOp() == func2->getParentOp() && "expected func1 and func2 to be in the same parent op"); - MLIRContext *context = func1->getContext(); - auto consumedName = StringAttr::get( - context, transform::TransformDialect::kArgConsumedAttrName); - auto readOnlyName = StringAttr::get( - context, transform::TransformDialect::kArgReadOnlyAttrName); - // Check that function signatures match. if (func1.getFunctionType() != func2.getFunctionType()) { return func1.emitError() @@ -331,6 +325,10 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { } // Check and merge argument attributes. + MLIRContext *context = func1->getContext(); + auto td = context->getLoadedDialect(); + StringAttr consumedName = td->getConsumedAttrName(); + StringAttr readOnlyName = td->getReadOnlyAttrName(); for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; From e80cf793ce6a15f5a6830403c53717e3dd6ba45e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 10:20:11 +0000 Subject: [PATCH 08/13] Address remaining issues raised by @ftynse. --- .../TransformInterpreterPassBase.cpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index cfbef7f0aa517..17bc37ab4cbf6 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -304,7 +304,10 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } -/// Return whether `func1` can be merged into `func2`. +/// Return whether `func1` can be merged into `func2`. For that to work `func1` +/// has to be a declaration (aka has to be external) and `func2` either has to +/// be a declaration as well, or it has to be public (otherwise, it wouldn't +/// be visible by `func1`). bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { return func1.isExternal() && (func2.isPublic() || func2.isExternal()); } @@ -381,9 +384,10 @@ static LogicalResult mergeSymbolsInto(Operation *target, // Rename private symbols in both ops in order to resolve conflicts that can // be resolved that way. LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); - for (auto [symbolTable, otherSymbolTable] : llvm::zip( - SmallVector{&targetSymbolTable, &otherSymbolTable}, - SmallVector{&otherSymbolTable, &targetSymbolTable})) { + for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( + SmallVector{&targetSymbolTable, &otherSymbolTable}, + SmallVector{&otherSymbolTable, + &targetSymbolTable})) { Operation *symbolTableOp = symbolTable->getOp(); for (Operation &op : symbolTableOp->getRegion(0).front()) { auto symbolOp = dyn_cast(op); @@ -449,9 +453,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, } LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); - InFlightDiagnostic diag = - emitError(symbolOp->getLoc(), - Twine("doubly defined symbol @") + name.getValue()); + InFlightDiagnostic diag = symbolOp.emitError() + << "doubly defined symbol @" << name.getValue(); diag.attachNote(collidingOp->getLoc()) << "previously defined here"; return diag; } @@ -459,8 +462,7 @@ static LogicalResult mergeSymbolsInto(Operation *target, for (auto *op : SmallVector{target, *other}) { if (failed(mlir::verify(op))) - return emitError(op->getLoc(), - "failed to verify input op after renaming"); + return op->emitError() << "failed to verify input op after renaming"; } // Step 2: @@ -520,8 +522,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, } if (failed(mlir::verify(target))) - return emitError(target->getLoc(), - "failed to verify target op after merging symbols"); + return target->emitError() + << "failed to verify target op after merging symbols"; LLVM_DEBUG(DBGS() << "done merging ops\n"); return success(); From 8c63dc040cebd40295c279e09d85932b2bbc19d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 10:14:51 +0000 Subject: [PATCH 09/13] Change tests to make sure external declaration is preserved. --- .../test-interpreter-external-symbol-decl.mlir | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir index a9083fe3e7078..7d0837abebde3 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -21,8 +21,8 @@ // expected-remark @below {{internal colliding_4}} // expected-remark @below {{internal colliding_5}} module attributes {transform.with_named_sequence} { - // CHECK: transform.named_sequence @print_message( - // CHECK: transform.include @private_helper + // CHECK-DAG: transform.named_sequence @print_message( + // CHECK-DAG: transform.include @private_helper transform.named_sequence private @print_message(!transform.any_op {transform.readonly}) // These ops collide with ops from the other module before or after renaming. @@ -42,6 +42,8 @@ module attributes {transform.with_named_sequence} { transform.test_print_remark_at_operand %arg0, "internal colliding_3" : !transform.any_op transform.yield } + // This symbol is public and thus can't be renamed. + // CHECK-DAG: transform.named_sequence @colliding_4( transform.named_sequence @colliding_4(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "internal colliding_4" : !transform.any_op transform.yield @@ -51,8 +53,8 @@ module attributes {transform.with_named_sequence} { transform.yield } - // CHECK: transform.named_sequence @unannotated( - // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated" + // CHECK-DAG: transform.named_sequence @unannotated( + // CHECK-DAG: test_print_remark_at_operand %{{.*}}, "unannotated" transform.named_sequence @unannotated(!transform.any_op {transform.readonly}) transform.sequence failures(propagate) { From 738cc714bf56fb0c8c01726db1f48499ca2be66b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 12:08:46 +0000 Subject: [PATCH 10/13] Remove library file name from the repro call. Since the pass injects all definitions, providing the library again isn't needed. Since that injection isn't idempotent, it actually isn't even *possible* anymore, so this commits removes that argument. --- .../TransformInterpreterPassBase.cpp | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 17bc37ab4cbf6..d2d860aca4d18 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -163,17 +163,9 @@ static llvm::raw_ostream & printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, - const Pass::Option &transformLibraryFileName, StringRef binaryName) { - std::string transformLibraryOption = ""; - if (!transformLibraryFileName.empty()) { - transformLibraryOption = - llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(), - transformLibraryFileName.getValue()) - .str(); - } os << llvm::formatv( - "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName, + "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, passName, debugPayloadRootTag.getArgStr(), debugPayloadRootTag.empty() ? StringRef(kTransformDialectTagPayloadRootValue) @@ -182,7 +174,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, debugTransformRootTag.empty() ? StringRef(kTransformDialectTagTransformContainerValue) : debugTransformRootTag, - transformLibraryOption, binaryName); + binaryName); return os; } @@ -228,8 +220,7 @@ void saveReproToTempFile( os << "=== Transform Interpreter Repro ===\n"; printReproCall(os, root->getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, - transformLibraryFileName, binaryName) + debugPayloadRootTag, debugTransformRootTag, binaryName) << " " << filename << "\n"; os << "===================================\n"; } @@ -283,8 +274,7 @@ static void performOptionalDebugActions( llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; printReproCall(llvm::dbgs() << "cat <getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, - transformLibraryFileName, binaryName) + debugPayloadRootTag, debugTransformRootTag, binaryName) << "\n"; printModuleForRepro(llvm::dbgs(), root, transform); llvm::dbgs() << "\nEOF\n"; From 2f5a2968c1033ba8bb73f76101966f842ab6f03a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 4 Oct 2023 12:17:18 +0000 Subject: [PATCH 11/13] Update function argument and CLI descriptions. --- .../Transforms/TransformInterpreterPassBase.h | 10 ++++++---- .../Transform/TestTransformDialectInterpreter.cpp | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h index 6102417ceda1a..a6f0dddebd7ea 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -62,9 +62,11 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// transform script. If empty, `debugTransformRootTag` is considered or the /// pass root operation must contain a single top-level transform op that /// will be interpreted. -/// - transformLibraryFileName: if non-empty, the name of the file containing -/// definitions of external symbols referenced in the transform script. -/// These definitions will be used to replace declarations. +/// - transformLibraryFileName: if non-empty, the module in this file will be +/// merged into the main transform script run by the interpreter before +/// execution. This allows to provide definitions for external functions +/// used in the main script. Other public symbols in the library module may +/// lead to collisions with public symbols in the main script. /// - debugPayloadRootTag: if non-empty, the value of the attribute named /// `kTransformDialectTagAttrName` indicating the single op that is /// considered the payload root of the transform interpreter; otherwise, the @@ -85,7 +87,7 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// as template arguments. They are *not* expected to to implement `initialize` /// or `runOnOperation`. They *are* expected to call the copy constructor of /// this class in their copy constructors, short of which the file-based -/// transform dialect script injection facility will become nonoperational. +/// transform dialect script injection facility will become non-operational. /// /// Concrete passes may implement the `runBeforeInterpreter` and /// `runAfterInterpreter` to customize the behavior of the pass. diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index f73deef9d5fd4..578d9abe4a56e 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -218,9 +218,9 @@ class TestTransformDialectInterpreterPass "select the container of the top-level transform op.")}; Option transformLibraryFileName{ *this, "transform-library-file-name", llvm::cl::init(""), - llvm::cl::desc( - "Optional name of the file containing transform dialect symbol " - "definitions to be injected into the transform module.")}; + llvm::cl::desc("Optional name of a file with a module that should be " + "merged into the transform module to provide the " + "definitions of external named sequences.")}; Option testModuleGeneration{ *this, "test-module-generation", llvm::cl::init(false), From 6976ca65420a57ffdc886fa1738d2c88e98619d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 5 Oct 2023 13:21:09 +0000 Subject: [PATCH 12/13] Address comments from final review of @ftynse. --- .../Dialect/Transform/IR/TransformDialect.td | 3 +- .../TransformInterpreterPassBase.cpp | 34 +++++++++++-------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index e6a6385161352..f28205a255070 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -43,7 +43,8 @@ def Transform_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = "transform.readonly"; - /// Above attribute names as `StringAttr`. + /// Names of the attributes indicating whether an argument of an external + /// transform dialect symbol is consumed or only read. StringAttr getConsumedAttrName() const { return StringAttr::get(getContext(), kArgConsumedAttrName); } diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index d2d860aca4d18..6c0899984f41e 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -50,8 +50,6 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue = constexpr static llvm::StringLiteral kTransformDialectTagTransformContainerValue = "transform_container"; -namespace { - /// Utility to parse the content of a `transformFileName` MLIR file containing /// a transform dialect specification. static LogicalResult @@ -180,8 +178,9 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, /// Prints the module rooted at `root` to `os` and appends /// `transformContainer` if it is not nested in `root`. -llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, - Operation *transform) { +static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, + Operation *root, + Operation *transform) { root->print(os); if (!root->isAncestor(transform)) transform->print(os); @@ -190,12 +189,13 @@ llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, /// Saves the payload and the transform IR into a temporary file and reports /// the file name to `os`. -void saveReproToTempFile( - llvm::raw_ostream &os, Operation *target, Operation *transform, - StringRef passName, const Pass::Option &debugPayloadRootTag, - const Pass::Option &debugTransformRootTag, - const Pass::Option &transformLibraryFileName, - StringRef binaryName) { +static void +saveReproToTempFile(llvm::raw_ostream &os, Operation *target, + Operation *transform, StringRef passName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + const Pass::Option &transformLibraryFileName, + StringRef binaryName) { using llvm::sys::fs::TempFile; Operation *root = getRootOperation(target); @@ -298,14 +298,15 @@ static void performOptionalDebugActions( /// has to be a declaration (aka has to be external) and `func2` either has to /// be a declaration as well, or it has to be public (otherwise, it wouldn't /// be visible by `func1`). -bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { +static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { return func1.isExternal() && (func2.isPublic() || func2.isExternal()); } /// Merge `func1` into `func2`. The two ops must be inside the same parent op /// and mergable according to `canMergeInto`. The function erases `func1` such /// that only `func2` exists when the function returns. -LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { +static LogicalResult mergeInto(FunctionOpInterface func1, + FunctionOpInterface func2) { assert(canMergeInto(func1, func2)); assert(func1->getParentOp() == func2->getParentOp() && "expected func1 and func2 to be in the same parent op"); @@ -319,7 +320,7 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { // Check and merge argument attributes. MLIRContext *context = func1->getContext(); - auto td = context->getLoadedDialect(); + auto *td = context->getLoadedDialect(); StringAttr consumedName = td->getConsumedAttrName(); StringAttr readOnlyName = td->getReadOnlyAttrName(); for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { @@ -359,6 +360,10 @@ LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { /// instances of `SymbolOpInterface`, where collisions are allowed if at least /// one of the two is external, in which case the other op preserved (or any one /// of the two if both are external). +// TODO: Reconsider cloning individual ops rather than forcing users of the +// function to clone (or move) `other` in order to improve efficiency. +// This might primarily make sense if we can also prune the symbols that +// are merged to a subset (such as those that are actually used). static LogicalResult mergeSymbolsInto(Operation *target, OwningOpRef other) { assert(target->hasTrait() && @@ -374,6 +379,7 @@ static LogicalResult mergeSymbolsInto(Operation *target, // Rename private symbols in both ops in order to resolve conflicts that can // be resolved that way. LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + // TODO: Do we *actually* need to test in both directions? for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( SmallVector{&targetSymbolTable, &otherSymbolTable}, SmallVector{&otherSymbolTable, @@ -519,8 +525,6 @@ static LogicalResult mergeSymbolsInto(Operation *target, return success(); } -} // namespace - LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( Operation *target, StringRef passName, const std::shared_ptr> &sharedTransformModule, From 71b40b1aeddf871f6097f223fc012c2f91592a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 5 Oct 2023 14:35:31 +0000 Subject: [PATCH 13/13] Add comment about verification in the middle of the pass. --- .../Transform/Transforms/TransformInterpreterPassBase.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 6c0899984f41e..68a735e7ef8e0 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -456,6 +456,8 @@ static LogicalResult mergeSymbolsInto(Operation *target, } } + // TODO: This duplicates pass infrastructure. We should split this pass into + // several and let the pass infrastructure do the verification. for (auto *op : SmallVector{target, *other}) { if (failed(mlir::verify(op))) return op->emitError() << "failed to verify input op after renaming";