diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp index b524b11f59664..4d3ea51ae1a5f 100644 --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -77,6 +77,72 @@ class LoopVersioningPass void runOnOperation() override; }; +/// @struct ArgInfo +/// A structure to hold an argument, the size of the argument and dimension +/// information. +struct ArgInfo { + mlir::Value arg; + size_t size; + unsigned rank; + fir::BoxDimsOp dims[CFI_MAX_RANK]; +}; + +/// @struct ArgsUsageInLoop +/// A structure providing information about the function arguments +/// usage by the instructions immediately nested in a loop. +struct ArgsUsageInLoop { + /// Mapping between the memref operand of an array indexing + /// operation (e.g. fir.coordinate_of) and the argument information. + llvm::DenseMap usageInfo; + /// Some array indexing operations inside a loop cannot be transformed. + /// This vector holds the memref operands of such operations. + /// The vector is used to make sure that we do not try to transform + /// any outer loop, since this will imply the operation rewrite + /// in this loop. + llvm::SetVector cannotTransform; + + // Debug dump of the structure members assuming that + // the information has been collected for the given loop. + void dump(fir::DoLoopOp loop) const { + // clang-format off + LLVM_DEBUG( + mlir::OpPrintingFlags printFlags; + printFlags.skipRegions(); + llvm::dbgs() << "Arguments usage info for loop:\n"; + loop.print(llvm::dbgs(), printFlags); + llvm::dbgs() << "\nUsed args:\n"; + for (auto &use : usageInfo) { + mlir::Value v = use.first; + v.print(llvm::dbgs(), printFlags); + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "\nCannot transform args:\n"; + for (mlir::Value arg : cannotTransform) { + arg.print(llvm::dbgs(), printFlags); + llvm::dbgs() << "\n"; + } + llvm::dbgs() << "====\n" + ); + // clang-format on + } + + // Erase usageInfo and cannotTransform entries for a set + // of given arguments. + void eraseUsage(const llvm::SetVector &args) { + for (auto &arg : args) + usageInfo.erase(arg); + cannotTransform.set_subtract(args); + } + + // Erase usageInfo and cannotTransform entries for a set + // of given arguments provided in the form of usageInfo map. + void eraseUsage(const llvm::DenseMap &args) { + for (auto &arg : args) { + usageInfo.erase(arg.first); + cannotTransform.remove(arg.first); + } + } +}; } // namespace /// @c replaceOuterUses - replace uses outside of @c op with result of @c @@ -179,16 +245,6 @@ void LoopVersioningPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); mlir::func::FuncOp func = getOperation(); - /// @c ArgInfo - /// A structure to hold an argument, the size of the argument and dimension - /// information. - struct ArgInfo { - mlir::Value arg; - size_t size; - unsigned rank; - fir::BoxDimsOp dims[CFI_MAX_RANK]; - }; - // First look for arguments with assumed shape = unknown extent in the lowest // dimension. LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n"); @@ -224,58 +280,137 @@ void LoopVersioningPass::runOnOperation() { } } - if (argsOfInterest.empty()) + if (argsOfInterest.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No suitable arguments.\n=== End " DEBUG_TYPE " ===\n"); return; + } - struct OpsWithArgs { - mlir::Operation *op; - mlir::SmallVector argsAndDims; - }; - // Now see if those arguments are used inside any loop. - mlir::SmallVector loopsOfInterest; + // A list of all loops in the function in post-order. + mlir::SmallVector originalLoops; + // Information about the arguments usage by the instructions + // immediately nested in a loop. + llvm::DenseMap argsInLoops; + // Traverse the loops in post-order and see + // if those arguments are used inside any loop. func.walk([&](fir::DoLoopOp loop) { mlir::Block &body = *loop.getBody(); - mlir::SmallVector argsInLoop; + auto &argsInLoop = argsInLoops[loop]; + originalLoops.push_back(loop); body.walk([&](mlir::Operation *op) { - // support either fir.array_coor or fir.coordinate_of - if (auto arrayCoor = mlir::dyn_cast(op)) { - // no support currently for sliced arrays - if (arrayCoor.getSlice()) - return; - } else if (!mlir::isa(op)) { + // Support either fir.array_coor or fir.coordinate_of. + if (!mlir::isa(op)) return; - } - - // The current operation could be inside another loop than - // the one we're currently processing. Skip it, we'll get - // to it later. + // Process only operations immediately nested in the current loop. if (op->getParentOfType() != loop) return; mlir::Value operand = op->getOperand(0); for (auto a : argsOfInterest) { if (a.arg == normaliseVal(operand)) { - // use the reboxed value, not the block arg when re-creating the loop: + // Use the reboxed value, not the block arg when re-creating the loop. + // TODO: should we check that the operand dominates the loop? + // If this might be a case, we should record such operands in + // argsInLoop.cannotTransform, so that they disable the transformation + // for the parent loops as well. a.arg = operand; - // Only add if it's not already in the list. - if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) { - return it.arg == a.arg; - }) == argsInLoop.end()) { - argsInLoop.push_back(a); + // No support currently for sliced arrays. + // This means that we cannot transform properly + // instructions referencing a.arg in the whole loop + // nest this loop is located in. + if (auto arrayCoor = mlir::dyn_cast(op)) + if (arrayCoor.getSlice()) + argsInLoop.cannotTransform.insert(a.arg); + + if (argsInLoop.cannotTransform.contains(a.arg)) { + // Remove any previously recorded usage, if any. + argsInLoop.usageInfo.erase(a.arg); break; } + + // Record the a.arg usage, if not recorded yet. + argsInLoop.usageInfo.try_emplace(a.arg, a); + break; } } }); - - if (!argsInLoop.empty()) { - OpsWithArgs ops = {loop, argsInLoop}; - loopsOfInterest.push_back(ops); - } }); - if (loopsOfInterest.empty()) + + // Dump loops info after initial collection. + // clang-format off + LLVM_DEBUG( + llvm::dbgs() << "Initial usage info:\n"; + for (fir::DoLoopOp loop : originalLoops) { + auto &argsInLoop = argsInLoops[loop]; + argsInLoop.dump(loop); + } + ); + // clang-format on + + // Clear argument usage for parent loops if an inner loop + // contains a non-transformable usage. + for (fir::DoLoopOp loop : originalLoops) { + auto &argsInLoop = argsInLoops[loop]; + if (argsInLoop.cannotTransform.empty()) + continue; + + fir::DoLoopOp parent = loop; + while ((parent = parent->getParentOfType())) + argsInLoops[parent].eraseUsage(argsInLoop.cannotTransform); + } + + // If an argument access can be optimized in a loop and + // its descendant loop, then it does not make sense to + // generate the contiguity check for the descendant loop. + // The check will be produced as part of the ancestor + // loop's transformation. So we can clear the argument + // usage for all descendant loops. + for (fir::DoLoopOp loop : originalLoops) { + auto &argsInLoop = argsInLoops[loop]; + if (argsInLoop.usageInfo.empty()) + continue; + + loop.getBody()->walk([&](fir::DoLoopOp dloop) { + argsInLoops[dloop].eraseUsage(argsInLoop.usageInfo); + }); + } + + // clang-format off + LLVM_DEBUG( + llvm::dbgs() << "Final usage info:\n"; + for (fir::DoLoopOp loop : originalLoops) { + auto &argsInLoop = argsInLoops[loop]; + argsInLoop.dump(loop); + } + ); + // clang-format on + + // Reduce the collected information to a list of loops + // with attached arguments usage information. + // The list must hold the loops in post order, so that + // the inner loops are transformed before the outer loops. + struct OpsWithArgs { + mlir::Operation *op; + mlir::SmallVector argsAndDims; + }; + mlir::SmallVector loopsOfInterest; + for (fir::DoLoopOp loop : originalLoops) { + auto &argsInLoop = argsInLoops[loop]; + if (argsInLoop.usageInfo.empty()) + continue; + OpsWithArgs info; + info.op = loop; + for (auto &arg : argsInLoop.usageInfo) + info.argsAndDims.push_back(arg.second); + loopsOfInterest.emplace_back(std::move(info)); + } + + if (loopsOfInterest.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "No loops to transform.\n=== End " DEBUG_TYPE " ===\n"); return; + } // If we get here, there are loops to process. fir::FirOpBuilder builder{module, std::move(kindMap)}; diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir index 566903d0897f2..f2768d7325f74 100644 --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -118,8 +118,6 @@ func.func @sum1dfixed(%arg0: !fir.ref> {fir.bindc_name = "a"}, // ----- -// RUN: fir-opt --loop-versioning %s | FileCheck %s - // Check that "no result" from a versioned loop works correctly // This code was the basis for this, but `read` is replaced with a function called Func // subroutine test3(x, y) @@ -1266,4 +1264,174 @@ func.func @test_optional_arg(%arg0: !fir.box> {fir.bindc_name // CHECK: fir.store %[[VAL_166:.*]]#1 to %[[VAL_18]] : !fir.ref // CHECK: return // CHECK: } + +// ! Verify that neither of the loops is versioned +// ! due to the array section in the inner loop: +// subroutine test_slice(x) +// real :: x(:,:) +// do i=10,100 +// x(i,7) = 1.0 +// x(i,3:5) = 2.0 +// end do +// end subroutine test_slice +func.func @_QPtest_slice(%arg0: !fir.box> {fir.bindc_name = "x"}) { + %c10 = arith.constant 10 : index + %c100 = arith.constant 100 : index + %c6_i64 = arith.constant 6 : i64 + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %cst = arith.constant 2.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %cst_0 = arith.constant 1.000000e+00 : f32 + %c1 = arith.constant 1 : index + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_sliceEi"} + %1 = fir.convert %c10 : (index) -> i32 + %2:2 = fir.do_loop %arg1 = %c10 to %c100 step %c1 iter_args(%arg2 = %1) -> (index, i32) { + fir.store %arg2 to %0 : !fir.ref + %3 = fir.load %0 : !fir.ref + %4 = fir.convert %3 : (i32) -> i64 + %5 = arith.subi %4, %c1_i64 : i64 + %6 = fir.coordinate_of %arg0, %5, %c6_i64 : (!fir.box>, i64, i64) -> !fir.ref + fir.store %cst_0 to %6 : !fir.ref + %7 = fir.load %0 : !fir.ref + %8 = fir.convert %7 : (i32) -> i64 + %9 = fir.undefined index + %10 = fir.convert %7 : (i32) -> index + %11 = fir.slice %8, %9, %9, %c3, %c5, %c1 : (i64, index, index, index, index, index) -> !fir.slice<2> + %12 = fir.undefined !fir.array + %13 = fir.do_loop %arg3 = %c0 to %c2 step %c1 unordered iter_args(%arg4 = %12) -> (!fir.array) { + %18 = arith.addi %arg3, %c1 : index + %19 = fir.array_coor %arg0 [%11] %10, %18 : (!fir.box>, !fir.slice<2>, index, index) -> !fir.ref + fir.store %cst to %19 : !fir.ref + fir.result %12 : !fir.array + } + %14 = arith.addi %arg1, %c1 : index + %15 = fir.convert %c1 : (index) -> i32 + %16 = fir.load %0 : !fir.ref + %17 = arith.addi %16, %15 : i32 + fir.result %14, %17 : index, i32 + } + fir.store %2#1 to %0 : !fir.ref + return +} +// CHECK-LABEL: func.func @_QPtest_slice( +// CHECK-NOT: fir.if + +// ! Verify versioning for argument 'x' but not for 'y': +// subroutine test_independent_args(x, y) +// real :: x(:,:), y(:,:) +// do i=10,100 +// x(i,7) = 1.0 +// y(i,3:5) = 2.0 +// end do +// end subroutine test_independent_args +func.func @_QPtest_independent_args(%arg0: !fir.box> {fir.bindc_name = "x"}, %arg1: !fir.box> {fir.bindc_name = "y"}) { + %c10 = arith.constant 10 : index + %c100 = arith.constant 100 : index + %c6_i64 = arith.constant 6 : i64 + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %cst = arith.constant 2.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + %cst_0 = arith.constant 1.000000e+00 : f32 + %c1 = arith.constant 1 : index + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_independent_argsEi"} + %1 = fir.convert %c10 : (index) -> i32 + %2:2 = fir.do_loop %arg2 = %c10 to %c100 step %c1 iter_args(%arg3 = %1) -> (index, i32) { + fir.store %arg3 to %0 : !fir.ref + %3 = fir.load %0 : !fir.ref + %4 = fir.convert %3 : (i32) -> i64 + %5 = arith.subi %4, %c1_i64 : i64 + %6 = fir.coordinate_of %arg0, %5, %c6_i64 : (!fir.box>, i64, i64) -> !fir.ref + fir.store %cst_0 to %6 : !fir.ref + %7 = fir.load %0 : !fir.ref + %8 = fir.convert %7 : (i32) -> i64 + %9 = fir.undefined index + %10 = fir.convert %7 : (i32) -> index + %11 = fir.slice %8, %9, %9, %c3, %c5, %c1 : (i64, index, index, index, index, index) -> !fir.slice<2> + %12 = fir.undefined !fir.array + %13 = fir.do_loop %arg4 = %c0 to %c2 step %c1 unordered iter_args(%arg5 = %12) -> (!fir.array) { + %18 = arith.addi %arg4, %c1 : index + %19 = fir.array_coor %arg1 [%11] %10, %18 : (!fir.box>, !fir.slice<2>, index, index) -> !fir.ref + fir.store %cst to %19 : !fir.ref + fir.result %12 : !fir.array + } + %14 = arith.addi %arg2, %c1 : index + %15 = fir.convert %c1 : (index) -> i32 + %16 = fir.load %0 : !fir.ref + %17 = arith.addi %16, %15 : i32 + fir.result %14, %17 : index, i32 + } + fir.store %2#1 to %0 : !fir.ref + return +} +// CHECK-LABEL: func.func @_QPtest_independent_args( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box> {fir.bindc_name = "x"}, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.box> {fir.bindc_name = "y"}) { +// CHECK: %[[VAL_16:.*]]:3 = fir.box_dims %[[VAL_0]], %{{.*}} : (!fir.box>, index) -> (index, index, index) +// CHECK: %[[VAL_19:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_20:.*]] = arith.cmpi eq, %[[VAL_16]]#2, %[[VAL_19]] : index +// CHECK: %[[VAL_21:.*]]:2 = fir.if %[[VAL_20]] -> (index, i32) { +// CHECK-NOT: fir.if + + +// ! Verify that the whole loop nest is versioned +// ! without additional contiguity check for the inner loop: +// subroutine test_loop_nest(x) +// real :: x(:) +// do i=10,100 +// x(i) = 1.0 +// do j=10,100 +// x(j) = 2.0 +// end do +// end do +// end subroutine test_loop_nest +func.func @_QPtest_loop_nest(%arg0: !fir.box> {fir.bindc_name = "x"}) { + %c10 = arith.constant 10 : index + %c100 = arith.constant 100 : index + %cst = arith.constant 2.000000e+00 : f32 + %c1_i64 = arith.constant 1 : i64 + %cst_0 = arith.constant 1.000000e+00 : f32 + %c1 = arith.constant 1 : index + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_loop_nestEi"} + %1 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QFtest_loop_nestEj"} + %2 = fir.convert %c10 : (index) -> i32 + %3:2 = fir.do_loop %arg1 = %c10 to %c100 step %c1 iter_args(%arg2 = %2) -> (index, i32) { + fir.store %arg2 to %0 : !fir.ref + %4 = fir.load %0 : !fir.ref + %5 = fir.convert %4 : (i32) -> i64 + %6 = arith.subi %5, %c1_i64 : i64 + %7 = fir.coordinate_of %arg0, %6 : (!fir.box>, i64) -> !fir.ref + fir.store %cst_0 to %7 : !fir.ref + %8:2 = fir.do_loop %arg3 = %c10 to %c100 step %c1 iter_args(%arg4 = %2) -> (index, i32) { + fir.store %arg4 to %1 : !fir.ref + %13 = fir.load %1 : !fir.ref + %14 = fir.convert %13 : (i32) -> i64 + %15 = arith.subi %14, %c1_i64 : i64 + %16 = fir.coordinate_of %arg0, %15 : (!fir.box>, i64) -> !fir.ref + fir.store %cst to %16 : !fir.ref + %17 = arith.addi %arg3, %c1 : index + %18 = fir.convert %c1 : (index) -> i32 + %19 = fir.load %1 : !fir.ref + %20 = arith.addi %19, %18 : i32 + fir.result %17, %20 : index, i32 + } + fir.store %8#1 to %1 : !fir.ref + %9 = arith.addi %arg1, %c1 : index + %10 = fir.convert %c1 : (index) -> i32 + %11 = fir.load %0 : !fir.ref + %12 = arith.addi %11, %10 : i32 + fir.result %9, %12 : index, i32 + } + fir.store %3#1 to %0 : !fir.ref + return +} +// CHECK-LABEL: func.func @_QPtest_loop_nest( +// CHECK: fir.if +// CHECK-NOT: fir.if + } // End module