From 5d43a6bd4fb27015e6917eb471bc725fe31b9ac2 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Thu, 28 Nov 2024 08:25:05 -0800 Subject: [PATCH 1/6] [flang][hlfir] add hlfir.eval_in_mem operation --- .../flang/Optimizer/Builder/HLFIRTools.h | 19 ++++ .../include/flang/Optimizer/HLFIR/HLFIROps.td | 59 ++++++++++ flang/lib/Optimizer/Builder/HLFIRTools.cpp | 47 ++++++++ flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 76 ++++++++++--- .../HLFIR/Transforms/BufferizeHLFIR.cpp | 33 +++++- flang/test/HLFIR/eval_in_mem-codegen.fir | 107 ++++++++++++++++++ flang/test/HLFIR/eval_in_mem.fir | 99 ++++++++++++++++ flang/test/HLFIR/invalid.fir | 34 ++++++ 8 files changed, 454 insertions(+), 20 deletions(-) create mode 100644 flang/test/HLFIR/eval_in_mem-codegen.fir create mode 100644 flang/test/HLFIR/eval_in_mem.fir diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index f073f494b3fb2..efbd9e4f50d43 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -33,6 +33,7 @@ class AssociateOp; class ElementalOp; class ElementalOpInterface; class ElementalAddrOp; +class EvaluateInMemoryOp; class YieldElementOp; /// Is this a Fortran variable for which the defining op carrying the Fortran @@ -398,6 +399,24 @@ mlir::Value inlineElementalOp( mlir::IRMapping &mapper, const std::function &mustRecursivelyInline); +/// Create a new temporary with the shape and parameters of the provided +/// hlfir.eval_in_mem operation and clone the body of the hlfir.eval_in_mem +/// operating on this new temporary. returns the temporary and whether the +/// temporary is heap or stack allocated. +std::pair +computeEvaluateOpInNewTemp(mlir::Location, fir::FirOpBuilder &, + hlfir::EvaluateInMemoryOp evalInMem, + mlir::Value shape, mlir::ValueRange typeParams); + +// Clone the body of the hlfir.eval_in_mem operating on this the provided +// storage. The provided storage must be a contiguous "raw" memory reference +// (not a fir.box) big enough to hold the value computed by hlfir.eval_in_mem. +// No runtime check is inserted by this utility to enforce that. It is also +// usually invalid to provide some storage that is already addressed directly +// or indirectly inside the hlfir.eval_in_mem body. +void computeEvaluateOpIn(mlir::Location, fir::FirOpBuilder &, + hlfir::EvaluateInMemoryOp, mlir::Value storage); + std::pair> convertToValue(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity entity); diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td index 1ab8793f72652..a9826543f48b6 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -1755,4 +1755,63 @@ def hlfir_CharExtremumOp : hlfir_Op<"char_extremum", let hasVerifier = 1; } +def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments, + RecursiveMemoryEffects, RecursivelySpeculatable, + SingleBlockImplicitTerminator<"fir::FirEndOp">]> { + let summary = "Wrap an in-memory implementation that computes expression value"; + let description = [{ + Returns a Fortran expression value for which the computation is + implemented inside the region operating on the block argument which + is a raw memory reference corresponding to the expression type. + + The shape and type parameters of the expressions are operands of the + operations. + + The memory cannot escape the region, and it is not described how it is + allocated. This facilitates later elision of the temporary storage for the + expression evaluation if it can be evaluated in some other storage (like a + left-hand side variable). + + Example: + + A function returning an array can be represented as: + ``` + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> { + ^bb0(%arg0: !fir.ref>): + %3 = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> + fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> + } + ``` + }]; + + let arguments = (ins + Optional:$shape, + Variadic:$typeparams + ); + + let results = (outs hlfir_ExprType); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + (`shape` $shape^)? (`typeparams` $typeparams^)? + attr-dict `:` functional-type(operands, results) + $body}]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "mlir::Type":$result_type, "mlir::Value":$shape, + CArg<"mlir::ValueRange", "{}">:$typeparams)> + ]; + + let extraClassDeclaration = [{ + // Return block argument representing the memory where the expression + // is evaluated. + mlir::Value getMemory() {return getBody().getArgument(0);} + }]; + + let hasVerifier = 1; +} + + #endif // FORTRAN_DIALECT_HLFIR_OPS diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 7425ccf7fc0e3..1bd950f2445ee 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -535,6 +535,8 @@ static mlir::Value tryRetrievingShapeOrShift(hlfir::Entity entity) { if (mlir::isa(entity.getType())) { if (auto elemental = entity.getDefiningOp()) return elemental.getShape(); + if (auto evalInMem = entity.getDefiningOp()) + return evalInMem.getShape(); return mlir::Value{}; } if (auto varIface = entity.getIfVariableInterface()) @@ -642,6 +644,11 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder, result.append(elemental.getTypeparams().begin(), elemental.getTypeparams().end()); return; + } else if (auto evalInMem = + expr.getDefiningOp()) { + result.append(evalInMem.getTypeparams().begin(), + evalInMem.getTypeparams().end()); + return; } else if (auto apply = expr.getDefiningOp()) { result.append(apply.getTypeparams().begin(), apply.getTypeparams().end()); return; @@ -1313,3 +1320,43 @@ hlfir::genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder, }; return {hlfir::Entity{convertedRhs}, cleanup}; } + +std::pair hlfir::computeEvaluateOpInNewTemp( + mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::EvaluateInMemoryOp evalInMem, mlir::Value shape, + mlir::ValueRange typeParams) { + llvm::StringRef tmpName{".tmp.expr_result"}; + llvm::SmallVector extents = + hlfir::getIndexExtents(loc, builder, shape); + mlir::Type baseType = + hlfir::getFortranElementOrSequenceType(evalInMem.getType()); + bool heapAllocated = fir::hasDynamicSize(baseType); + // Note: temporaries are stack allocated here when possible (do not require + // stack save/restore) because flang has always stack allocated function + // results. + mlir::Value temp = heapAllocated + ? builder.createHeapTemporary(loc, baseType, tmpName, + extents, typeParams) + : builder.createTemporary(loc, baseType, tmpName, + extents, typeParams); + mlir::Value innerMemory = evalInMem.getMemory(); + temp = builder.createConvert(loc, innerMemory.getType(), temp); + auto declareOp = builder.create( + loc, temp, tmpName, shape, typeParams, + /*dummy_scope=*/nullptr, fir::FortranVariableFlagsAttr{}); + computeEvaluateOpIn(loc, builder, evalInMem, declareOp.getOriginalBase()); + return {hlfir::Entity{declareOp.getBase()}, /*heapAllocated=*/heapAllocated}; +} + +void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::EvaluateInMemoryOp evalInMem, + mlir::Value storage) { + mlir::Value innerMemory = evalInMem.getMemory(); + mlir::Value storageCast = + builder.createConvert(loc, innerMemory.getType(), storage); + mlir::IRMapping mapper; + mapper.map(innerMemory, storageCast); + for (auto &op : evalInMem.getBody().front().without_terminator()) + builder.clone(op, mapper); + return; +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index b593383ff2848..8751988244648 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -333,6 +333,25 @@ static void printDesignatorComplexPart(mlir::OpAsmPrinter &p, p << "real"; } } +template +static llvm::LogicalResult verifyTypeparams(Op &op, mlir::Type elementType, + unsigned numLenParam) { + if (mlir::isa(elementType)) { + if (numLenParam != 1) + return op.emitOpError("must be provided one length parameter when the " + "result is a character"); + } else if (fir::isRecordWithTypeParameters(elementType)) { + if (numLenParam != + mlir::cast(elementType).getNumLenParams()) + return op.emitOpError("must be provided the same number of length " + "parameters as in the result derived type"); + } else if (numLenParam != 0) { + return op.emitOpError( + "must not be provided length parameters if the result " + "type does not have length parameters"); + } + return mlir::success(); +} llvm::LogicalResult hlfir::DesignateOp::verify() { mlir::Type memrefType = getMemref().getType(); @@ -462,20 +481,10 @@ llvm::LogicalResult hlfir::DesignateOp::verify() { return emitOpError("shape must be a fir.shape or fir.shapeshift with " "the rank of the result"); } - auto numLenParam = getTypeparams().size(); - if (mlir::isa(outputElementType)) { - if (numLenParam != 1) - return emitOpError("must be provided one length parameter when the " - "result is a character"); - } else if (fir::isRecordWithTypeParameters(outputElementType)) { - if (numLenParam != - mlir::cast(outputElementType).getNumLenParams()) - return emitOpError("must be provided the same number of length " - "parameters as in the result derived type"); - } else if (numLenParam != 0) { - return emitOpError("must not be provided length parameters if the result " - "type does not have length parameters"); - } + if (auto res = + verifyTypeparams(*this, outputElementType, getTypeparams().size()); + failed(res)) + return res; } return mlir::success(); } @@ -1989,6 +1998,45 @@ hlfir::GetLengthOp::canonicalize(GetLengthOp getLength, return mlir::success(); } +//===----------------------------------------------------------------------===// +// EvaluateInMemoryOp +//===----------------------------------------------------------------------===// + +void hlfir::EvaluateInMemoryOp::build(mlir::OpBuilder &builder, + mlir::OperationState &odsState, + mlir::Type resultType, mlir::Value shape, + mlir::ValueRange typeparams) { + odsState.addTypes(resultType); + if (shape) + odsState.addOperands(shape); + odsState.addOperands(typeparams); + odsState.addAttribute( + getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {shape ? 1 : 0, static_cast(typeparams.size())})); + mlir::Region *bodyRegion = odsState.addRegion(); + bodyRegion->push_back(new mlir::Block{}); + mlir::Type memType = fir::ReferenceType::get( + hlfir::getFortranElementOrSequenceType(resultType)); + bodyRegion->front().addArgument(memType, odsState.location); + EvaluateInMemoryOp::ensureTerminator(*bodyRegion, builder, odsState.location); +} + +llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() { + unsigned shapeRank = 0; + if (mlir::Value shape = getShape()) + if (auto shapeTy = mlir::dyn_cast(shape.getType())) + shapeRank = shapeTy.getRank(); + auto exprType = mlir::cast(getResult().getType()); + if (shapeRank != exprType.getRank()) + return emitOpError("`shape` rank must match the result rank"); + mlir::Type elementType = exprType.getElementType(); + if (auto res = verifyTypeparams(*this, elementType, getTypeparams().size()); + failed(res)) + return res; + return mlir::success(); +} + #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc" #define GET_OP_CLASSES #include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc" diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 1848dbe2c7a2c..347f0a5630777 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -905,6 +905,26 @@ struct CharExtremumOpConversion } }; +struct EvaluateInMemoryOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern< + hlfir::EvaluateInMemoryOp>::OpConversionPattern; + explicit EvaluateInMemoryOpConversion(mlir::MLIRContext *ctx) + : mlir::OpConversionPattern{ctx} {} + llvm::LogicalResult + matchAndRewrite(hlfir::EvaluateInMemoryOp evalInMemOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = evalInMemOp->getLoc(); + fir::FirOpBuilder builder(rewriter, evalInMemOp.getOperation()); + auto [temp, isHeapAlloc] = hlfir::computeEvaluateOpInNewTemp( + loc, builder, evalInMemOp, adaptor.getShape(), adaptor.getTypeparams()); + mlir::Value bufferizedExpr = + packageBufferizedExpr(loc, builder, temp, isHeapAlloc); + rewriter.replaceOp(evalInMemOp, bufferizedExpr); + return mlir::success(); + } +}; + class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: void runOnOperation() override { @@ -918,12 +938,13 @@ class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns + .insert(context); mlir::ConversionTarget target(*context); // Note that YieldElementOp is not marked as an illegal operation. // It must be erased by its parent converter and there is no explicit diff --git a/flang/test/HLFIR/eval_in_mem-codegen.fir b/flang/test/HLFIR/eval_in_mem-codegen.fir new file mode 100644 index 0000000000000..26a989832ca92 --- /dev/null +++ b/flang/test/HLFIR/eval_in_mem-codegen.fir @@ -0,0 +1,107 @@ +// Test hlfir.eval_in_mem default code generation. + +// RUN: fir-opt %s --bufferize-hlfir -o - | FileCheck %s + +func.func @_QPtest() { + %c10 = arith.constant 10 : index + %0 = fir.address_of(@_QFtestEx) : !fir.ref> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> { + ^bb0(%arg0: !fir.ref>): + %3 = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> + fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> + } + hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref> + hlfir.destroy %2 : !hlfir.expr<10xf32> + return +} +fir.global internal @_QFtestEx : !fir.array<10xf32> +func.func private @_QParray_func() -> !fir.array<10xf32> + + +func.func @_QPtest_char() { + %c10 = arith.constant 10 : index + %c5 = arith.constant 5 : index + %0 = fir.address_of(@_QFtest_charEx) : !fir.ref>> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> { + ^bb0(%arg0: !fir.ref>>): + %3 = fir.call @_QPchar_array_func() fastmath : () -> !fir.array<10x!fir.char<1,5>> + fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref>>, !fir.shape<1>, index + } + hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref>> + hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>> + return +} + +fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>> +func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>> + +func.func @test_dynamic(%arg0: !fir.box>, %arg1: index) { + %0 = fir.shape %arg1 : (index) -> !fir.shape<1> + %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg2: !fir.ref>): + %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array + fir.save_result %2 to %arg2(%0) : !fir.array, !fir.ref>, !fir.shape<1> + } + hlfir.assign %1 to %arg0 : !hlfir.expr, !fir.box> + hlfir.destroy %1 : !hlfir.expr + return +} +func.func private @_QPdyn_array_func(index) -> !fir.array + +// CHECK-LABEL: func.func @_QPtest() { +// CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = ".tmp.expr_result"} +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.address_of(@_QFtestEx) : !fir.ref> +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> +// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> +// CHECK: %[[VAL_6:.*]] = arith.constant false +// CHECK: %[[VAL_7:.*]] = fir.undefined tuple>, i1> +// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_6]], [1 : index] : (tuple>, i1>, i1) -> tuple>, i1> +// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_4]]#0, [0 : index] : (tuple>, i1>, !fir.ref>) -> tuple>, i1> +// CHECK: hlfir.assign %[[VAL_4]]#0 to %[[VAL_2]] : !fir.ref>, !fir.ref> +// CHECK: return +// CHECK: } +// CHECK: fir.global internal @_QFtestEx : !fir.array<10xf32> +// CHECK: func.func private @_QParray_func() -> !fir.array<10xf32> + +// CHECK-LABEL: func.func @_QPtest_char() { +// CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<10x!fir.char<1,5>> {bindc_name = ".tmp.expr_result"} +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_3:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref>> +// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_4]]) typeparams %[[VAL_2]] {uniq_name = ".tmp.expr_result"} : (!fir.ref>>, !fir.shape<1>, index) -> (!fir.ref>>, !fir.ref>>) +// CHECK: %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath : () -> !fir.array<10x!fir.char<1,5>> +// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_4]]) typeparams %[[VAL_2]] : !fir.array<10x!fir.char<1,5>>, !fir.ref>>, !fir.shape<1>, index +// CHECK: %[[VAL_7:.*]] = arith.constant false +// CHECK: %[[VAL_8:.*]] = fir.undefined tuple>>, i1> +// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple>>, i1>, i1) -> tuple>>, i1> +// CHECK: %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple>>, i1>, !fir.ref>>) -> tuple>>, i1> +// CHECK: hlfir.assign %[[VAL_5]]#0 to %[[VAL_3]] : !fir.ref>>, !fir.ref>> +// CHECK: return +// CHECK: } +// CHECK: fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>> +// CHECK: func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>> + +// CHECK-LABEL: func.func @test_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>, +// CHECK-SAME: %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = fir.allocmem !fir.array, %[[VAL_1]] {bindc_name = ".tmp.expr_result", uniq_name = ""} +// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.heap>) -> !fir.ref> +// CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]](%[[VAL_2]]) {uniq_name = ".tmp.expr_result"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +// CHECK: %[[VAL_6:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array +// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]]#1(%[[VAL_2]]) : !fir.array, !fir.ref>, !fir.shape<1> +// CHECK: %[[VAL_7:.*]] = arith.constant true +// CHECK: %[[VAL_8:.*]] = fir.undefined tuple>, i1> +// CHECK: %[[VAL_9:.*]] = fir.insert_value %[[VAL_8]], %[[VAL_7]], [1 : index] : (tuple>, i1>, i1) -> tuple>, i1> +// CHECK: %[[VAL_10:.*]] = fir.insert_value %[[VAL_9]], %[[VAL_5]]#0, [0 : index] : (tuple>, i1>, !fir.box>) -> tuple>, i1> +// CHECK: hlfir.assign %[[VAL_5]]#0 to %[[VAL_0]] : !fir.box>, !fir.box> +// CHECK: %[[VAL_11:.*]] = fir.box_addr %[[VAL_5]]#0 : (!fir.box>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_11]] : !fir.heap> +// CHECK: return +// CHECK: } diff --git a/flang/test/HLFIR/eval_in_mem.fir b/flang/test/HLFIR/eval_in_mem.fir new file mode 100644 index 0000000000000..34e48ed5be545 --- /dev/null +++ b/flang/test/HLFIR/eval_in_mem.fir @@ -0,0 +1,99 @@ +// Test hlfir.eval_in_mem operation parse, verify (no errors), and unparse. + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @_QPtest() { + %c10 = arith.constant 10 : index + %0 = fir.address_of(@_QFtestEx) : !fir.ref> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> { + ^bb0(%arg0: !fir.ref>): + %3 = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> + fir.save_result %3 to %arg0(%1) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> + } + hlfir.assign %2 to %0 : !hlfir.expr<10xf32>, !fir.ref> + hlfir.destroy %2 : !hlfir.expr<10xf32> + return +} +fir.global internal @_QFtestEx : !fir.array<10xf32> +func.func private @_QParray_func() -> !fir.array<10xf32> + + +func.func @_QPtest_char() { + %c10 = arith.constant 10 : index + %c5 = arith.constant 5 : index + %0 = fir.address_of(@_QFtest_charEx) : !fir.ref>> + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2 = hlfir.eval_in_mem shape %1 typeparams %c5 : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> { + ^bb0(%arg0: !fir.ref>>): + %3 = fir.call @_QPchar_array_func() fastmath : () -> !fir.array<10x!fir.char<1,5>> + fir.save_result %3 to %arg0(%1) typeparams %c5 : !fir.array<10x!fir.char<1,5>>, !fir.ref>>, !fir.shape<1>, index + } + hlfir.assign %2 to %0 : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref>> + hlfir.destroy %2 : !hlfir.expr<10x!fir.char<1,5>> + return +} + +fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>> +func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>> + +func.func @test_dynamic(%arg0: !fir.box>, %arg1: index) { + %0 = fir.shape %arg1 : (index) -> !fir.shape<1> + %1 = hlfir.eval_in_mem shape %0 : (!fir.shape<1>) -> !hlfir.expr { + ^bb0(%arg2: !fir.ref>): + %2 = fir.call @_QPdyn_array_func(%arg1) : (index) -> !fir.array + fir.save_result %2 to %arg2(%0) : !fir.array, !fir.ref>, !fir.shape<1> + } + hlfir.assign %1 to %arg0 : !hlfir.expr, !fir.box> + hlfir.destroy %1 : !hlfir.expr + return +} +func.func private @_QPdyn_array_func(index) -> !fir.array + +// CHECK-LABEL: func.func @_QPtest() { +// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFtestEx) : !fir.ref> +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> { +// CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref>): +// CHECK: %[[VAL_5:.*]] = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> +// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> +// CHECK: } +// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_1]] : !hlfir.expr<10xf32>, !fir.ref> +// CHECK: hlfir.destroy %[[VAL_3]] : !hlfir.expr<10xf32> +// CHECK: return +// CHECK: } +// CHECK: fir.global internal @_QFtestEx : !fir.array<10xf32> +// CHECK: func.func private @_QParray_func() -> !fir.array<10xf32> + +// CHECK-LABEL: func.func @_QPtest_char() { +// CHECK: %[[VAL_0:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 5 : index +// CHECK: %[[VAL_2:.*]] = fir.address_of(@_QFtest_charEx) : !fir.ref>> +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_4:.*]] = hlfir.eval_in_mem shape %[[VAL_3]] typeparams %[[VAL_1]] : (!fir.shape<1>, index) -> !hlfir.expr<10x!fir.char<1,5>> { +// CHECK: ^bb0(%[[VAL_5:.*]]: !fir.ref>>): +// CHECK: %[[VAL_6:.*]] = fir.call @_QPchar_array_func() fastmath : () -> !fir.array<10x!fir.char<1,5>> +// CHECK: fir.save_result %[[VAL_6]] to %[[VAL_5]](%[[VAL_3]]) typeparams %[[VAL_1]] : !fir.array<10x!fir.char<1,5>>, !fir.ref>>, !fir.shape<1>, index +// CHECK: } +// CHECK: hlfir.assign %[[VAL_4]] to %[[VAL_2]] : !hlfir.expr<10x!fir.char<1,5>>, !fir.ref>> +// CHECK: hlfir.destroy %[[VAL_4]] : !hlfir.expr<10x!fir.char<1,5>> +// CHECK: return +// CHECK: } +// CHECK: fir.global internal @_QFtest_charEx : !fir.array<10x!fir.char<1,5>> +// CHECK: func.func private @_QPchar_array_func() -> !fir.array<10x!fir.char<1,5>> + +// CHECK-LABEL: func.func @test_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>, +// CHECK-SAME: %[[VAL_1:.*]]: index) { +// CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_3:.*]] = hlfir.eval_in_mem shape %[[VAL_2]] : (!fir.shape<1>) -> !hlfir.expr { +// CHECK: ^bb0(%[[VAL_4:.*]]: !fir.ref>): +// CHECK: %[[VAL_5:.*]] = fir.call @_QPdyn_array_func(%[[VAL_1]]) : (index) -> !fir.array +// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]](%[[VAL_2]]) : !fir.array, !fir.ref>, !fir.shape<1> +// CHECK: } +// CHECK: hlfir.assign %[[VAL_3]] to %[[VAL_0]] : !hlfir.expr, !fir.box> +// CHECK: hlfir.destroy %[[VAL_3]] : !hlfir.expr +// CHECK: return +// CHECK: } +// CHECK: func.func private @_QPdyn_array_func(index) -> !fir.array diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir index c390dddcf3f38..5c5db7aac0697 100644 --- a/flang/test/HLFIR/invalid.fir +++ b/flang/test/HLFIR/invalid.fir @@ -1314,3 +1314,37 @@ func.func @end_associate_with_alloc_comp(%var: !hlfir.expr>}>>>, i1 return } + +// ----- + +func.func @bad_eval_in_mem_1() { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10 : (index) -> !fir.shape<1> +// expected-error@+1 {{'hlfir.eval_in_mem' op result #0 must be The type of an array, character, or derived type Fortran expression, but got '!fir.array<10xf32>'}} + %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !fir.array<10xf32> { + ^bb0(%arg0: !fir.ref>): + } + return +} + +// ----- + +func.func @bad_eval_in_mem_2() { + %c10 = arith.constant 10 : index + %1 = fir.shape %c10, %c10 : (index, index) -> !fir.shape<2> + // expected-error@+1 {{'hlfir.eval_in_mem' op `shape` rank must match the result rank}} + %2 = hlfir.eval_in_mem shape %1 : (!fir.shape<2>) -> !hlfir.expr<10xf32> { + ^bb0(%arg0: !fir.ref>): + } + return +} + +// ----- + +func.func @bad_eval_in_mem_3() { + // expected-error@+1 {{'hlfir.eval_in_mem' op must be provided one length parameter when the result is a character}} + %1 = hlfir.eval_in_mem : () -> !hlfir.expr> { + ^bb0(%arg0: !fir.ref>): + } + return +} From 6be998ec6f74937820e350f51796490b41a76d46 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Thu, 28 Nov 2024 08:26:39 -0800 Subject: [PATCH 2/6] [flang][hlfir] optimize hlfir.eval_in_mem bufferization --- .../lib/Optimizer/Analysis/AliasAnalysis.cpp | 14 ++- .../Transforms/OptimizedBufferization.cpp | 108 ++++++++++++++++++ .../HLFIR/opt-bufferization-eval_in_mem.fir | 67 +++++++++++ 3 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 flang/test/HLFIR/opt-bufferization-eval_in_mem.fir diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp index 2b24791d6c7c5..c561285b9feef 100644 --- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp +++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp @@ -91,6 +91,13 @@ bool AliasAnalysis::Source::isDummyArgument() const { return false; } +static bool isEvaluateInMemoryBlockArg(mlir::Value v) { + if (auto evalInMem = llvm::dyn_cast_or_null( + v.getParentRegion()->getParentOp())) + return evalInMem.getMemory() == v; + return false; +} + bool AliasAnalysis::Source::isData() const { return origin.isData; } bool AliasAnalysis::Source::isBoxData() const { return mlir::isa(fir::unwrapRefType(valueType)) && @@ -698,7 +705,7 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, breakFromLoop = true; }); } - if (!defOp && type == SourceKind::Unknown) + if (!defOp && type == SourceKind::Unknown) { // Check if the memory source is coming through a dummy argument. if (isDummyArgument(v)) { type = SourceKind::Argument; @@ -708,7 +715,12 @@ AliasAnalysis::Source AliasAnalysis::getSource(mlir::Value v, if (isPointerReference(ty)) attributes.set(Attribute::Pointer); + } else if (isEvaluateInMemoryBlockArg(v)) { + // hlfir.eval_in_mem block operands is allocated by the operation. + type = SourceKind::Allocate; + ty = v.getType(); } + } if (type == SourceKind::Global) { return {{global, instantiationPoint, followingData}, diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index a0160b233e3cd..e8c15a256b9da 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -1108,6 +1108,113 @@ class ReductionMaskConversion : public mlir::OpRewritePattern { } }; +class EvaluateIntoMemoryAssignBufferization + : public mlir::OpRewritePattern { + +public: + using mlir::OpRewritePattern::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::EvaluateInMemoryOp, + mlir::PatternRewriter &rewriter) const override; +}; + +static bool mayReadOrWrite(mlir::Region ®ion, mlir::Value var) { + fir::AliasAnalysis aliasAnalysis; + for (mlir::Operation &op : region.getOps()) { + if (op.hasTrait()) { + for (mlir::Region &subRegion : op.getRegions()) + if (mayReadOrWrite(subRegion, var)) + return true; + // In MLIR, RecursiveMemoryEffects can be combined with + // MemoryEffectOpInterface to describe extra effects on top of the + // effects of the nested operations. However, the presence of + // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface + // implies the operation has no other memory effects than the one of its + // nested operations. + if (!mlir::isa(op)) + continue; + } + if (!aliasAnalysis.getModRef(&op, var).isNoModRef()) + return true; + } + return false; +} + +static llvm::LogicalResult +tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem, + mlir::PatternRewriter &rewriter) { + mlir::Location loc = evalInMem.getLoc(); + hlfir::DestroyOp destroy; + hlfir::AssignOp assign; + for (auto user : llvm::enumerate(evalInMem->getUsers())) { + if (user.index() > 2) + return mlir::failure(); + mlir::TypeSwitch(user.value()) + .Case([&](hlfir::AssignOp op) { assign = op; }) + .Case([&](hlfir::DestroyOp op) { destroy = op; }); + } + if (!assign || !destroy || destroy.mustFinalizeExpr() || + assign.isAllocatableAssignment()) + return mlir::failure(); + + hlfir::Entity lhs{assign.getLhs()}; + // EvaluateInMemoryOp memory is contiguous, so in general, it can only be + // replace by the LHS if the LHS is contiguous. + if (!lhs.isSimplyContiguous()) + return mlir::failure(); + // Character assignment may involves truncation/padding, so the LHS + // cannot be used to evaluate RHS in place without proving the LHS and + // RHS lengths are the same. + if (lhs.isCharacter()) + return mlir::failure(); + + // The region must not read or write the LHS. + if (mayReadOrWrite(evalInMem.getBody(), lhs)) + return mlir::failure(); + // Any variables affected between the hlfir.evalInMem and assignment must not + // be read or written inside the region since it will be moved at the + // assignment insertion point. + auto effects = getEffectsBetween(evalInMem->getNextNode(), assign); + if (!effects) { + LLVM_DEBUG( + llvm::dbgs() + << "operation with unknown effects between eval_in_mem and assign\n"); + return mlir::failure(); + } + for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { + mlir::Value affected = effect.getValue(); + if (!affected || mayReadOrWrite(evalInMem.getBody(), affected)) + return mlir::failure(); + } + + rewriter.setInsertionPoint(assign); + fir::FirOpBuilder builder(rewriter, evalInMem.getOperation()); + mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs); + hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs); + rewriter.eraseOp(assign); + rewriter.eraseOp(destroy); + rewriter.eraseOp(evalInMem); + return mlir::success(); +} + +llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite( + hlfir::EvaluateInMemoryOp evalInMem, + mlir::PatternRewriter &rewriter) const { + if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter))) + return mlir::success(); + // Rewrite to temp + as_expr here so that the assign + as_expr pattern can + // kick-in for simple types and at least implement the assignment inline + // instead of call Assign runtime. + fir::FirOpBuilder builder(rewriter, evalInMem.getOperation()); + mlir::Location loc = evalInMem.getLoc(); + auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp( + loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams()); + rewriter.replaceOpWithNewOp( + evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated)); + return mlir::success(); +} + class OptimizedBufferizationPass : public hlfir::impl::OptimizedBufferizationBase< OptimizedBufferizationPass> { @@ -1130,6 +1237,7 @@ class OptimizedBufferizationPass patterns.insert(context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); patterns.insert>(context); patterns.insert>(context); patterns.insert>(context); diff --git a/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir new file mode 100644 index 0000000000000..984c0bcbaddcc --- /dev/null +++ b/flang/test/HLFIR/opt-bufferization-eval_in_mem.fir @@ -0,0 +1,67 @@ +// RUN: fir-opt --opt-bufferization %s | FileCheck %s + +// Fortran F2023 15.5.2.14 point 4. ensures that _QPfoo cannot access _QFtestEx +// and the temporary storage for the result can be avoided. +func.func @_QPtest(%arg0: !fir.ref> {fir.bindc_name = "x"}) { + %c10 = arith.constant 10 : index + %0 = fir.dummy_scope : !fir.dscope + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {uniq_name = "_QFtestEx"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.ref>, !fir.ref>) + %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> { + ^bb0(%arg1: !fir.ref>): + %4 = fir.call @_QPfoo() fastmath : () -> !fir.array<10xf32> + fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> + } + hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref> + hlfir.destroy %3 : !hlfir.expr<10xf32> + return +} +func.func private @_QPfoo() -> !fir.array<10xf32> + +// CHECK-LABEL: func.func @_QPtest( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {fir.bindc_name = "x"}) { +// CHECK: %[[VAL_1:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFtestEx"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.ref>, !fir.ref>) +// CHECK: %[[VAL_5:.*]] = fir.call @_QPfoo() fastmath : () -> !fir.array<10xf32> +// CHECK: fir.save_result %[[VAL_5]] to %[[VAL_4]]#1(%[[VAL_3]]) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> +// CHECK: return +// CHECK: } + + +// Temporary storage cannot be avoided in this case since +// _QFnegative_test_is_targetEx has the TARGET attribute. +func.func @_QPnegative_test_is_target(%arg0: !fir.ref> {fir.bindc_name = "x", fir.target}) { + %c10 = arith.constant 10 : index + %0 = fir.dummy_scope : !fir.dscope + %1 = fir.shape %c10 : (index) -> !fir.shape<1> + %2:2 = hlfir.declare %arg0(%1) dummy_scope %0 {fortran_attrs = #fir.var_attrs, uniq_name = "_QFnegative_test_is_targetEx"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.ref>, !fir.ref>) + %3 = hlfir.eval_in_mem shape %1 : (!fir.shape<1>) -> !hlfir.expr<10xf32> { + ^bb0(%arg1: !fir.ref>): + %4 = fir.call @_QPfoo() fastmath : () -> !fir.array<10xf32> + fir.save_result %4 to %arg1(%1) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> + } + hlfir.assign %3 to %2#0 : !hlfir.expr<10xf32>, !fir.ref> + hlfir.destroy %3 : !hlfir.expr<10xf32> + return +} +// CHECK-LABEL: func.func @_QPnegative_test_is_target( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {fir.bindc_name = "x", fir.target}) { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant false +// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_4:.*]] = fir.alloca !fir.array<10xf32> +// CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_0]]{{.*}} +// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_4]]{{.*}} +// CHECK: %[[VAL_9:.*]] = fir.call @_QPfoo() fastmath : () -> !fir.array<10xf32> +// CHECK: fir.save_result %[[VAL_9]] to %[[VAL_8]]#1{{.*}} +// CHECK: %[[VAL_10:.*]] = hlfir.as_expr %[[VAL_8]]#0 move %[[VAL_2]] : (!fir.ref>, i1) -> !hlfir.expr<10xf32> +// CHECK: fir.do_loop %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_3]] step %[[VAL_1]] unordered { +// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_10]], %[[VAL_11]] : (!hlfir.expr<10xf32>, index) -> f32 +// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_7]]#0 (%[[VAL_11]]) : (!fir.ref>, index) -> !fir.ref +// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] : f32, !fir.ref +// CHECK: } +// CHECK: hlfir.destroy %[[VAL_10]] : !hlfir.expr<10xf32> +// CHECK: return +// CHECK: } From 4debf91864773ba805a9439cde9b74fc44315acc Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Thu, 28 Nov 2024 08:28:21 -0800 Subject: [PATCH 3/6] [flang] optimize array function calls using hlfir.eval_in_mem --- flang/include/flang/Lower/ConvertCall.h | 6 +- .../flang/Optimizer/HLFIR/HLFIRDialect.h | 4 + flang/lib/Lower/ConvertCall.cpp | 102 +++++++++----- flang/lib/Lower/ConvertExpr.cpp | 13 +- flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 15 ++ flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp | 11 +- .../order_assignments/where-scheduling.f90 | 2 +- .../test/Lower/HLFIR/calls-array-results.f90 | 131 ++++++++++++++++++ flang/test/Lower/HLFIR/where-nonelemental.f90 | 38 ++--- .../Lower/explicit-interface-results-2.f90 | 2 - .../test/Lower/explicit-interface-results.f90 | 8 +- flang/test/Lower/forall/array-constructor.f90 | 2 +- 12 files changed, 260 insertions(+), 74 deletions(-) create mode 100644 flang/test/Lower/HLFIR/calls-array-results.f90 diff --git a/flang/include/flang/Lower/ConvertCall.h b/flang/include/flang/Lower/ConvertCall.h index bc082907e6176..2c51a887010c8 100644 --- a/flang/include/flang/Lower/ConvertCall.h +++ b/flang/include/flang/Lower/ConvertCall.h @@ -24,6 +24,10 @@ namespace Fortran::lower { +struct LoweredResult { + std::variant result; +}; + /// Given a call site for which the arguments were already lowered, generate /// the call and return the result. This function deals with explicit result /// allocation and lowering if needed. It also deals with passing the host @@ -32,7 +36,7 @@ namespace Fortran::lower { /// It is only used for HLFIR. /// The returned boolean indicates if finalization has been emitted in /// \p stmtCtx for the result. -std::pair genCallOpAndResult( +std::pair genCallOpAndResult( mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx, Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType, diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h index 3830237f96f3c..447d5fbab8999 100644 --- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h +++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h @@ -61,6 +61,10 @@ inline mlir::Type getFortranElementOrSequenceType(mlir::Type type) { return type; } +/// Build the hlfir.expr type for the value held in a variable of type \p +/// variableType. +mlir::Type getExprType(mlir::Type variableType); + /// Is this a fir.box or fir.class address type? inline bool isBoxAddressType(mlir::Type type) { type = fir::dyn_cast_ptrEleTy(type); diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index e84e7afbe82e0..088d8f96caa41 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -284,7 +284,8 @@ static void remapActualToDummyDescriptors( } } -std::pair Fortran::lower::genCallOpAndResult( +std::pair +Fortran::lower::genCallOpAndResult( mlir::Location loc, Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symMap, Fortran::lower::StatementContext &stmtCtx, Fortran::lower::CallerInterface &caller, mlir::FunctionType callSiteType, @@ -326,6 +327,11 @@ std::pair Fortran::lower::genCallOpAndResult( } } + const bool isExprCall = + converter.getLoweringOptions().getLowerToHighLevelFIR() && + callSiteType.getNumResults() == 1 && + llvm::isa(callSiteType.getResult(0)); + mlir::IndexType idxTy = builder.getIndexType(); auto lowerSpecExpr = [&](const auto &expr) -> mlir::Value { mlir::Value convertExpr = builder.createConvert( @@ -333,6 +339,8 @@ std::pair Fortran::lower::genCallOpAndResult( return fir::factory::genMaxWithZero(builder, loc, convertExpr); }; llvm::SmallVector resultLengths; + mlir::Value arrayResultShape; + hlfir::EvaluateInMemoryOp evaluateInMemory; auto allocatedResult = [&]() -> std::optional { llvm::SmallVector extents; llvm::SmallVector lengths; @@ -366,6 +374,18 @@ std::pair Fortran::lower::genCallOpAndResult( resultLengths = lengths; } + if (!extents.empty()) + arrayResultShape = builder.genShape(loc, extents); + + if (isExprCall) { + mlir::Type exprType = hlfir::getExprType(type); + evaluateInMemory = builder.create( + loc, exprType, arrayResultShape, resultLengths); + builder.setInsertionPointToStart(&evaluateInMemory.getBody().front()); + return toExtendedValue(loc, evaluateInMemory.getMemory(), extents, + lengths); + } + if ((!extents.empty() || !lengths.empty()) && !isElemental) { // Note: in the elemental context, the alloca ownership inside the // elemental region is implicit, and later pass in lowering (stack @@ -384,8 +404,7 @@ std::pair Fortran::lower::genCallOpAndResult( if (mustPopSymMap) symMap.popScope(); - // Place allocated result or prepare the fir.save_result arguments. - mlir::Value arrayResultShape; + // Place allocated result if (allocatedResult) { if (std::optional::PassedEntity> @@ -399,16 +418,6 @@ std::pair Fortran::lower::genCallOpAndResult( else fir::emitFatalError( loc, "only expect character scalar result to be passed by ref"); - } else { - assert(caller.mustSaveResult()); - arrayResultShape = allocatedResult->match( - [&](const fir::CharArrayBoxValue &) { - return builder.createShape(loc, *allocatedResult); - }, - [&](const fir::ArrayBoxValue &) { - return builder.createShape(loc, *allocatedResult); - }, - [&](const auto &) { return mlir::Value{}; }); } } @@ -642,6 +651,19 @@ std::pair Fortran::lower::genCallOpAndResult( callResult = call.getResult(0); } + std::optional retTy = + caller.getCallDescription().proc().GetType(); + // With HLFIR lowering, isElemental must be set to true + // if we are producing an elemental call. In this case, + // the elemental results must not be destroyed, instead, + // the resulting array result will be finalized/destroyed + // as needed by hlfir.destroy. + const bool mustFinalizeResult = + !isElemental && callSiteType.getNumResults() > 0 && + !fir::isPointerType(callSiteType.getResult(0)) && retTy.has_value() && + (retTy->category() == Fortran::common::TypeCategory::Derived || + retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()); + if (caller.mustSaveResult()) { assert(allocatedResult.has_value()); builder.create(loc, callResult, @@ -649,6 +671,19 @@ std::pair Fortran::lower::genCallOpAndResult( arrayResultShape, resultLengths); } + if (evaluateInMemory) { + builder.setInsertionPointAfter(evaluateInMemory); + mlir::Value expr = evaluateInMemory.getResult(); + fir::FirOpBuilder *bldr = &converter.getFirOpBuilder(); + if (!isElemental) + stmtCtx.attachCleanup([bldr, loc, expr, mustFinalizeResult]() { + bldr->create(loc, expr, + /*finalize=*/mustFinalizeResult); + }); + return {LoweredResult{hlfir::EntityWithAttributes{expr}}, + mustFinalizeResult}; + } + if (allocatedResult) { // The result must be optionally destroyed (if it is of a derived type // that may need finalization or deallocation of the components). @@ -679,17 +714,7 @@ std::pair Fortran::lower::genCallOpAndResult( // derived-type. // For polymorphic and unlimited polymorphic enities call the runtime // in any cases. - std::optional retTy = - caller.getCallDescription().proc().GetType(); - // With HLFIR lowering, isElemental must be set to true - // if we are producing an elemental call. In this case, - // the elemental results must not be destroyed, instead, - // the resulting array result will be finalized/destroyed - // as needed by hlfir.destroy. - if (!isElemental && !fir::isPointerType(funcType.getResults()[0]) && - retTy && - (retTy->category() == Fortran::common::TypeCategory::Derived || - retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic())) { + if (mustFinalizeResult) { if (retTy->IsPolymorphic() || retTy->IsUnlimitedPolymorphic()) { auto *bldr = &converter.getFirOpBuilder(); stmtCtx.attachCleanup([bldr, loc, allocatedResult]() { @@ -715,12 +740,13 @@ std::pair Fortran::lower::genCallOpAndResult( } } } - return {*allocatedResult, resultIsFinalized}; + return {LoweredResult{*allocatedResult}, resultIsFinalized}; } // subroutine call if (!resultType) - return {fir::ExtendedValue{mlir::Value{}}, /*resultIsFinalized=*/false}; + return {LoweredResult{fir::ExtendedValue{mlir::Value{}}}, + /*resultIsFinalized=*/false}; // For now, Fortran return values are implemented with a single MLIR // function return value. @@ -734,10 +760,13 @@ std::pair Fortran::lower::genCallOpAndResult( mlir::dyn_cast(funcType.getResults()[0]); mlir::Value len = builder.createIntegerConstant( loc, builder.getCharacterLengthType(), charTy.getLen()); - return {fir::CharBoxValue{callResult, len}, /*resultIsFinalized=*/false}; + return { + LoweredResult{fir::ExtendedValue{fir::CharBoxValue{callResult, len}}}, + /*resultIsFinalized=*/false}; } - return {callResult, /*resultIsFinalized=*/false}; + return {LoweredResult{fir::ExtendedValue{callResult}}, + /*resultIsFinalized=*/false}; } static hlfir::EntityWithAttributes genStmtFunctionRef( @@ -1661,19 +1690,26 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals, // Prepare lowered arguments according to the interface // and map the lowered values to the dummy // arguments. - auto [result, resultIsFinalized] = Fortran::lower::genCallOpAndResult( + auto [loweredResult, resultIsFinalized] = Fortran::lower::genCallOpAndResult( loc, callContext.converter, callContext.symMap, callContext.stmtCtx, caller, callSiteType, callContext.resultType, callContext.isElementalProcWithArrayArgs()); - // For procedure pointer function result, just return the call. - if (callContext.resultType && - mlir::isa(*callContext.resultType)) - return hlfir::EntityWithAttributes(fir::getBase(result)); /// Clean-up associations and copy-in. for (auto cleanUp : callCleanUps) cleanUp.genCleanUp(loc, builder); + if (auto *entity = + std::get_if(&loweredResult.result)) + return *entity; + + auto &result = std::get(loweredResult.result); + + // For procedure pointer function result, just return the call. + if (callContext.resultType && + mlir::isa(*callContext.resultType)) + return hlfir::EntityWithAttributes(fir::getBase(result)); + if (!fir::getBase(result)) return std::nullopt; // subroutine call. diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index 46168b81dd3a0..926e45b807e0a 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -2852,10 +2852,11 @@ class ScalarExprLowering { } } - ExtValue result = + auto loweredResult = Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx, caller, callSiteType, resultType) .first; + auto &result = std::get(loweredResult.result); // Sync pointers and allocatables that may have been modified during the // call. @@ -4881,10 +4882,12 @@ class ArrayExprLowering { [&](const auto &) { return fir::getBase(exv); }); caller.placeInput(argIface, arg); } - return Fortran::lower::genCallOpAndResult(loc, converter, symMap, - getElementCtx(), caller, - callSiteType, retTy) - .first; + Fortran::lower::LoweredResult res = + Fortran::lower::genCallOpAndResult(loc, converter, symMap, + getElementCtx(), caller, + callSiteType, retTy) + .first; + return std::get(res.result); }; } diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index 0b61c0edce622..c66ba75f912fb 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -215,3 +215,18 @@ bool hlfir::mayHaveAllocatableComponent(mlir::Type ty) { return fir::isPolymorphicType(ty) || fir::isUnlimitedPolymorphicType(ty) || fir::isRecordWithAllocatableMember(hlfir::getFortranElementType(ty)); } + +mlir::Type hlfir::getExprType(mlir::Type variableType) { + hlfir::ExprType::Shape typeShape; + bool isPolymorphic = fir::isPolymorphicType(variableType); + mlir::Type type = getFortranElementOrSequenceType(variableType); + assert(!fir::isa_trivial(type) && + "numerical and logical scalar should not be wrapped in hlfir.expr"); + if (auto seqType = mlir::dyn_cast(type)) { + assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions"); + typeShape.append(seqType.getShape().begin(), seqType.getShape().end()); + type = seqType.getEleTy(); + } + return hlfir::ExprType::get(variableType.getContext(), typeShape, type, + isPolymorphic); +} diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp index 8751988244648..3a172d1b8b540 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1427,16 +1427,7 @@ llvm::LogicalResult hlfir::EndAssociateOp::verify() { void hlfir::AsExprOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::Value var, mlir::Value mustFree) { - hlfir::ExprType::Shape typeShape; - bool isPolymorphic = fir::isPolymorphicType(var.getType()); - mlir::Type type = getFortranElementOrSequenceType(var.getType()); - if (auto seqType = mlir::dyn_cast(type)) { - typeShape.append(seqType.getShape().begin(), seqType.getShape().end()); - type = seqType.getEleTy(); - } - - auto resultType = hlfir::ExprType::get(builder.getContext(), typeShape, type, - isPolymorphic); + mlir::Type resultType = hlfir::getExprType(var.getType()); return build(builder, result, resultType, var, mustFree); } diff --git a/flang/test/HLFIR/order_assignments/where-scheduling.f90 b/flang/test/HLFIR/order_assignments/where-scheduling.f90 index 3010476d4a188..6feaba0d3389a 100644 --- a/flang/test/HLFIR/order_assignments/where-scheduling.f90 +++ b/flang/test/HLFIR/order_assignments/where-scheduling.f90 @@ -134,7 +134,7 @@ end function f !CHECK-NEXT: run 1 save : where/mask !CHECK-NEXT: run 2 evaluate: where/region_assign1 !CHECK-LABEL: ------------ scheduling where in _QPonly_once ------------ -!CHECK-NEXT: unknown effect: %{{[0-9]+}} = llvm.intr.stacksave : !llvm.ptr +!CHECK-NEXT: unknown effect: %11 = fir.call @_QPcall_me_only_once() fastmath : () -> !fir.array<10x!fir.logical<4>> !CHECK-NEXT: saving eval because write effect prevents re-evaluation !CHECK-NEXT: run 1 save (w): where/mask !CHECK-NEXT: run 2 evaluate: where/region_assign1 diff --git a/flang/test/Lower/HLFIR/calls-array-results.f90 b/flang/test/Lower/HLFIR/calls-array-results.f90 new file mode 100644 index 0000000000000..d91844cc2e6f8 --- /dev/null +++ b/flang/test/Lower/HLFIR/calls-array-results.f90 @@ -0,0 +1,131 @@ +! RUN: bbc -emit-hlfir -o - %s -I nowhere | FileCheck %s + +subroutine simple_test() + implicit none + interface + function array_func() + real :: array_func(10) + end function + end interface + real :: x(10) + x = array_func() +end subroutine + +subroutine arg_test(n) + implicit none + interface + function array_func_2(n) + integer(8) :: n + real :: array_func_2(n) + end function + end interface + integer(8) :: n + real :: x(n) + x = array_func_2(n) +end subroutine + +module type_defs + interface + function array_func() + real :: array_func(10) + end function + end interface + type t + contains + procedure, nopass :: array_func => array_func + end type +end module + +subroutine dispatch_test(x, a) + use type_defs, only : t + implicit none + real :: x(10) + class(t) :: a + x = a%array_func() +end subroutine + +! CHECK-LABEL: func.func @_QPsimple_test() { +! CHECK: %[[VAL_0:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_1:.*]] = fir.alloca !fir.array<10xf32> {bindc_name = "x", uniq_name = "_QFsimple_testEx"} +! CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_0]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_2]]) {uniq_name = "_QFsimple_testEx"} : (!fir.ref>, !fir.shape<1>) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[VAL_4:.*]] = arith.constant 10 : i64 +! CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_6:.*]] = arith.subi %[[VAL_4]], %[[VAL_5]] : i64 +! CHECK: %[[VAL_7:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i64 +! CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i64) -> index +! CHECK: %[[VAL_10:.*]] = arith.constant 0 : index +! CHECK: %[[VAL_11:.*]] = arith.cmpi sgt, %[[VAL_9]], %[[VAL_10]] : index +! CHECK: %[[VAL_12:.*]] = arith.select %[[VAL_11]], %[[VAL_9]], %[[VAL_10]] : index +! CHECK: %[[VAL_13:.*]] = fir.shape %[[VAL_12]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_14:.*]] = hlfir.eval_in_mem shape %[[VAL_13]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> { +! CHECK: ^bb0(%[[VAL_15:.*]]: !fir.ref>): +! CHECK: %[[VAL_16:.*]] = fir.call @_QParray_func() fastmath : () -> !fir.array<10xf32> +! CHECK: fir.save_result %[[VAL_16]] to %[[VAL_15]](%[[VAL_13]]) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> +! CHECK: } +! CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_3]]#0 : !hlfir.expr<10xf32>, !fir.ref> +! CHECK: hlfir.destroy %[[VAL_14]] : !hlfir.expr<10xf32> +! CHECK: return +! CHECK: } + +! CHECK-LABEL: func.func @_QParg_test( +! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref {fir.bindc_name = "n"}) { +! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFarg_testEn"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]]#0 : !fir.ref +! CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (i64) -> index +! CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +! CHECK: %[[VAL_6:.*]] = arith.cmpi sgt, %[[VAL_4]], %[[VAL_5]] : index +! CHECK: %[[VAL_7:.*]] = arith.select %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] : index +! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.array, %[[VAL_7]] {bindc_name = "x", uniq_name = "_QFarg_testEx"} +! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_9]]) {uniq_name = "_QFarg_testEx"} : (!fir.ref>, !fir.shape<1>) -> (!fir.box>, !fir.ref>) +! CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_2]]#1 {uniq_name = "_QFarg_testFarray_func_2En"} : (!fir.ref) -> (!fir.ref, !fir.ref) +! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref +! CHECK: %[[VAL_13:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_13]] : i64 +! CHECK: %[[VAL_15:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : i64 +! CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i64) -> index +! CHECK: %[[VAL_18:.*]] = arith.constant 0 : index +! CHECK: %[[VAL_19:.*]] = arith.cmpi sgt, %[[VAL_17]], %[[VAL_18]] : index +! CHECK: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_17]], %[[VAL_18]] : index +! CHECK: %[[VAL_21:.*]] = fir.shape %[[VAL_20]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_22:.*]] = hlfir.eval_in_mem shape %[[VAL_21]] : (!fir.shape<1>) -> !hlfir.expr { +! CHECK: ^bb0(%[[VAL_23:.*]]: !fir.ref>): +! CHECK: %[[VAL_24:.*]] = fir.call @_QParray_func_2(%[[VAL_2]]#1) fastmath : (!fir.ref) -> !fir.array +! CHECK: fir.save_result %[[VAL_24]] to %[[VAL_23]](%[[VAL_21]]) : !fir.array, !fir.ref>, !fir.shape<1> +! CHECK: } +! CHECK: hlfir.assign %[[VAL_22]] to %[[VAL_10]]#0 : !hlfir.expr, !fir.box> +! CHECK: hlfir.destroy %[[VAL_22]] : !hlfir.expr +! CHECK: return +! CHECK: } + +! CHECK-LABEL: func.func @_QPdispatch_test( +! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref> {fir.bindc_name = "x"}, +! CHECK-SAME: %[[VAL_1:.*]]: !fir.class> {fir.bindc_name = "a"}) { +! CHECK: %[[VAL_2:.*]] = fir.dummy_scope : !fir.dscope +! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]] dummy_scope %[[VAL_2]] {uniq_name = "_QFdispatch_testEa"} : (!fir.class>, !fir.dscope) -> (!fir.class>, !fir.class>) +! CHECK: %[[VAL_4:.*]] = arith.constant 10 : index +! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) dummy_scope %[[VAL_2]] {uniq_name = "_QFdispatch_testEx"} : (!fir.ref>, !fir.shape<1>, !fir.dscope) -> (!fir.ref>, !fir.ref>) +! CHECK: %[[VAL_7:.*]] = arith.constant 10 : i64 +! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_8]] : i64 +! CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64 +! CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : i64 +! CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_11]] : (i64) -> index +! CHECK: %[[VAL_13:.*]] = arith.constant 0 : index +! CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_12]], %[[VAL_13]] : index +! CHECK: %[[VAL_15:.*]] = arith.select %[[VAL_14]], %[[VAL_12]], %[[VAL_13]] : index +! CHECK: %[[VAL_16:.*]] = fir.shape %[[VAL_15]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_17:.*]] = hlfir.eval_in_mem shape %[[VAL_16]] : (!fir.shape<1>) -> !hlfir.expr<10xf32> { +! CHECK: ^bb0(%[[VAL_18:.*]]: !fir.ref>): +! CHECK: %[[VAL_19:.*]] = fir.dispatch "array_func"(%[[VAL_3]]#1 : !fir.class>) -> !fir.array<10xf32> +! CHECK: fir.save_result %[[VAL_19]] to %[[VAL_18]](%[[VAL_16]]) : !fir.array<10xf32>, !fir.ref>, !fir.shape<1> +! CHECK: } +! CHECK: hlfir.assign %[[VAL_17]] to %[[VAL_6]]#0 : !hlfir.expr<10xf32>, !fir.ref> +! CHECK: hlfir.destroy %[[VAL_17]] : !hlfir.expr<10xf32> +! CHECK: return +! CHECK: } diff --git a/flang/test/Lower/HLFIR/where-nonelemental.f90 b/flang/test/Lower/HLFIR/where-nonelemental.f90 index 643f417c47674..7be5831890012 100644 --- a/flang/test/Lower/HLFIR/where-nonelemental.f90 +++ b/flang/test/Lower/HLFIR/where-nonelemental.f90 @@ -26,11 +26,12 @@ real elemental function elem_func(x) ! CHECK-LABEL: func.func @_QPtest_where( ! CHECK: hlfir.where { ! CHECK-NOT: hlfir.exactly_once -! CHECK: %[[VAL_17:.*]] = llvm.intr.stacksave : !llvm.ptr -! CHECK: %[[VAL_19:.*]] = fir.call @_QPlogical_func1() fastmath : () -> !fir.array<100x!fir.logical<4>> -! CHECK: hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup { -! CHECK: llvm.intr.stackrestore %[[VAL_17]] : !llvm.ptr -! CHECK: } +! CHECK: %[[VAL_19:.*]] = hlfir.eval_in_mem {{.*}} { +! CHECK: fir.call @_QPlogical_func1() fastmath : () -> !fir.array<100x!fir.logical<4>> +! CHECK: } +! CHECK: hlfir.yield %[[VAL_19]] : !hlfir.expr<100x!fir.logical<4>> cleanup { +! CHECK: hlfir.destroy %[[VAL_19]] +! CHECK: } ! CHECK: } do { ! CHECK: hlfir.region_assign { ! CHECK: %[[VAL_24:.*]] = hlfir.exactly_once : f32 { @@ -70,10 +71,11 @@ real elemental function elem_func(x) ! CHECK: } ! CHECK: hlfir.elsewhere mask { ! CHECK: %[[VAL_62:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> { -! CHECK: %[[VAL_72:.*]] = llvm.intr.stacksave : !llvm.ptr -! CHECK: fir.call @_QPlogical_func2() fastmath : () -> !fir.array<100x!fir.logical<4>> -! CHECK: hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup { -! CHECK: llvm.intr.stackrestore %[[VAL_72]] : !llvm.ptr +! CHECK: %[[VAL_72:.*]] = hlfir.eval_in_mem {{.*}} { +! CHECK: fir.call @_QPlogical_func2() fastmath : () -> !fir.array<100x!fir.logical<4>> +! CHECK: } +! CHECK: hlfir.yield %[[VAL_72]] : !hlfir.expr<100x!fir.logical<4>> cleanup { +! CHECK: hlfir.destroy %[[VAL_72]] ! CHECK: } ! CHECK: } ! CHECK: hlfir.yield %[[VAL_62]] : !hlfir.expr<100x!fir.logical<4>> @@ -123,11 +125,12 @@ integer pure function pure_ifoo() ! CHECK: } (%[[VAL_10:.*]]: i32) { ! CHECK: %[[VAL_11:.*]] = hlfir.forall_index "i" %[[VAL_10]] : (i32) -> !fir.ref ! CHECK: hlfir.where { -! CHECK: %[[VAL_21:.*]] = llvm.intr.stacksave : !llvm.ptr ! CHECK-NOT: hlfir.exactly_once -! CHECK: %[[VAL_23:.*]] = fir.call @_QPpure_logical_func1() proc_attrs fastmath : () -> !fir.array<100x!fir.logical<4>> -! CHECK: hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup { -! CHECK: llvm.intr.stackrestore %[[VAL_21]] : !llvm.ptr +! CHECK: %[[VAL_23:.*]] = hlfir.eval_in_mem {{.*}} { +! CHECK: fir.call @_QPpure_logical_func1() proc_attrs fastmath : () -> !fir.array<100x!fir.logical<4>> +! CHECK: } +! CHECK: hlfir.yield %[[VAL_23]] : !hlfir.expr<100x!fir.logical<4>> cleanup { +! CHECK: hlfir.destroy %[[VAL_23]] ! CHECK: } ! CHECK: } do { ! CHECK: hlfir.region_assign { @@ -172,10 +175,11 @@ integer pure function pure_ifoo() ! CHECK: } ! CHECK: hlfir.elsewhere mask { ! CHECK: %[[VAL_129:.*]] = hlfir.exactly_once : !hlfir.expr<100x!fir.logical<4>> { -! CHECK: %[[VAL_139:.*]] = llvm.intr.stacksave : !llvm.ptr -! CHECK: %[[VAL_141:.*]] = fir.call @_QPpure_logical_func2() proc_attrs fastmath : () -> !fir.array<100x!fir.logical<4>> -! CHECK: hlfir.yield %{{.*}} : !hlfir.expr<100x!fir.logical<4>> cleanup { -! CHECK: llvm.intr.stackrestore %[[VAL_139]] : !llvm.ptr +! CHECK: %[[VAL_139:.*]] = hlfir.eval_in_mem {{.*}} { +! CHECK: fir.call @_QPpure_logical_func2() proc_attrs fastmath : () -> !fir.array<100x!fir.logical<4>> +! CHECK: } +! CHECK: hlfir.yield %[[VAL_139]] : !hlfir.expr<100x!fir.logical<4>> cleanup { +! CHECK: hlfir.destroy %[[VAL_139]] ! CHECK: } ! CHECK: } ! CHECK: hlfir.yield %[[VAL_129]] : !hlfir.expr<100x!fir.logical<4>> diff --git a/flang/test/Lower/explicit-interface-results-2.f90 b/flang/test/Lower/explicit-interface-results-2.f90 index 95aee84f4a644..2336053c32a54 100644 --- a/flang/test/Lower/explicit-interface-results-2.f90 +++ b/flang/test/Lower/explicit-interface-results-2.f90 @@ -252,12 +252,10 @@ subroutine test_call_to_used_interface(dummy_proc) call takes_array(dummy_proc()) ! CHECK: %[[VAL_1:.*]] = arith.constant 100 : index ! CHECK: %[[VAL_2:.*]] = fir.alloca !fir.array<100xf32> {bindc_name = ".result"} -! CHECK: %[[VAL_3:.*]] = llvm.intr.stacksave : !llvm.ptr ! CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1> ! CHECK: %[[VAL_5:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> (() -> !fir.array<100xf32>) ! CHECK: %[[VAL_6:.*]] = fir.call %[[VAL_5]]() {{.*}}: () -> !fir.array<100xf32> ! CHECK: fir.save_result %[[VAL_6]] to %[[VAL_2]](%[[VAL_4]]) : !fir.array<100xf32>, !fir.ref>, !fir.shape<1> ! CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_2]] : (!fir.ref>) -> !fir.ref> ! CHECK: fir.call @_QPtakes_array(%[[VAL_7]]) {{.*}}: (!fir.ref>) -> () -! CHECK: llvm.intr.stackrestore %[[VAL_3]] : !llvm.ptr end subroutine diff --git a/flang/test/Lower/explicit-interface-results.f90 b/flang/test/Lower/explicit-interface-results.f90 index 623e875b5f9c9..612d57be36448 100644 --- a/flang/test/Lower/explicit-interface-results.f90 +++ b/flang/test/Lower/explicit-interface-results.f90 @@ -195,8 +195,8 @@ subroutine dyn_array(m, n) ! CHECK-DAG: %[[ncast2:.*]] = fir.convert %[[nadd]] : (i64) -> index ! CHECK-DAG: %[[ncmpi:.*]] = arith.cmpi sgt, %[[ncast2]], %{{.*}} : index ! CHECK-DAG: %[[nselect:.*]] = arith.select %[[ncmpi]], %[[ncast2]], %{{.*}} : index - ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array, %[[mselect]], %[[nselect]] ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2> + ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array, %[[mselect]], %[[nselect]] ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_array(%[[m]], %[[n]]) {{.*}}: (!fir.ref, !fir.ref) -> !fir.array ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) : !fir.array, !fir.ref>, !fir.shape<2> print *, return_dyn_array(m, n) @@ -211,8 +211,8 @@ subroutine dyn_char_cst_array(l) ! CHECK: %[[lcast2:.*]] = fir.convert %[[lcast]] : (i64) -> index ! CHECK: %[[cmpi:.*]] = arith.cmpi sgt, %[[lcast2]], %{{.*}} : index ! CHECK: %[[select:.*]] = arith.select %[[cmpi]], %[[lcast2]], %{{.*}} : index - ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<20x30x!fir.char<1,?>>(%[[select]] : index) ! CHECK: %[[shape:.*]] = fir.shape %{{.*}}, %{{.*}} : (index, index) -> !fir.shape<2> + ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array<20x30x!fir.char<1,?>>(%[[select]] : index) ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_char_cst_array(%[[l]]) {{.*}}: (!fir.ref) -> !fir.array<20x30x!fir.char<1,?>> ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams %[[select]] : !fir.array<20x30x!fir.char<1,?>>, !fir.ref>>, !fir.shape<2>, index print *, return_dyn_char_cst_array(l) @@ -236,8 +236,8 @@ subroutine cst_char_dyn_array(m, n) ! CHECK-DAG: %[[ncast2:.*]] = fir.convert %[[nadd]] : (i64) -> index ! CHECK-DAG: %[[ncmpi:.*]] = arith.cmpi sgt, %[[ncast2]], %{{.*}} : index ! CHECK-DAG: %[[nselect:.*]] = arith.select %[[ncmpi]], %[[ncast2]], %{{.*}} : index - ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array>, %[[mselect]], %[[nselect]] ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2> + ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array>, %[[mselect]], %[[nselect]] ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_cst_char_dyn_array(%[[m]], %[[n]]) {{.*}}: (!fir.ref, !fir.ref) -> !fir.array> ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams {{.*}} : !fir.array>, !fir.ref>>, !fir.shape<2>, index print *, return_cst_char_dyn_array(m, n) @@ -267,8 +267,8 @@ subroutine dyn_char_dyn_array(l, m, n) ! CHECK-DAG: %[[lcast2:.*]] = fir.convert %[[lcast]] : (i64) -> index ! CHECK-DAG: %[[lcmpi:.*]] = arith.cmpi sgt, %[[lcast2]], %{{.*}} : index ! CHECK-DAG: %[[lselect:.*]] = arith.select %[[lcmpi]], %[[lcast2]], %{{.*}} : index - ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array>(%[[lselect]] : index), %[[mselect]], %[[nselect]] ! CHECK: %[[shape:.*]] = fir.shape %[[mselect]], %[[nselect]] : (index, index) -> !fir.shape<2> + ! CHECK: %[[tmp:.*]] = fir.alloca !fir.array>(%[[lselect]] : index), %[[mselect]], %[[nselect]] ! CHECK: %[[res:.*]] = fir.call @_QMcalleePreturn_dyn_char_dyn_array(%[[l]], %[[m]], %[[n]]) {{.*}}: (!fir.ref, !fir.ref, !fir.ref) -> !fir.array> ! CHECK: fir.save_result %[[res]] to %[[tmp]](%[[shape]]) typeparams {{.*}} : !fir.array>, !fir.ref>>, !fir.shape<2>, index integer :: l, m, n diff --git a/flang/test/Lower/forall/array-constructor.f90 b/flang/test/Lower/forall/array-constructor.f90 index 4c8c756ea689c..6b6b46fdd4688 100644 --- a/flang/test/Lower/forall/array-constructor.f90 +++ b/flang/test/Lower/forall/array-constructor.f90 @@ -232,8 +232,8 @@ end subroutine ac2 ! CHECK: %[[C0:.*]] = arith.constant 0 : index ! CHECK: %[[CMPI:.*]] = arith.cmpi sgt, %[[VAL_80]], %[[C0]] : index ! CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[VAL_80]], %[[C0]] : index -! CHECK: %[[VAL_81:.*]] = llvm.intr.stacksave : !llvm.ptr ! CHECK: %[[VAL_82:.*]] = fir.shape %[[SELECT]] : (index) -> !fir.shape<1> +! CHECK: %[[VAL_81:.*]] = llvm.intr.stacksave : !llvm.ptr ! CHECK: %[[VAL_83:.*]] = fir.convert %[[VAL_74]] : (!fir.box>) -> !fir.box> ! CHECK: %[[VAL_84:.*]] = fir.call @_QFac2Pfunc(%[[VAL_83]]) {{.*}}: (!fir.box>) -> !fir.array<3xi32> ! CHECK: fir.save_result %[[VAL_84]] to %[[VAL_2]](%[[VAL_82]]) : !fir.array<3xi32>, !fir.ref>, !fir.shape<1> From d68b5b2652831cda053c00465b33039bc645bc02 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 2 Dec 2024 01:59:33 -0800 Subject: [PATCH 4/6] PR118069 comment: mayReadOrWrite to getModRef --- .../flang/Optimizer/Analysis/AliasAnalysis.h | 6 +++ .../lib/Optimizer/Analysis/AliasAnalysis.cpp | 27 ++++++++++++++ .../Transforms/OptimizedBufferization.cpp | 37 ++++++------------- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h index e410831c0fc3e..8d17e4e476d10 100644 --- a/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h +++ b/flang/include/flang/Optimizer/Analysis/AliasAnalysis.h @@ -198,6 +198,12 @@ struct AliasAnalysis { /// Return the modify-reference behavior of `op` on `location`. mlir::ModRefResult getModRef(mlir::Operation *op, mlir::Value location); + /// Return the modify-reference behavior of operations inside `region` on + /// `location`. Contrary to getModRef(operation, location), this will visit + /// nested regions recursively according to the HasRecursiveMemoryEffects + /// trait. + mlir::ModRefResult getModRef(mlir::Region ®ion, mlir::Value location); + /// Return the memory source of a value. /// If getLastInstantiationPoint is true, the search for the source /// will stop at [hl]fir.declare if it represents a dummy diff --git a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp index c561285b9feef..0b0f83d024ce3 100644 --- a/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp +++ b/flang/lib/Optimizer/Analysis/AliasAnalysis.cpp @@ -464,6 +464,33 @@ ModRefResult AliasAnalysis::getModRef(Operation *op, Value location) { return result; } +ModRefResult AliasAnalysis::getModRef(mlir::Region ®ion, + mlir::Value location) { + ModRefResult result = ModRefResult::getNoModRef(); + for (mlir::Operation &op : region.getOps()) { + if (op.hasTrait()) { + for (mlir::Region &subRegion : op.getRegions()) { + result = result.merge(getModRef(subRegion, location)); + // Fast return is already mod and ref. + if (result.isModAndRef()) + return result; + } + // In MLIR, RecursiveMemoryEffects can be combined with + // MemoryEffectOpInterface to describe extra effects on top of the + // effects of the nested operations. However, the presence of + // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface + // implies the operation has no other memory effects than the one of its + // nested operations. + if (!mlir::isa(op)) + continue; + } + result = result.merge(getModRef(&op, location)); + if (result.isModAndRef()) + return result; + } + return result; +} + AliasAnalysis::Source::Attributes getAttrsFromVariable(fir::FortranVariableOpInterface var) { AliasAnalysis::Source::Attributes attrs; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index e8c15a256b9da..9327e7ad5875c 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -1119,28 +1119,6 @@ class EvaluateIntoMemoryAssignBufferization mlir::PatternRewriter &rewriter) const override; }; -static bool mayReadOrWrite(mlir::Region ®ion, mlir::Value var) { - fir::AliasAnalysis aliasAnalysis; - for (mlir::Operation &op : region.getOps()) { - if (op.hasTrait()) { - for (mlir::Region &subRegion : op.getRegions()) - if (mayReadOrWrite(subRegion, var)) - return true; - // In MLIR, RecursiveMemoryEffects can be combined with - // MemoryEffectOpInterface to describe extra effects on top of the - // effects of the nested operations. However, the presence of - // RecursiveMemoryEffects and the absence of MemoryEffectOpInterface - // implies the operation has no other memory effects than the one of its - // nested operations. - if (!mlir::isa(op)) - continue; - } - if (!aliasAnalysis.getModRef(&op, var).isNoModRef()) - return true; - } - return false; -} - static llvm::LogicalResult tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem, mlir::PatternRewriter &rewriter) { @@ -1168,9 +1146,17 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem, // RHS lengths are the same. if (lhs.isCharacter()) return mlir::failure(); - + fir::AliasAnalysis aliasAnalysis; // The region must not read or write the LHS. - if (mayReadOrWrite(evalInMem.getBody(), lhs)) + // Note that getModRef is used instead of mlir::MemoryEffects because + // EvaluateInMemoryOp is typically expected to hold fir.calls and that + // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects: + // it is hard/impossible to list all the read/written SSA values in a call, + // but it is often possible to tell that an SSA value cannot be accessed, + // hence getModRef is needed here and below. Also note that getModRef uses + // mlir::MemoryEffects for operations that do not have special handling in + // getModRef. + if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef()) return mlir::failure(); // Any variables affected between the hlfir.evalInMem and assignment must not // be read or written inside the region since it will be moved at the @@ -1184,7 +1170,8 @@ tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem, } for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { mlir::Value affected = effect.getValue(); - if (!affected || mayReadOrWrite(evalInMem.getBody(), affected)) + if (!affected || + aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef()) return mlir::failure(); } From 2002996a96326a28d11da797f323966c331b6fe8 Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 2 Dec 2024 05:52:09 -0800 Subject: [PATCH 5/6] PR118070 comments: simplify LoweredResult --- flang/include/flang/Lower/ConvertCall.h | 7 ++++--- flang/lib/Lower/ConvertCall.cpp | 5 ++--- flang/lib/Lower/ConvertExpr.cpp | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/flang/include/flang/Lower/ConvertCall.h b/flang/include/flang/Lower/ConvertCall.h index 2c51a887010c8..f1cd4f938320b 100644 --- a/flang/include/flang/Lower/ConvertCall.h +++ b/flang/include/flang/Lower/ConvertCall.h @@ -24,9 +24,10 @@ namespace Fortran::lower { -struct LoweredResult { - std::variant result; -}; +/// Data structure packaging the SSA value(s) produced for the result of lowered +/// function calls. +using LoweredResult = + std::variant; /// Given a call site for which the arguments were already lowered, generate /// the call and return the result. This function deals with explicit result diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 088d8f96caa41..40cd106e63018 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -1699,11 +1699,10 @@ genUserCall(Fortran::lower::PreparedActualArguments &loweredActuals, for (auto cleanUp : callCleanUps) cleanUp.genCleanUp(loc, builder); - if (auto *entity = - std::get_if(&loweredResult.result)) + if (auto *entity = std::get_if(&loweredResult)) return *entity; - auto &result = std::get(loweredResult.result); + auto &result = std::get(loweredResult); // For procedure pointer function result, just return the call. if (callContext.resultType && diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp index 926e45b807e0a..7698fac89c223 100644 --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -2856,7 +2856,7 @@ class ScalarExprLowering { Fortran::lower::genCallOpAndResult(loc, converter, symMap, stmtCtx, caller, callSiteType, resultType) .first; - auto &result = std::get(loweredResult.result); + auto &result = std::get(loweredResult); // Sync pointers and allocatables that may have been modified during the // call. @@ -4887,7 +4887,7 @@ class ArrayExprLowering { getElementCtx(), caller, callSiteType, retTy) .first; - return std::get(res.result); + return std::get(res); }; } From 4ab4992246fddf1eef2131d2a6d341095feba85f Mon Sep 17 00:00:00 2001 From: Jean Perier Date: Mon, 2 Dec 2024 07:18:49 -0800 Subject: [PATCH 6/6] remove new over-assertive assert I added an assert that prevents ExprType to be "trivial types" (i32, f32, ...) but that actually breaks existing lowering code that generate as_expr + associate patterns to create temporaries: https://github.com/llvm/llvm-project/blob/03b5f8f0f0d10c412842ed04b90e2217cf071218/flang/lib/Lower/ConvertCall.cpp#L1271 ``` subroutine trivial_as_expr(i) integer, optional :: i interface subroutine foo(j) integer, optional, value :: j end subroutine end interface call foo(i) end subroutine ``` This is unrelated to this patch, so leave it like this (it is not even a functional problem). --- flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp index c66ba75f912fb..d67b5fa659807 100644 --- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp @@ -220,8 +220,6 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) { hlfir::ExprType::Shape typeShape; bool isPolymorphic = fir::isPolymorphicType(variableType); mlir::Type type = getFortranElementOrSequenceType(variableType); - assert(!fir::isa_trivial(type) && - "numerical and logical scalar should not be wrapped in hlfir.expr"); if (auto seqType = mlir::dyn_cast(type)) { assert(!seqType.hasUnknownShape() && "assumed-rank cannot be expressions"); typeShape.append(seqType.getShape().begin(), seqType.getShape().end());