diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 6d8aaf64e3263..8a9ce949a750d 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for", "expected an index less than the number of region iter args"); return getBody()->getArguments().drop_front(getNumInductionVars())[index]; } - MutableArrayRef getIterOpOperands() { - return - getOperation()->getOpOperands().drop_front(getNumControlOperands()); - } void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); } void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); } void setStep(Value step) { getOperation()->setOperand(2, step); } - void setIterArg(unsigned iterArgNum, Value iterArgValue) { - getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue); - } /// Number of induction variables, always 1 for scf::ForOp. unsigned getNumInductionVars() { return 1; } diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h index f1a1f1841f179..9c11178f9cd9c 100644 --- a/mlir/include/mlir/IR/ValueRange.h +++ b/mlir/include/mlir/IR/ValueRange.h @@ -165,13 +165,9 @@ class MutableOperandRange { /// Returns the OpOperand at the given index. OpOperand &operator[](unsigned index) const; - OperandRange::iterator begin() const { - return static_cast(*this).begin(); - } - - OperandRange::iterator end() const { - return static_cast(*this).end(); - } + /// Iterators enumerate OpOperands. + MutableArrayRef::iterator begin() const; + MutableArrayRef::iterator end() const; private: /// Update the length of this range to the one provided. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 43ba11cf132cb..09d3083582808 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { static bool isMemref(Value v) { return v.getType().isa(); } +static bool isMemrefOperand(OpOperand &operand) { + return isMemref(operand.get()); +} + //===----------------------------------------------------------------------===// // Backedges analysis //===----------------------------------------------------------------------===// @@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // Add an additional operand for every MemRef for the ownership indicator. if (!funcWithoutDynamicOwnership) { - unsigned numMemRefs = llvm::count_if(operands, isMemref); + unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand); SmallVector newOperands{OperandRange(operands)}; auto ownershipValues = deallocOp.getUpdatedConditions().take_front(numMemRefs); diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp index e847e946eef1b..9423af2542690 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -96,12 +96,12 @@ struct CondBranchOpInterface mapping[retained] = ownership; } SmallVector replacements, ownerships; - for (Value operand : destOperands) { - replacements.push_back(operand); - if (isMemref(operand)) { - assert(mapping.contains(operand) && + for (OpOperand &operand : destOperands) { + replacements.push_back(operand.get()); + if (isMemref(operand.get())) { + assert(mapping.contains(operand.get()) && "Should be contained at this point"); - ownerships.push_back(mapping[operand]); + ownerships.push_back(mapping[operand.get()]); } } replacements.append(ownerships); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5565aefbad18d..2a760c76d2f68 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, assert(operand.get().getType() != replacement.getType() && "Expected a different type"); SmallVector newIterOperands; - for (OpOperand &opOperand : forOp.getIterOpOperands()) { + for (OpOperand &opOperand : forOp.getInitArgsMutable()) { if (opOperand.getOperandNumber() == operand.getOperandNumber()) { newIterOperands.push_back(replacement); continue; @@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { - for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) { + for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) { OpOperand &iterOpOperand = std::get<0>(it); auto incomingCast = iterOpOperand.get().getDefiningOp(); if (!incomingCast || diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 11cfefed890c6..8c04a8887013c 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -332,7 +332,7 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr> -getBuffers(RewriterBase &rewriter, MutableArrayRef operands, +getBuffers(RewriterBase &rewriter, MutableOperandRange operands, const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { @@ -606,7 +606,7 @@ struct ForOpInterface // The new memref init_args of the loop. FailureOr> maybeInitArgs = - getBuffers(rewriter, forOp.getIterOpOperands(), options); + getBuffers(rewriter, forOp.getInitArgsMutable(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; @@ -825,7 +825,7 @@ struct WhileOpInterface // The new memref init_args of the loop. FailureOr> maybeInitArgs = - getBuffers(rewriter, whileOp->getOpOperands(), options); + getBuffers(rewriter, whileOp.getInitsMutable(), options); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 1ce25565edcaf..ceec0756e421f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, MutableArrayRef loops) { // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) - auto [fusableProducer, destinationIterArg] = + auto [fusableProducer, destinationInitArg] = getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0], loops); if (!fusableProducer) @@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, // TODO: This can be modeled better if the `DestinationStyleOpInterface`. // Update to use that when it does become available. scf::ForOp outerMostLoop = loops.front(); - std::optional iterArgNumber; - if (destinationIterArg) { - iterArgNumber = - outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); - } - if (iterArgNumber) { + if (destinationInitArg && + (*destinationInitArg)->getOwner() == outerMostLoop) { + std::optional iterArgNumber = + outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg); int64_t resultNumber = fusableProducer.getResultNumber(); if (auto dstOp = dyn_cast(fusableProducer.getOwner())) { - outerMostLoop.setIterArg(iterArgNumber.value(), - dstOp.getTiedOpOperand(fusableProducer)->get()); + (*destinationInitArg) + ->set(dstOp.getTiedOpOperand(fusableProducer)->get()); } for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) { auto dstOp = dyn_cast(tileAndFusedOp); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 7b17e231ce106..b0c50f3d6e298 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const { return owner->getOpOperand(start + index); } +MutableArrayRef::iterator MutableOperandRange::begin() const { + return owner->getOpOperands().slice(start, length).begin(); +} + +MutableArrayRef::iterator MutableOperandRange::end() const { + return owner->getOpOperands().slice(start, length).end(); +} + //===----------------------------------------------------------------------===// // MutableOperandRangeRange diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp index 9aab89ed75536..e7bf6628ccbd7 100644 --- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp +++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp @@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) { return succOps.getMutableForwardedOperands(); } +/// Return the operand range used to transfer operands from `block` to its +/// successor with the given index. +static OperandRange getSuccessorOperands(Block *block, + unsigned successorIndex) { + return getMutableSuccessorOperands(block, successorIndex); +} + /// Appends all the block arguments from `other` to the block arguments of /// `block`, copying their types and locations. static void addBlockArgumentsFromOther(Block *block, Block *other) { @@ -175,8 +182,14 @@ class Edge { /// Returns the arguments of this edge that are passed to the block arguments /// of the successor. - MutableOperandRange getSuccessorOperands() const { - return getMutableSuccessorOperands(fromBlock, successorIndex); + MutableOperandRange getMutableSuccessorOperands() const { + return ::getMutableSuccessorOperands(fromBlock, successorIndex); + } + + /// Returns the arguments of this edge that are passed to the block arguments + /// of the successor. + OperandRange getSuccessorOperands() const { + return ::getSuccessorOperands(fromBlock, successorIndex); } }; @@ -262,7 +275,7 @@ class EdgeMultiplexer { assert(result != blockArgMapping.end() && "Edge was not originally passed to `create` method."); - MutableOperandRange successorOperands = edge.getSuccessorOperands(); + MutableOperandRange successorOperands = edge.getMutableSuccessorOperands(); // Extra arguments are always appended at the end of the block arguments. unsigned extraArgsBeginIndex = @@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock, // invalidated when mutating the operands through a different // `MutableOperandRange` of the same operation. SmallVector loopHeaderSuccessorOperands = - llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex)); + llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex)); // Add all values used in the next iteration to the exit block. Replace // any uses that are outside the loop with the newly created exit block. @@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock, loopHeaderSuccessorOperands.push_back(argument); for (Edge edge : successorEdges(latch)) - edge.getSuccessorOperands().append(argument); + edge.getMutableSuccessorOperands().append(argument); } use.set(blockArgument); @@ -939,9 +952,8 @@ static FailureOr> transformToStructuredCFBranches( if (regionEntry->getNumSuccessors() == 1) { // Single successor we can just splice together. Block *successor = regionEntry->getSuccessor(0); - for (auto &&[oldValue, newValue] : - llvm::zip(successor->getArguments(), - getMutableSuccessorOperands(regionEntry, 0))) + for (auto &&[oldValue, newValue] : llvm::zip( + successor->getArguments(), getSuccessorOperands(regionEntry, 0))) oldValue.replaceAllUsesWith(newValue); regionEntry->getTerminator()->erase();