diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.h b/mlir/include/mlir/Interfaces/FunctionInterfaces.h index e10e9bd342702..f121a6823711a 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.h +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.h @@ -20,6 +20,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td index 697f951748c67..edb0b620f5038 100644 --- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td +++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td @@ -108,6 +108,30 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [ } } + // FunctionOpInterface is tied to a ReturnLike. + Operation *terminator = entryBlock.getTerminator(); + if (!terminator->hasTrait()) { + return $_op.emitOpError("The body of a FunctionOpInterface must have ") + << "a ReturnLike terminator, but the current terminator does not " + << "have this trait."; + } + + // Match ReturnLike's operand types and FunctionOpInterface's + // result types. + auto returnOperandTypes = terminator->getOperandTypes(); + auto funcResultTypes = $_op->getResultTypes(); + if (funcResultTypes.size() != returnOperandTypes.size()) { + return $_op.emitOpError("The number of a FunctionOpInterface's") + << "results must match that of the ReturnLike operands."; + } + + for (unsigned i = 0; i < funcResultTypes.size(); ++i) { + if (funcResultTypes[i] != returnOperandTypes[i]) { + return $_op.emitOpError("The result types of a FunctionOpInterface") + << "must match the operand types of the ReturnLike."; + } + } + return success(); }]>, InterfaceMethod<[{ diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index a0e5c7fff7690..bb1cf2fc3fd20 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -302,8 +302,8 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, Operation *returnOp = getAssumedUniqueReturnOp(funcOp); if (!returnOp) return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; + << "cannot bufferize a FunctionOpInterface with tensors and " + "without a unique ReturnLike"; } // Collect function calls and populate the caller map.