diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h index 915ab3016b688..644118ca884c6 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 0c93989ca99a4..6130f031ca6ab 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [ RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, + DestinationStyleOpInterface ]> { let summary = "evaluate a block multiple times in parallel"; let description = [{ @@ -630,6 +632,9 @@ def ForallOp : SCF_Op<"forall", [ Location loc); InParallelOp getTerminator(); + + // Declare the shared_outs as inits/outs to DestinationStyleOpInterface. + MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } }]; } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 3e30e320bee8f..f719cfed6b6dd 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -3970,6 +3971,11 @@ struct FoldTensorCastProducerOp if (isa(op.getOperation())) return failure(); + // Exclude DPS ops that are also LoopLike from this interface as they + // might need special handling of attached regions. + if (isa(op.getOperation())) + return failure(); + // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {