Skip to content

[MLIR] Add bufferization state to getBufferType and resolveConflicts interface methods #141466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
bool copy = true);
const BufferizationState &state, bool copy = true);

/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options);
const BufferizationOptions &options,
const BufferizationState &state);

/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
Expand All @@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);
const BufferizationOptions &options,
const BufferizationState &state);

/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
Expand All @@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);

/// Return "true" if the given op has tensor semantics and should be bufferized.
Expand Down Expand Up @@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);

/// This is the default implementation of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"const ::mlir::bufferization::AnalysisState &":$state),
"const ::mlir::bufferization::AnalysisState &":$analysisState,
"const ::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, state);
rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -528,6 +529,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
"const ::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -536,7 +538,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
value, options, invocationStack);
value, options, state, invocationStack);
}]
>,
InterfaceMethod<
Expand Down Expand Up @@ -621,7 +623,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::AnalysisState &state);
const ::mlir::bufferization::AnalysisState &analysisState,
const ::mlir::bufferization::BufferizationState &bufferizationState);

/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",

FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);

RankedTensorType getType() {
Expand Down Expand Up @@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [

FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
const BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
return bufferization::detail::defaultGetBufferType(value, options,
return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);

// Compute the buffer type of the block argument by computing the bufferized
Expand All @@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
Expand All @@ -81,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (bufferType == callerType)
continue;

// If the computed buffer type does not match the computed buffer type
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
// If the computed buffer type does not match the computed buffer type
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
assert(bufferType.hasRank() && callerType.hasRank() &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
const BufferizationOptions &options);
const BufferizationOptions &options,
BufferizationState &state);

} // namespace bufferization
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
const BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);

/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
LogicalResult insertTensorCopies(Operation *op,
const AnalysisState &analysisState,
const BufferizationState &bufferizationState);

/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());

FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
FailureOr<Value> source =
getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
Expand Down Expand Up @@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
getBuffer(rewriter, selectOp.getTrueValue(), options);
getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
getBuffer(rewriter, selectOp.getFalseValue(), options);
getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
Expand All @@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
bufferization::getBufferType(selectOp.getResult(), options);
bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
Expand All @@ -182,13 +183,14 @@ struct SelectOpInterface

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
options, invocationStack);
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
options, invocationStack);
auto trueType = bufferization::getBufferType(
selectOp.getTrueValue(), options, state, invocationStack);
auto falseType = bufferization::getBufferType(
selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
Expand Down
52 changes: 31 additions & 21 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options, bool copy) {
const BufferizationOptions &options, const BufferizationState &state,
bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
Expand Down Expand Up @@ -210,7 +211,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
FailureOr<BaseMemRefType> copyBufferType =
getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
Expand All @@ -222,7 +224,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}

LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &state) {
RewriterBase &rewriter, const AnalysisState &analysisState,
const BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
Expand All @@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");

AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
AliasingValueList aliasingValues =
analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
!analysisState.bufferizesToMemoryWrite(opOperand) &&
analysisState
.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
Expand All @@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
if (!state.canOmitTensorCopy(opOperand))
if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
if (!state.canOmitTensorCopy(opOperand))
if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
Expand All @@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
copiedOpOperands.contains(opOperand));
rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
Expand All @@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), value, state.getOptions(),
copiedOpValues.count(value));
rewriter, op->getLoc(), value, analysisState.getOptions(),
bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
Expand Down Expand Up @@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}

FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
const BufferizationOptions &options,
const BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
Expand All @@ -678,7 +684,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
Expand All @@ -689,14 +695,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,

/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state) {
SmallVector<Value> invocationStack;
return getBufferType(value, options, invocationStack);
return getBufferType(value, options, state, invocationStack);
}

/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
Expand All @@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, invocationStack);
return bufferizableOp.getBufferType(value, options, state, invocationStack);

// Op is not bufferizable.
auto memSpace =
Expand Down Expand Up @@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(

FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");

Expand All @@ -954,14 +963,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
auto opResult = llvm::cast<OpResult>(value);
AnalysisState state(options);
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
AnalysisState analysisState(options);
AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() > 0 &&
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
return getBufferType(equivalentOperand, options, invocationStack);
return getBufferType(equivalentOperand, options, bufferizationState,
invocationStack);
}

// If we do not know the memory space and there is no default memory space,
Expand Down
Loading