diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b946fc8875860..d3933cad920a3 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1358,7 +1358,8 @@ def VerifyOp : TransformDialectOp<"verify", } def YieldOp : TransformDialectOp<"yield", - [Terminator, DeclareOpInterfaceMethods]> { + [Terminator, ReturnLike, + DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; let description = [{ This terminator operation yields operation handles from regions of the diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 590cae9aa0d66..48b25d19d7dc3 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -104,16 +104,8 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { - regions.emplace_back(&alternative, !getOperands().empty() - ? alternative.getArguments() - : Block::BlockArgListType()); - } if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1502,16 +1494,8 @@ void transform::ForeachOp::getEffects( void transform::ForeachOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - Region *bodyRegion = &getBody(); - if (point.isParent()) { - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - return; - } - - // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); + if (point.getRegionOrNull() == &getBody()) + regions.emplace_back(getResults()); } OperandRange @@ -2702,16 +2686,8 @@ transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::SequenceOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.isParent()) { - Region *bodyRegion = &getBody(); - regions.emplace_back(bodyRegion, getNumOperands() != 0 - ? bodyRegion->getArguments() - : Block::BlockArgListType()); - return; - } - - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + if (point.getRegionOrNull() == &getBody()) + regions.emplace_back(getResults()); } void transform::SequenceOp::getRegionInvocationBounds(