From 8a5aca204bb7ed1a0a05f14994274a70f732b3d6 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Wed, 4 Sep 2024 15:04:36 -0400 Subject: [PATCH 01/17] Make OneShotModuleBufferize accept FunctionOpInterface and CallOpInterface --- .../Transforms/OneShotModuleBufferize.cpp | 81 ++++++++++++------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 0a4072605c265..2983af0fcbf3f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -75,7 +75,7 @@ using namespace mlir::bufferization; using namespace mlir::bufferization::func_ext; /// A mapping of FuncOps to their callers. -using FuncCallerMap = DenseMap>; +using FuncCallerMap = DenseMap>; /// Get or create FuncAnalysisState. static FuncAnalysisState & @@ -247,6 +247,15 @@ static func::FuncOp getCalledFunction(func::CallOp callOp) { SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } +static FunctionOpInterface getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + /// Gather equivalence info of CallOps. /// Note: This only adds new equivalence info if the called function was already /// analyzed. @@ -277,10 +286,10 @@ static void equivalenceAnalysis(func::FuncOp funcOp, } /// Return "true" if the given function signature has tensor semantics. -static bool hasTensorSignature(func::FuncOp funcOp) { - return llvm::any_of(funcOp.getFunctionType().getInputs(), +static bool hasTensorSignature(FunctionOpInterface funcOp) { + return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred) || - llvm::any_of(funcOp.getFunctionType().getResults(), + llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred); } @@ -291,26 +300,30 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// retrieve the called FuncOp from any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, - SmallVectorImpl &orderedFuncOps, + SmallVectorImpl &orderedFuncOps, FuncCallerMap &callerMap) { // For each FuncOp, the set of functions called by it (i.e. the union of // symbols of all nested func::CallOp). - DenseMap> calledBy; + DenseMap> calledBy; // For each FuncOp, the number of func::CallOp it contains. - DenseMap numberCallOpsContainedInFuncOp; - WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { - if (!funcOp.getBody().empty()) { - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; + DenseMap numberCallOpsContainedInFuncOp; + WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { + // Only handle ReturnOp if funcOp is exactly the FuncOp type. + if(isa(funcOp)) { + FuncOp funcOpCasted = cast(funcOp); + if (!funcOpCasted.getBody().empty()) { + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted); + if (!returnOp) + return funcOp->emitError() + << "cannot bufferize a FuncOp with tensors and " + "without a unique ReturnOp"; + } } // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; - return funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp); + return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { + FunctionOpInterface calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller. @@ -379,7 +392,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); // A list of functions in the order in which they are analyzed + bufferized. - SmallVector orderedFuncOps; + SmallVector orderedFuncOps; // A mapping of FuncOps to their callers. FuncCallerMap callerMap; @@ -388,27 +401,33 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, return failure(); // Analyze ops. - for (func::FuncOp funcOp : orderedFuncOps) { - if (!state.getOptions().isOpAllowed(funcOp)) + for (FunctionOpInterface funcOp : orderedFuncOps) { + + // The following analysis is specific to the FuncOp type. + if(!isa(funcOp)) + continue; + FuncOp funcOpCasted = cast(funcOp); + + if (!state.getOptions().isOpAllowed(funcOpCasted)) continue; // Now analyzing function. - funcState.startFunctionAnalysis(funcOp); + funcState.startFunctionAnalysis(funcOpCasted); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, state, funcState); + equivalenceAnalysis(funcOpCasted, state, funcState); // Analyze funcOp. - if (failed(analyzeOp(funcOp, state, statistics))) + if (failed(analyzeOp(funcOpCasted, state, statistics))) return failure(); // Run some extra function analyses. - if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || - failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) + if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) || + failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState))) return failure(); // Mark op as fully analyzed. - funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; + funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed; } return success(); @@ -430,20 +449,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( IRRewriter rewriter(moduleOp.getContext()); // A list of functions in the order in which they are analyzed + bufferized. - SmallVector orderedFuncOps; + SmallVector orderedFuncOps; // A mapping of FuncOps to their callers. FuncCallerMap callerMap; if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); + SmallVector ops; // Bufferize functions. - for (func::FuncOp funcOp : orderedFuncOps) { + for (FunctionOpInterface funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - - if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) { + if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) { // This function was not analyzed and RaW conflicts were not resolved. // Buffer copies must be inserted before every write. OneShotBufferizationOptions updatedOptions = options; @@ -456,8 +475,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Change buffer return types to more precise layout maps. - if (options.inferFunctionResultLayout) - foldMemRefCasts(funcOp); + if (options.inferFunctionResultLayout && isa(funcOp)) + foldMemRefCasts(cast(funcOp)); } // Bufferize all other ops. From 5153af3ee72d4322273b1614a6637a952b10cdcc Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Wed, 4 Sep 2024 15:42:08 -0400 Subject: [PATCH 02/17] Cleanup --- .../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 2983af0fcbf3f..5231fe8605537 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -456,12 +456,12 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); - SmallVector ops; // Bufferize functions. for (FunctionOpInterface funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. + if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getName())) { // This function was not analyzed and RaW conflicts were not resolved. // Buffer copies must be inserted before every write. From 1f8d847077716be2f0115c4fadcb7c2d4eafe945 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Fri, 6 Sep 2024 10:37:18 -0400 Subject: [PATCH 03/17] Make getAssumedUniqueReturnOp detect ReturnLike and FuncAnalysisState use FunctionOpInterface --- .../FuncBufferizableOpInterfaceImpl.h | 12 +- .../FuncBufferizableOpInterfaceImpl.cpp | 2 +- .../Transforms/OneShotModuleBufferize.cpp | 117 ++++++++---------- 3 files changed, 59 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h index 0b91d3d675b7c..8bed0dfc5814b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -50,24 +50,24 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension { /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. - DenseMap equivalentFuncArgs; + DenseMap equivalentFuncArgs; /// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices. - DenseMap aliasingReturnVals; + DenseMap aliasingReturnVals; /// A set of all read BlockArguments of FuncOps. - DenseMap readBbArgs; + DenseMap readBbArgs; /// A set of all written-to BlockArguments of FuncOps. - DenseMap writtenBbArgs; + DenseMap writtenBbArgs; /// Keep track of which FuncOps are fully analyzed or currently being /// analyzed. - DenseMap analyzedFuncOps; + DenseMap analyzedFuncOps; /// This function is called right before analyzing the given FuncOp. It /// initializes the data structures for the FuncOp in this state object. - void startFunctionAnalysis(FuncOp funcOp); + void startFunctionAnalysis(FunctionOpInterface funcOp); }; void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 9fbe574ec392d..9749a71f3514b 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -22,7 +22,7 @@ namespace mlir { namespace bufferization { namespace func_ext { -void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { +void FuncAnalysisState::startFunctionAnalysis(FunctionOpInterface funcOp) { analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; auto createdEquiv = equivalentFuncArgs.try_emplace(funcOp, IndexMapping()); auto createdAliasingResults = diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 5231fe8605537..cfb87aef6e64b 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -88,10 +88,11 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { /// Return the unique ReturnOp that terminates `funcOp`. /// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { +static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { + Operation *returnOp = nullptr; + for (Block &b : funcOp.getFunctionBody()) { + auto candidateOp = b.getTerminator(); + if (candidateOp && candidateOp->hasTrait()) { if (returnOp) return nullptr; returnOp = candidateOp; @@ -126,16 +127,15 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal, /// Store function BlockArguments that are equivalent to/aliasing a returned /// value in FuncAnalysisState. static LogicalResult -aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, +aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - if (funcOp.getBody().empty()) { + if (funcOp.getFunctionBody().empty()) { // No function body available. Conservatively assume that every tensor // return value may alias with any tensor bbArg. - FunctionType type = funcOp.getFunctionType(); - for (const auto &inputIt : llvm::enumerate(type.getInputs())) { + for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) { if (!isa(inputIt.value())) continue; - for (const auto &resultIt : llvm::enumerate(type.getResults())) { + for (const auto &resultIt : llvm::enumerate(funcOp.getResultTypes())) { if (!isa(resultIt.value())) continue; int64_t returnIdx = resultIt.index(); @@ -147,7 +147,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, } // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!isa(funcOp)) + return success(); + Operation *returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) @@ -168,7 +170,7 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, return success(); } -static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, +static void annotateFuncArgAccess(FunctionOpInterface funcOp, int64_t idx, bool isRead, bool isWritten) { OpBuilder b(funcOp.getContext()); Attribute accessType; @@ -189,12 +191,12 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead, /// function with unknown ops, we conservatively assume that such ops bufferize /// to a read + write. static LogicalResult -funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, +funcOpBbArgReadWriteAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; + for (int64_t idx = 0, e = funcOp.getNumArguments(); idx < e; ++idx) { // Skip non-tensor arguments. - if (!isa(funcOp.getFunctionType().getInput(idx))) + if (!isa(funcOp.getArgumentTypes()[idx])) continue; bool isRead; bool isWritten; @@ -204,7 +206,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, StringRef str = accessAttr.getValue(); isRead = str == "read" || str == "read-write"; isWritten = str == "write" || str == "read-write"; - } else if (funcOp.getBody().empty()) { + } else if (funcOp.getFunctionBody().empty()) { // If the function has no body, conservatively assume that all args are // read + written. isRead = true; @@ -230,23 +232,13 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, /// Remove bufferization attributes on FuncOp arguments. static void removeBufferizationAttributes(BlockArgument bbArg) { - auto funcOp = cast(bbArg.getOwner()->getParentOp()); + auto funcOp = cast(bbArg.getOwner()->getParentOp()); funcOp.removeArgAttr(bbArg.getArgNumber(), BufferizationDialect::kBufferLayoutAttrName); funcOp.removeArgAttr(bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName); } -/// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(func::CallOp callOp) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - static FunctionOpInterface getCalledFunction(CallOpInterface callOp) { SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); @@ -260,12 +252,12 @@ static FunctionOpInterface getCalledFunction(CallOpInterface callOp) { /// Note: This only adds new equivalence info if the called function was already /// analyzed. // TODO: This does not handle cyclic function call graphs etc. -static void equivalenceAnalysis(func::FuncOp funcOp, +static void equivalenceAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - funcOp->walk([&](func::CallOp callOp) { - func::FuncOp calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called func::FuncOp"); + funcOp->walk([&](CallOpInterface callOp) { + FunctionOpInterface calledFunction = getCalledFunction(callOp); + assert(calledFunction && "could not retrieved called FunctionOpInterface"); // No equivalence info available for the called function. if (!funcState.equivalentFuncArgs.count(calledFunction)) @@ -276,7 +268,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp, int64_t bbargIdx = it.second; if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) continue; - Value returnVal = callOp.getResult(returnIdx); + Value returnVal = callOp->getResult(returnIdx); Value argVal = callOp->getOperand(bbargIdx); state.unionEquivalenceClasses(returnVal, argVal); } @@ -308,23 +300,19 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { - // Only handle ReturnOp if funcOp is exactly the FuncOp type. - if(isa(funcOp)) { - FuncOp funcOpCasted = cast(funcOp); - if (!funcOpCasted.getBody().empty()) { - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOpCasted); - if (!returnOp) - return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; - } + if (!funcOp.getFunctionBody().empty() && isa(funcOp)) { + Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() + << "cannot bufferize a FuncOp with tensors and " + "without a unique ReturnOp"; } // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { FunctionOpInterface calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called func::FuncOp"); + assert(calledFunction && "could not retrieved called FunctionOpInterface"); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller. if (!hasTensorSignature(calledFunction)) @@ -362,11 +350,15 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, /// most generic layout map as function return types. After bufferizing the /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. -static void foldMemRefCasts(func::FuncOp funcOp) { - if (funcOp.getBody().empty()) +static void foldMemRefCasts(FunctionOpInterface funcOp) { + if (funcOp.getFunctionBody().empty()) + return; + + Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + + if (!returnOp) return; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); SmallVector resultTypes; for (OpOperand &operand : returnOp->getOpOperands()) { @@ -379,7 +371,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } auto newFuncType = FunctionType::get( - funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); + funcOp.getContext(), funcOp.getArgumentTypes(), resultTypes); funcOp.setType(newFuncType); } @@ -403,31 +395,26 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, // Analyze ops. for (FunctionOpInterface funcOp : orderedFuncOps) { - // The following analysis is specific to the FuncOp type. - if(!isa(funcOp)) - continue; - FuncOp funcOpCasted = cast(funcOp); - - if (!state.getOptions().isOpAllowed(funcOpCasted)) + if (!state.getOptions().isOpAllowed(funcOp)) continue; // Now analyzing function. - funcState.startFunctionAnalysis(funcOpCasted); + funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOpCasted, state, funcState); + equivalenceAnalysis(funcOp, state, funcState); // Analyze funcOp. - if (failed(analyzeOp(funcOpCasted, state, statistics))) + if (failed(analyzeOp(funcOp, state, statistics))) return failure(); // Run some extra function analyses. - if (failed(aliasingFuncOpBBArgsAnalysis(funcOpCasted, state, funcState)) || - failed(funcOpBbArgReadWriteAnalysis(funcOpCasted, state, funcState))) + if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) || + failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState))) return failure(); // Mark op as fully analyzed. - funcState.analyzedFuncOps[funcOpCasted] = FuncOpAnalysisState::Analyzed; + funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; } return success(); @@ -435,7 +422,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, void mlir::bufferization::removeBufferizationAttributesInModule( ModuleOp moduleOp) { - moduleOp.walk([&](func::FuncOp op) { + moduleOp.walk([&](FunctionOpInterface op) { for (BlockArgument bbArg : op.getArguments()) removeBufferizationAttributes(bbArg); }); @@ -475,14 +462,14 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Change buffer return types to more precise layout maps. - if (options.inferFunctionResultLayout && isa(funcOp)) - foldMemRefCasts(cast(funcOp)); + if (options.inferFunctionResultLayout) + foldMemRefCasts(funcOp); } // Bufferize all other ops. for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { // Functions were already bufferized. - if (isa(&op)) + if (isa(&op)) continue; if (failed(bufferizeOp(&op, options, statistics))) return failure(); @@ -509,12 +496,12 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( // FuncOps whose names are specified in options.noAnalysisFuncFilter will // not be analyzed. Ops in these FuncOps will not be analyzed as well. OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) { - auto func = dyn_cast(op); + auto func = dyn_cast(op); if (!func) - func = op->getParentOfType(); + func = op->getParentOfType(); if (func) return llvm::is_contained(options.noAnalysisFuncFilter, - func.getSymName()); + func.getName()); return false; }; OneShotBufferizationOptions updatedOptions(options); From 26e69ad35197b7c1b7a2084810b714898af2aeb7 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Fri, 6 Sep 2024 10:56:50 -0400 Subject: [PATCH 04/17] Make getAssumedUniqueReturnOp return funcOp if there is no return --- .../Transforms/OneShotModuleBufferize.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index cfb87aef6e64b..bd054ac4e7b87 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -88,6 +88,7 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { /// Return the unique ReturnOp that terminates `funcOp`. /// Return nullptr if there is no such unique ReturnOp. +/// Return `funcOp` it self if there is no ReturnOp. static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { Operation *returnOp = nullptr; for (Block &b : funcOp.getFunctionBody()) { @@ -98,6 +99,8 @@ static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { returnOp = candidateOp; } } + if (!returnOp) + return funcOp; return returnOp; } @@ -147,9 +150,10 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &s } // Support only single return-terminated block in the function. - if (!isa(funcOp)) - return success(); + // If funcOp has no returnOp, skip the following analysis. Operation *returnOp = getAssumedUniqueReturnOp(funcOp); + if (returnOp == funcOp) + return success(); assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) @@ -300,9 +304,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { - if (!funcOp.getFunctionBody().empty() && isa(funcOp)) { + if (!funcOp.getFunctionBody().empty()) { Operation *returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) + if (!returnOp && returnOp != funcOp) return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and " "without a unique ReturnOp"; @@ -356,7 +360,7 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) { Operation *returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) + if (!returnOp || returnOp == funcOp) return; SmallVector resultTypes; From 074192ca0e62ba600f63de4e914d44fb4bf86ffb Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Fri, 6 Sep 2024 14:35:56 -0400 Subject: [PATCH 05/17] Use getNumResults to guard functions without any return type --- .../Transforms/OneShotModuleBufferize.cpp | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index bd054ac4e7b87..6933fde7f9565 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -88,7 +88,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { /// Return the unique ReturnOp that terminates `funcOp`. /// Return nullptr if there is no such unique ReturnOp. -/// Return `funcOp` it self if there is no ReturnOp. static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { Operation *returnOp = nullptr; for (Block &b : funcOp.getFunctionBody()) { @@ -99,8 +98,6 @@ static Operation* getAssumedUniqueReturnOp(FunctionOpInterface funcOp) { returnOp = candidateOp; } } - if (!returnOp) - return funcOp; return returnOp; } @@ -132,7 +129,7 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal, static LogicalResult aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - if (funcOp.getFunctionBody().empty()) { + if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0) { // No function body available. Conservatively assume that every tensor // return value may alias with any tensor bbArg. for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) { @@ -150,10 +147,7 @@ aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &s } // Support only single return-terminated block in the function. - // If funcOp has no returnOp, skip the following analysis. Operation *returnOp = getAssumedUniqueReturnOp(funcOp); - if (returnOp == funcOp) - return success(); assert(returnOp && "expected func with single return op"); for (OpOperand &returnVal : returnOp->getOpOperands()) @@ -304,9 +298,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { - if (!funcOp.getFunctionBody().empty()) { + if (!funcOp.getFunctionBody().empty() && funcOp.getNumResults() != 0) { Operation *returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp && returnOp != funcOp) + if (!returnOp) return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and " "without a unique ReturnOp"; @@ -355,14 +349,10 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. static void foldMemRefCasts(FunctionOpInterface funcOp) { - if (funcOp.getFunctionBody().empty()) + if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0) return; Operation *returnOp = getAssumedUniqueReturnOp(funcOp); - - if (!returnOp || returnOp == funcOp) - return; - SmallVector resultTypes; for (OpOperand &operand : returnOp->getOpOperands()) { @@ -398,7 +388,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, // Analyze ops. for (FunctionOpInterface funcOp : orderedFuncOps) { - if (!state.getOptions().isOpAllowed(funcOp)) continue; From 4ba535b93e607698f3319cc5d13a3432fb0c67c4 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 10 Sep 2024 14:30:18 -0400 Subject: [PATCH 06/17] Update mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> --- .../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 6933fde7f9565..bf29b7e86a46d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -349,7 +349,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. static void foldMemRefCasts(FunctionOpInterface funcOp) { - if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0) + if (funcOp.getFunctionBody().empty()) return; Operation *returnOp = getAssumedUniqueReturnOp(funcOp); From caa69cdfcde278bda7da41b78c668610e8a6c519 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 10 Sep 2024 14:30:31 -0400 Subject: [PATCH 07/17] Update mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> --- .../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index bf29b7e86a46d..67323715ee424 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -129,7 +129,7 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal, static LogicalResult aliasingFuncOpBBArgsAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { - if (funcOp.getFunctionBody().empty() || funcOp.getNumResults() == 0) { + if (funcOp.getFunctionBody().empty()) { // No function body available. Conservatively assume that every tensor // return value may alias with any tensor bbArg. for (const auto &inputIt : llvm::enumerate(funcOp.getArgumentTypes())) { From eb2f8884ca6a95ff3e8d74b155d89c812a2ee866 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 10 Sep 2024 14:30:39 -0400 Subject: [PATCH 08/17] Update mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> --- .../Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 67323715ee424..ce90d907b4ca5 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -298,7 +298,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](FunctionOpInterface funcOp) -> WalkResult { - if (!funcOp.getFunctionBody().empty() && funcOp.getNumResults() != 0) { + if (!funcOp.getFunctionBody().empty()) { Operation *returnOp = getAssumedUniqueReturnOp(funcOp); if (!returnOp) return funcOp->emitError() From 695b945ce5259d888643339d66546bcedb9e6043 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 10 Sep 2024 14:54:50 -0400 Subject: [PATCH 09/17] Add debug-payload-root-tag to transform.named_sequence tests --- .../Transforms/transform-ops.mlir | 142 +++++++++-------- mlir/test/Dialect/LLVM/transform-e2e.mlir | 22 +-- .../Linalg/matmul-shared-memory-padding.mlir | 52 +++--- .../Linalg/pad-to-specific-memory-space.mlir | 148 +++++++++--------- .../test/Dialect/Vector/transform-vector.mlir | 84 +++++----- 5 files changed, 241 insertions(+), 207 deletions(-) diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir index 3c50a9e72d9d9..588aa8a85a84e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --transform-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s -split-input-file -verify-diagnostics | FileCheck %s // Test One-Shot Bufferize. @@ -12,19 +12,21 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_function( // CHECK-SAME: %[[A:.*]]: tensor -func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { - %c0 = arith.constant 0 : index +module @payload attributes { transform.target_tag = "payload" } { + func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index - // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] - // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] - // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) - // CHECK: memref.copy %[[A_memref]], %[[alloc]] - // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] - // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] - %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: memref.copy %[[A_memref]], %[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor - // CHECK: return %[[res_tensor]] - return %0 : tensor + // CHECK: return %[[res_tensor]] + return %0 : tensor + } } // ----- @@ -42,19 +44,21 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_function( // CHECK-SAME: %[[A:.*]]: tensor // CHECK-NOT: memref.copy -func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { - %c0 = arith.constant 0 : index +module @payload attributes { transform.target_tag = "payload" } { + func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index - // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] - // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] - // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) - // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]] - // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] - // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] - %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor - // CHECK: return %[[res_tensor]] - return %0 : tensor + // CHECK: return %[[res_tensor]] + return %0 : tensor + } } // ----- @@ -72,13 +76,15 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @test_function_analysis( // CHECK-SAME: %[[A:.*]]: tensor -func.func @test_function_analysis(%A : tensor, %v : vector<4xf32>) -> (tensor) { - %c0 = arith.constant 0 : index - // CHECK: vector.transfer_write - // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} - // CHECK-SAME: tensor - %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor - return %0 : tensor +module @payload attributes { transform.target_tag = "payload" } { + func.func @test_function_analysis(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} + // CHECK-SAME: tensor + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + return %0 : tensor + } } // ----- @@ -95,10 +101,12 @@ module attributes {transform.with_named_sequence} { } } -func.func @test_unknown_op_failure() -> (tensor) { - // expected-error @+1 {{op was not bufferized}} - %0 = "test.dummy_op"() : () -> (tensor) - return %0 : tensor +module @payload attributes { transform.target_tag = "payload" } { + func.func @test_unknown_op_failure() -> (tensor) { + // expected-error @+1 {{op was not bufferized}} + %0 = "test.dummy_op"() : () -> (tensor) + return %0 : tensor + } } // ----- @@ -111,7 +119,7 @@ module attributes {transform.with_named_sequence} { } } -module { +module @payload attributes { transform.target_tag = "payload" } { // CHECK-LABEL: func @test_function( // CHECK-SAME: %[[A:.*]]: tensor func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { @@ -146,11 +154,13 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[A:.*]]: memref<12x9xf32>, // CHECK-SAME: %[[B:.*]]: memref<9x6xf32>, // CHECK-SAME: %[[C:.*]]: memref<12x6xf32>) -> memref<12x6xf32> { -func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { - // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) - %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> - // CHECK: return %[[C]] : memref<12x6xf32> - return %D : tensor<12x6xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32>) -> tensor<12x6xf32> { + // CHECK: linalg.matmul ins(%[[A]], %[[B]] : memref<12x9xf32>, memref<9x6xf32>) outs(%[[C]] : memref<12x6xf32>) + %D = linalg.matmul ins(%A, %B: tensor<12x9xf32>, tensor<9x6xf32>) outs(%C: tensor<12x6xf32>) -> tensor<12x6xf32> + // CHECK: return %[[C]] : memref<12x6xf32> + return %D : tensor<12x6xf32> + } } // ----- @@ -165,10 +175,12 @@ module attributes {transform.with_named_sequence} { } // Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty. -func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { - // CHECK: bufferization.alloc_tensor - %0 = tensor.empty() : tensor<2x2xf32> - return %0 : tensor<2x2xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK: bufferization.alloc_tensor + %0 = tensor.empty() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } } // ----- @@ -185,13 +197,15 @@ module attributes {transform.with_named_sequence} { // CHECK: tensor.extract_slice // CHECK: linalg.fill // CHECK: tensor.insert_slice -func.func @empty_tensor_elimination( - %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { - %0 = tensor.empty() : tensor<5xf32> - %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> - %2 = tensor.insert_slice %1 into %t [1][5][1] - : tensor<5xf32> into tensor<10xf32> - return %2 : tensor<10xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @empty_tensor_elimination( + %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %2 = tensor.insert_slice %1 into %t [1][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> + } } // ----- @@ -208,12 +222,14 @@ module attributes {transform.with_named_sequence} { // CHECK: memref.alloca // CHECK: scf.for // CHECK: memref.store -func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) { - scf.for %iv = %lb to %ub step %step { - %0 = memref.alloca() : memref<5xf32> - memref.store %f, %0[%pos] : memref<5xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %pos: index) { + scf.for %iv = %lb to %ub step %step { + %0 = memref.alloca() : memref<5xf32> + memref.store %f, %0[%pos] : memref<5xf32> + } + return } - return } // ----- @@ -231,10 +247,12 @@ module attributes {transform.with_named_sequence} { // Expect `bufferization.bufferize_to_allocation` to create an alloc. // CHECK-LABEL: func.func @empty_to_tensor_alloc() -func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { - // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32> - // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32> - // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32> - %0 = bufferization.alloc_tensor() : tensor<2x2xf32> - return %0 : tensor<2x2xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32> + // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32> + // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32> + %0 = bufferization.alloc_tensor() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } } diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index c00b47fb936e9..3e637a3ec49a4 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -1,15 +1,17 @@ -// RUN: mlir-opt %s --transform-interpreter -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @matmul_tensors -func.func @matmul_tensors( - %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) - -> tensor<2x6xf32> { -// CHECK-NOT: linalg -// CHECK: llvm.intr.fmuladd{{.*}} - %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>) - outs(%arg2: tensor<2x6xf32>) - -> tensor<2x6xf32> - return %0 : tensor<2x6xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @matmul_tensors( + %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) + -> tensor<2x6xf32> { + // CHECK-NOT: linalg + // CHECK: llvm.intr.fmuladd{{.*}} + %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>) + outs(%arg2: tensor<2x6xf32>) + -> tensor<2x6xf32> + return %0 : tensor<2x6xf32> + } } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir index 3f8d2ea06641e..9c223737750a9 100644 --- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir +++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --transform-interpreter %s | FileCheck %s +// RUN: mlir-opt --split-input-file --transform-interpreter="debug-payload-root-tag=payload" %s | FileCheck %s // CHECK-LABEL: func @matmul_divisible // CHECK: scf.forall @@ -24,19 +24,21 @@ // CHECK: scf.forall // CHECK: vector.transfer_read // CHECK: vector.transfer_write -func.func @matmul_divisible(%A: tensor<1024x1024xf32>, - %B: tensor<1024x1024xf32>, - %C: tensor<1024x1024xf32>) - -> tensor<1024x1024xf32> -{ - %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill ins(%cst : f32) - outs(%C : tensor<1024x1024xf32>) +module @payload attributes { transform.target_tag = "payload" } { + func.func @matmul_divisible(%A: tensor<1024x1024xf32>, + %B: tensor<1024x1024xf32>, + %C: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>) - outs(%0 : tensor<1024x1024xf32>) - -> tensor<1024x1024xf32> - return %1 : tensor<1024x1024xf32> + { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) + outs(%C : tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> + %1 = linalg.matmul ins(%A, %B : tensor<1024x1024xf32>, tensor<1024x1024xf32>) + outs(%0 : tensor<1024x1024xf32>) + -> tensor<1024x1024xf32> + return %1 : tensor<1024x1024xf32> + } } module attributes {transform.with_named_sequence} { @@ -143,19 +145,21 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.matmul // CHECK: vector.transfer_read // CHECK: vector.transfer_write +module @payload attributes { transform.target_tag = "payload" } { func.func @matmul_not_divisible(%A: tensor<1023x1023xf32>, - %B: tensor<1023x1023xf32>, - %C: tensor<1023x1023xf32>) - -> tensor<1023x1023xf32> -{ - %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.fill ins(%cst : f32) - outs(%C : tensor<1023x1023xf32>) + %B: tensor<1023x1023xf32>, + %C: tensor<1023x1023xf32>) -> tensor<1023x1023xf32> - %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>) - outs(%0 : tensor<1023x1023xf32>) - -> tensor<1023x1023xf32> - return %1 : tensor<1023x1023xf32> + { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%cst : f32) + outs(%C : tensor<1023x1023xf32>) + -> tensor<1023x1023xf32> + %1 = linalg.matmul ins(%A, %B : tensor<1023x1023xf32>, tensor<1023x1023xf32>) + outs(%0 : tensor<1023x1023xf32>) + -> tensor<1023x1023xf32> + return %1 : tensor<1023x1023xf32> + } } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir index f2e9e839b7c46..5e5657980ba12 100644 --- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir +++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt --transform-interpreter -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" -cse -canonicalize -split-input-file -verify-diagnostics %s | FileCheck %s #map = affine_map<()[s0] -> (-s0 + 12, 7)> @@ -7,43 +7,45 @@ // CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>, // CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>, // CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>, -func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>, - %arg1: tensor<12x25xf32>, - %arg2: tensor<24x25xf32>, - %iv0 : index, %iv1 : index, - %iv2 : index) -> tensor<24x25xf32> { - %0 = affine.min #map()[%iv2] - - // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] - %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> - // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] - %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor - // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] - %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> - - // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> - // CHECK: linalg.fill {{.*}} outs(%[[alloc0]] - // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1] - // CHECK: memref.copy %[[s0]], %[[alloc0_view]] - - // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> - // CHECK: linalg.fill {{.*}} outs(%[[alloc1]] - // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1] - // CHECK: memref.copy %[[s1]], %[[alloc1_view]] - - // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> - // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]] - // No subview because there is 0 padding - // CHECK: memref.copy %[[s2]], %[[alloc2]] - - // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) - // Copy back result. - // CHECK: memref.copy %[[alloc2]], %[[s2]] - %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> - - // insert_slice bufferizes to a no-op. - %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> - func.return %5 : tensor<24x25xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @pad_to_memory_space(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, + %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> + // CHECK: linalg.fill {{.*}} outs(%[[alloc0]] + // CHECK: %[[alloc0_view:.*]] = memref.subview %[[alloc0]][0, 0] [4, %{{.*}}] [1, 1] + // CHECK: memref.copy %[[s0]], %[[alloc0_view]] + + // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> + // CHECK: linalg.fill {{.*}} outs(%[[alloc1]] + // CHECK: %[[alloc1_view:.*]] = memref.subview %[[alloc1]][0, 0] [%{{.*}}, 5] [1, 1] + // CHECK: memref.copy %[[s1]], %[[alloc1_view]] + + // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> + // CHECK-NOT: linalg.fill {{.*}} outs(%[[alloc2]] + // No subview because there is 0 padding + // CHECK: memref.copy %[[s2]], %[[alloc2]] + + // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) + // Copy back result. + // CHECK: memref.copy %[[alloc2]], %[[s2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + + // insert_slice bufferizes to a no-op. + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> + } } module attributes {transform.with_named_sequence} { @@ -69,40 +71,42 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: %[[arg0:.*]]: memref<24x12xf32, strided<[?, ?], offset: ?>>, // CHECK-SAME: %[[arg1:.*]]: memref<12x25xf32, strided<[?, ?], offset: ?>>, // CHECK-SAME: %[[arg2:.*]]: memref<24x25xf32, strided<[?, ?], offset: ?>>, -func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>, - %arg1: tensor<12x25xf32>, - %arg2: tensor<24x25xf32>, - %iv0 : index, %iv1 : index, - %iv2 : index) -> tensor<24x25xf32> { - %0 = affine.min #map()[%iv2] - - // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] - %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> - // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] - %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor - // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] - %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> - - // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]] - // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> - // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]] - - // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]] - // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> - // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]] - - // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]] - // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> - // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]] - - // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) - // Copy back result. - // CHECK: memref.copy %[[alloc2]], %[[s2]] - %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> - - // insert_slice bufferizes to a no-op. - %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> - func.return %5 : tensor<24x25xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @vectorize_and_bufferize_pad(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, + %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + + // CHECK: %[[s0:.*]] = memref.subview %[[arg0]] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + // CHECK: %[[s1:.*]] = memref.subview %[[arg1]] + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + // CHECK: %[[s2:.*]] = memref.subview %[[arg2]] + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK: %[[v0:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s0]] + // CHECK: %[[alloc0:.*]] = memref.alloc() : memref<4x7xf32, 3> + // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v0]], %[[alloc0]] + + // CHECK: %[[v1:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s1]] + // CHECK: %[[alloc1:.*]] = memref.alloc() : memref<7x5xf32, 3> + // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v1]], %[[alloc1]] + + // CHECK: %[[v2:.*]] = vector.mask {{.*}} { vector.transfer_read %[[s2]] + // CHECK: %[[alloc2:.*]] = memref.alloc() : memref<4x5xf32, 3> + // CHECK: vector.mask {{.*}} { vector.transfer_write %[[v2]], %[[alloc0]] + + // CHECK: linalg.matmul ins(%[[alloc0]], %[[alloc1]] : {{.*}}) outs(%[[alloc2]] : {{.*}}) + // Copy back result. + // CHECK: memref.copy %[[alloc2]], %[[s2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + + // insert_slice bufferizes to a no-op. + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> + } } module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 4b38db79bff3e..0439844dc66ca 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -1,16 +1,18 @@ -// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s +// RUN: mlir-opt --transform-interpreter="debug-payload-root-tag=payload" %s --split-input-file | FileCheck %s // CHECK-LABEL: func @matmul_tensors -func.func @matmul_tensors( - %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) - -> tensor<8x32xf32> { -// CHECK-NOT: linalg -// CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32> -// CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> - %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) - outs(%arg2: tensor<8x32xf32>) - -> tensor<8x32xf32> - return %0 : tensor<8x32xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @matmul_tensors( + %arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<8x32xf32>) + -> tensor<8x32xf32> { + // CHECK-NOT: linalg + // CHECK: vector.extract {{.*}} : vector<4xf32> from vector<8x4xf32> + // CHECK: vector.store {{.*}} : memref<8x32xf32>, vector<4xf32> + %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x16xf32>, tensor<16x32xf32>) + outs(%arg2: tensor<8x32xf32>) + -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> + } } module attributes {transform.with_named_sequence} { @@ -76,11 +78,13 @@ module attributes {transform.with_named_sequence} { // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32> // CHECK-NEXT: return %[[R]] : vector<64x64xf32> -func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { - %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> - %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> - %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> - return %result : vector<64x64xf32> +module @payload attributes { transform.target_tag = "payload" } { + func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> { + %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32> + %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32> + %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32> + return %result : vector<64x64xf32> + } } module attributes {transform.with_named_sequence} { @@ -95,30 +99,32 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32 -// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, -// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { -// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> -// CHECK: return %[[RES]] : vector<[4]x[4]xi32> -func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { - %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> - %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> - %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> - %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> - return %mul: vector<[4]x[4]xi32> -} +module @payload attributes { transform.target_tag = "payload" } { + // CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32 + // CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, + // CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { + // CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> + // CHECK: return %[[RES]] : vector<[4]x[4]xi32> + func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { + %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> + %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> + %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> + return %mul: vector<[4]x[4]xi32> + } -// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32 -// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, -// CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> { -// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32> -// CHECK: return %[[RES]] : vector<8x16xf32> -func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> { - %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32> - %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32> - %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32> - %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32> - return %mul: vector<8x16xf32> + // CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32 + // CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, + // CHECK-SAME: %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> { + // CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32> + // CHECK: return %[[RES]] : vector<8x16xf32> + func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> { + %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32> + %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32> + %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32> + %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32> + return %mul: vector<8x16xf32> + } } module attributes {transform.with_named_sequence} { From aca0e78354db91fa51fa3239d8d82b586adc7a77 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Wed, 11 Sep 2024 15:51:49 -0400 Subject: [PATCH 10/17] Add transform.target_tag to CHH/full.mlir --- mlir/test/Examples/transform/ChH/full.mlir | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/test/Examples/transform/ChH/full.mlir b/mlir/test/Examples/transform/ChH/full.mlir index 259475ebdbf49..005ac5a9ba8ec 100644 --- a/mlir/test/Examples/transform/ChH/full.mlir +++ b/mlir/test/Examples/transform/ChH/full.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --transform-interpreter \ +// RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" \ // RUN: --test-transform-dialect-erase-schedule \ // RUN: --math-uplift-to-fma \ // RUN: --convert-bufferization-to-memref \ @@ -19,6 +19,7 @@ // tensors annotated with attributes from the `bufferization` dialect. These // attributes hint the bufferization pass to assume buffers can be directly // used for these tensors without reshaping. +module @payload attributes { transform.target_tag = "payload" } { func.func @conv( %input: !tinput {bufferization.writable = false, bufferization.access = "read", @@ -84,7 +85,7 @@ func.func @conv( return %relued : !toutput } - +} // Module containing the transformation script to be applied. The attribute // is required to correctly verify the use of named (macro-like) sequences. module attributes { transform.with_named_sequence } { From 913ab762c527cc16509605b4d8f27ea8bd6bd157 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Mon, 16 Sep 2024 15:43:22 -0400 Subject: [PATCH 11/17] Create a pass pipeline on nested modules in mlir/test/Examples/transform/ChH/full.mlir --- mlir/test/Examples/transform/ChH/full.mlir | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/test/Examples/transform/ChH/full.mlir b/mlir/test/Examples/transform/ChH/full.mlir index 005ac5a9ba8ec..85dbf67023323 100644 --- a/mlir/test/Examples/transform/ChH/full.mlir +++ b/mlir/test/Examples/transform/ChH/full.mlir @@ -1,8 +1,6 @@ // RUN: mlir-opt %s --transform-interpreter="debug-payload-root-tag=payload" \ -// RUN: --test-transform-dialect-erase-schedule \ -// RUN: --math-uplift-to-fma \ -// RUN: --convert-bufferization-to-memref \ -// RUN: --test-lower-to-llvm |\ +// RUN: --test-transform-dialect-erase-schedule |\ +// RUN: mlir-opt -pass-pipeline='builtin.module(builtin.module(math-uplift-to-fma,convert-bufferization-to-memref,test-lower-to-llvm))' - |\ // RUN: FileCheck %s // Fixed-size tensor types to be used in convolution. From 1811994699517d14b659405ff979b91f8e5e37a0 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Tue, 17 Sep 2024 16:34:17 -0400 Subject: [PATCH 12/17] Make FunctionArgTypeConverterFn use FunctionOpInterface --- .../mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h | 3 ++- .../lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 2fda091e412ae..ba28596d1f97d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" @@ -262,7 +263,7 @@ struct BufferizationOptions { /// Parameters: Value, memory space, func op, bufferization options using FunctionArgTypeConverterFn = std::function; + FunctionOpInterface, const BufferizationOptions &)>; /// Tensor -> MemRef type converter. /// Parameters: Value, memory space, bufferization options using UnknownTypeConverterFn = std::function Date: Tue, 24 Sep 2024 14:56:07 -0400 Subject: [PATCH 13/17] Fix formatting issues --- .../IR/BufferizableOpInterface.h | 8 +++--- .../Transforms/OneShotModuleBufferize.cpp | 28 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index ba28596d1f97d..eb0df1d92d6ad 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -9,9 +9,9 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMapInfoVariant.h" #include "llvm/ADT/SetVector.h" @@ -261,9 +261,9 @@ struct BufferizationOptions { using AnalysisStateInitFn = std::function; /// Tensor -> MemRef type converter. /// Parameters: Value, memory space, func op, bufferization options - using FunctionArgTypeConverterFn = - std::function; + using FunctionArgTypeConverterFn = std::function; /// Tensor -> MemRef type converter. /// Parameters: Value, memory space, bufferization options using UnknownTypeConverterFn = std::function(funcOp.getArgumentTypes()[idx])) continue; @@ -277,10 +278,8 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp, /// Return "true" if the given function signature has tensor semantics. static bool hasTensorSignature(FunctionOpInterface funcOp) { - return llvm::any_of(funcOp.getArgumentTypes(), - llvm::IsaPred) || - llvm::any_of(funcOp.getResultTypes(), - llvm::IsaPred); + return llvm::any_of(funcOp.getArgumentTypes(), llvm::IsaPred) || + llvm::any_of(funcOp.getResultTypes(), llvm::IsaPred); } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by @@ -310,7 +309,8 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { FunctionOpInterface calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called FunctionOpInterface"); + assert(calledFunction && + "could not retrieved called FunctionOpInterface"); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller. if (!hasTensorSignature(calledFunction)) @@ -364,8 +364,8 @@ static void foldMemRefCasts(FunctionOpInterface funcOp) { } } - auto newFuncType = FunctionType::get( - funcOp.getContext(), funcOp.getArgumentTypes(), resultTypes); + auto newFuncType = FunctionType::get(funcOp.getContext(), + funcOp.getArgumentTypes(), resultTypes); funcOp.setType(newFuncType); } From 8d780007625108a7f34e40efb8604b858e04c60c Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Fri, 27 Sep 2024 15:46:11 -0400 Subject: [PATCH 14/17] Skip function call analysis for CallIndirectOp --- .../Bufferization/Transforms/OneShotModuleBufferize.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index f934ae88277b6..36499ed2ad74a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -255,6 +255,9 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { funcOp->walk([&](CallOpInterface callOp) { + if (isa(callOp)) + return WalkResult::skip(); + FunctionOpInterface calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called FunctionOpInterface"); @@ -308,6 +311,9 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { + if (isa(callOp)) + return WalkResult::skip(); + FunctionOpInterface calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called FunctionOpInterface"); From f3b55a1dea1954b6a8908378f292b939425db976 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Mon, 30 Sep 2024 09:50:34 -0400 Subject: [PATCH 15/17] Update mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com> --- .../Bufferization/Transforms/OneShotModuleBufferize.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 36499ed2ad74a..465679c1fa120 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -255,11 +255,9 @@ static void equivalenceAnalysis(FunctionOpInterface funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState) { funcOp->walk([&](CallOpInterface callOp) { - if (isa(callOp)) - return WalkResult::skip(); - FunctionOpInterface calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called FunctionOpInterface"); + if (!calledFunction) + return WalkResult::skip(); // No equivalence info available for the called function. if (!funcState.equivalentFuncArgs.count(calledFunction)) From 6c6f88984e86b89acb28f4c97d58a458ca7d6ef6 Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Mon, 30 Sep 2024 09:57:35 -0400 Subject: [PATCH 16/17] Patch another function call assertion --- .../Bufferization/Transforms/OneShotModuleBufferize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 465679c1fa120..bec5a95883452 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -313,8 +313,8 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, return WalkResult::skip(); FunctionOpInterface calledFunction = getCalledFunction(callOp); - assert(calledFunction && - "could not retrieved called FunctionOpInterface"); + if (!calledFunction) + return WalkResult::skip(); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller. if (!hasTensorSignature(calledFunction)) From 4a60687b1f744f3444855b39ae457c8efe6ae1bb Mon Sep 17 00:00:00 2001 From: Tzung-Han Juang Date: Mon, 30 Sep 2024 14:35:28 -0400 Subject: [PATCH 17/17] Remove CallIndirectOp skip --- .../Bufferization/Transforms/OneShotModuleBufferize.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index bec5a95883452..a0e5c7fff7690 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -309,9 +309,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { - if (isa(callOp)) - return WalkResult::skip(); - FunctionOpInterface calledFunction = getCalledFunction(callOp); if (!calledFunction) return WalkResult::skip();