diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h index e8e6226460ac7..51f3c0843569d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension { /// analyzed. DenseMap analyzedFuncOps; + /// A collection of cached SymbolTables used for faster function lookup. + mutable SymbolTableCollection symbolTables; + /// This function is called right before analyzing the given FuncOp. It /// initializes the data structures for the FuncOp in this state object. void startFunctionAnalysis(FuncOp funcOp); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index c45678f1e4b4d..76039fb7dd485 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -76,13 +76,29 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, } /// Return the FuncOp called by `callOp`. -static FuncOp getCalledFunction(CallOpInterface callOp) { +static FuncOp getCalledFunction(CallOpInterface callOp, + SymbolTableCollection &symbolTables) { SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + symbolTables.lookupNearestSymbolFrom(callOp, sym)); +} + +/// Return the FuncOp called by `callOp`. +static FuncOp getCalledFunction(CallOpInterface callOp, + const AnalysisState &state) { + auto &oneShotAnalysisState = static_cast(state); + + if (auto *funcAnalysisState = + oneShotAnalysisState.getExtension()) { + // Use the cached symbol tables. + return getCalledFunction(callOp, funcAnalysisState->symbolTables); + } + + SymbolTableCollection symbolTables; + return getCalledFunction(callOp, symbolTables); } /// Get FuncAnalysisState. @@ -135,7 +151,7 @@ struct CallOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + FuncOp funcOp = getCalledFunction(callOp, state); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) @@ -150,7 +166,7 @@ struct CallOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + FuncOp funcOp = getCalledFunction(callOp, state); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) @@ -165,7 +181,7 @@ struct CallOpInterface AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + FuncOp funcOp = getCalledFunction(callOp, state); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Any OpResult may be aliasing. @@ -199,7 +215,11 @@ struct CallOpInterface getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector &invocationStack) const { auto callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); + + // TODO Avoid recomputing the symbol tables every time. + SymbolTableCollection symbolTable; + + FuncOp funcOp = getCalledFunction(callOp, symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); // If the callee was already bufferized, we can directly take the type from @@ -243,7 +263,11 @@ struct CallOpInterface // 2. Rewrite tensor operands as memrefs based on type of the already // bufferized callee. SmallVector newOperands; - FuncOp funcOp = getCalledFunction(callOp); + + // TODO Avoid recomputing the symbol tables every time. + SymbolTableCollection symbolTable; + + FuncOp funcOp = getCalledFunction(callOp, symbolTable); assert(funcOp && "expected CallOp to a FuncOp"); FunctionType funcType = funcOp.getFunctionType(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index edd6bcf84f460..a025da8635135 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) { } /// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(func::CallOp callOp) { +static func::FuncOp +getCalledFunction(func::CallOp callOp, + mlir::SymbolTableCollection &symbolTable) { SymbolRefAttr sym = llvm::dyn_cast_if_present(callOp.getCallableForCallee()); if (!sym) return nullptr; return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + symbolTable.lookupNearestSymbolFrom(callOp, sym)); } /// Return "true" if the given function signature has tensor semantics. @@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; + + // TODO Avoid recomputing the symbol tables every time. + mlir::SymbolTableCollection symbolTable; + for (func::FuncOp funcOp : moduleOp.getOps()) { // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp); + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable); assert(calledFunction && "could not retrieved called func::FuncOp"); // If the called function does not have any tensors in its signature, then // it is not necessary to bufferize the callee before the caller.