diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 7ef016f88be37..f169857ae3741 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -263,16 +263,17 @@ static void promoteIfBlock(AffineIfOp ifOp, bool elseBlock) { static Operation *getOutermostInvariantForOp(AffineIfOp ifOp) { // Walk up the parents past all for op that this conditional is invariant on. auto ifOperands = ifOp.getOperands(); - auto *res = ifOp.getOperation(); - while (!isa(res->getParentOp())) { + Operation *res = ifOp; + while (!res->getParentOp()->hasTrait()) { auto *parentOp = res->getParentOp(); if (auto forOp = dyn_cast(parentOp)) { if (llvm::is_contained(ifOperands, forOp.getInductionVar())) break; } else if (auto parallelOp = dyn_cast(parentOp)) { - for (auto iv : parallelOp.getIVs()) - if (llvm::is_contained(ifOperands, iv)) - break; + if (llvm::any_of(parallelOp.getIVs(), [&](Value iv) { + return llvm::is_contained(ifOperands, iv); + })) + break; } else if (!isa(parentOp)) { // Won't walk up past anything other than affine.for/if ops. break; @@ -438,11 +439,10 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { if (folded) *folded = false; - // The folding above should have ensured this, but the affine.if's - // canonicalization is missing composition of affine.applys into it. + // The folding above should have ensured this. assert(llvm::all_of(ifOp.getOperands(), [](Value v) { - return isTopLevelValue(v) || isAffineForInductionVar(v); + return isTopLevelValue(v) || isAffineInductionVar(v); }) && "operands not composed"); diff --git a/mlir/test/Dialect/Affine/loop-unswitch.mlir b/mlir/test/Dialect/Affine/loop-unswitch.mlir index 5a58941937bf5..c94f5515d49ea 100644 --- a/mlir/test/Dialect/Affine/loop-unswitch.mlir +++ b/mlir/test/Dialect/Affine/loop-unswitch.mlir @@ -254,3 +254,31 @@ func.func @multiple_if(%N : index) { // CHECK-NEXT: return func.func private @external() + +// Check to ensure affine.parallel ops are handled as well. + +#set = affine_set<(d0) : (-d0 + 3 >= 0)> +// CHECK-LABEL: affine_parallel +func.func @affine_parallel(%arg0: memref<35xf32>) -> memref<35xf32> { + %0 = llvm.mlir.constant(1.000000e+00 : f32) : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<35xf32> + // CHECK: affine.parallel + affine.parallel (%arg1) = (0) to (35) step (32) { + // This can't be hoisted further. + // CHECK-NEXT: affine.if + affine.if #set(%arg1) { + affine.parallel (%arg2) = (%arg1) to (%arg1 + 32) { + %1 = affine.load %arg0[%arg2] : memref<35xf32> + %2 = llvm.fdiv %0, %1 : f32 + affine.store %2, %alloc[%arg2] : memref<35xf32> + } + } else { + affine.parallel (%arg2) = (%arg1) to (min(%arg1 + 32, 35)) { + %1 = affine.load %arg0[%arg2] : memref<35xf32> + %2 = llvm.fdiv %0, %1 : f32 + affine.store %2, %alloc[%arg2] : memref<35xf32> + } + } + } + return %alloc : memref<35xf32> +}