diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index f7bffbf53c190..c332493eb8072 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget { //===----------------------------------------------------------------------===// namespace { +// AArch64 procedure call standard: +// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing struct TargetAArch64 : public GenericTarget { using GenericTarget::GenericTarget; @@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget { return marshal; } - // Flatten a RecordType::TypeList containing more record types or array types + // Flatten a RecordType::TypeList containing more record types or array type static std::optional> flattenTypeList(const RecordType::TypeList &types) { std::vector flatTypes; @@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget { // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An // HFA is a record type with up to 4 floating-point members of the same type. - static bool isHFA(fir::RecordType ty) { + static std::optional usedRegsForHFA(fir::RecordType ty) { RecordType::TypeList types = ty.getTypeList(); if (types.empty() || types.size() > 4) - return false; + return std::nullopt; std::optional> flatTypes = flattenTypeList(types); if (!flatTypes || flatTypes->size() > 4) { - return false; + return std::nullopt; } if (!isa_real(flatTypes->front())) { - return false; + return std::nullopt; } - return llvm::all_equal(*flatTypes); + return llvm::all_equal(*flatTypes) ? std::optional{flatTypes->size()} + : std::nullopt; } - // AArch64 procedure call ABI: - // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing - CodeGenSpecifics::Marshalling - structReturnType(mlir::Location loc, fir::RecordType ty) const override { - CodeGenSpecifics::Marshalling marshal; + struct NRegs { + int n{0}; + bool isSimd{false}; + }; - if (isHFA(ty)) { - // Just return the existing record type - marshal.emplace_back(ty, AT{}); - return marshal; + NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const { + if (std::optional size = usedRegsForHFA(type)) + return {*size, true}; + + auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash( + loc, type, getDataLayout(), kindMap); + + if (size <= 16) + return {static_cast((size + 7) / 8), false}; + + // Pass on the stack, i.e. no registers used + return {}; + } + + NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const { + return llvm::TypeSwitch(type) + .Case([&](auto intTy) { + return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false}; + }) + .Case([&](auto) { return NRegs{1, true}; }) + .Case([&](auto) { return NRegs{2, true}; }) + .Case([&](auto) { return NRegs{1, false}; }) + .Case([&](auto) { return NRegs{1, false}; }) + .Case([&](auto ty) { + assert(ty.getShape().size() == 1 && + "invalid array dimensions in BIND(C)"); + NRegs nregs = usedRegsForType(loc, ty.getEleTy()); + nregs.n *= ty.getShape()[0]; + return nregs; + }) + .Case( + [&](auto ty) { return usedRegsForRecordType(loc, ty); }) + .Case([&](auto) { + TODO(loc, "passing vector argument to C by value is not supported"); + return NRegs{}; + }); + } + + bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type, + const Marshalling &previousArguments) const { + int availIntRegisters = 8; + int availSIMDRegisters = 8; + + // Check previous arguments to see how many registers are used already + for (auto [type, attr] : previousArguments) { + if (availIntRegisters <= 0 || availSIMDRegisters <= 0) + break; + + if (attr.isByVal()) + continue; // Previous argument passed on the stack + + NRegs nregs = usedRegsForType(loc, type); + if (nregs.isSimd) + availSIMDRegisters -= nregs.n; + else + availIntRegisters -= nregs.n; } - auto [size, align] = + NRegs nregs = usedRegsForRecordType(loc, type); + + if (nregs.isSimd) + return nregs.n <= availSIMDRegisters; + + return nregs.n <= availIntRegisters; + } + + CodeGenSpecifics::Marshalling + passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const { + CodeGenSpecifics::Marshalling marshal; + auto sizeAndAlign = fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap); + // The stack is always 8 byte aligned + unsigned short align = + std::max(sizeAndAlign.second, static_cast(8)); + marshal.emplace_back(fir::ReferenceType::get(ty), + AT{align, /*byval=*/!isResult, /*sret=*/isResult}); + return marshal; + } - // return in registers if size <= 16 bytes - if (size <= 16) { - std::size_t dwordSize = (size + 7) / 8; - auto newTy = fir::SequenceType::get( - dwordSize, mlir::IntegerType::get(ty.getContext(), 64)); - marshal.emplace_back(newTy, AT{}); - return marshal; + CodeGenSpecifics::Marshalling + structType(mlir::Location loc, fir::RecordType type, bool isResult) const { + NRegs nregs = usedRegsForRecordType(loc, type); + + // If the type needs no registers it must need to be passed on the stack + if (nregs.n == 0) + return passOnTheStack(loc, type, isResult); + + CodeGenSpecifics::Marshalling marshal; + + mlir::Type pcsType; + if (nregs.isSimd) { + pcsType = type; + } else { + pcsType = fir::SequenceType::get( + nregs.n, mlir::IntegerType::get(type.getContext(), 64)); } - unsigned short stackAlign = std::max(align, 8u); - marshal.emplace_back(fir::ReferenceType::get(ty), - AT{stackAlign, false, true}); + marshal.emplace_back(pcsType, AT{}); return marshal; } + + CodeGenSpecifics::Marshalling + structArgumentType(mlir::Location loc, fir::RecordType ty, + const Marshalling &previousArguments) const override { + if (!hasEnoughRegisters(loc, ty, previousArguments)) { + return passOnTheStack(loc, ty, /*isResult=*/false); + } + + return structType(loc, ty, /*isResult=*/false); + } + + CodeGenSpecifics::Marshalling + structReturnType(mlir::Location loc, fir::RecordType ty) const override { + return structType(loc, ty, /*isResult=*/true); + } }; } // namespace diff --git a/flang/test/Fir/struct-passing-aarch64-byval.fir b/flang/test/Fir/struct-passing-aarch64-byval.fir new file mode 100644 index 0000000000000..27143459dde2f --- /dev/null +++ b/flang/test/Fir/struct-passing-aarch64-byval.fir @@ -0,0 +1,73 @@ +// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types). +// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s + +// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>) +func.func private @small_i32(!fir.type) +// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>) +func.func private @small_i64(!fir.type) +// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>) +func.func private @small_mixed(!fir.type) +// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>) +func.func private @small_non_hfa(!fir.type) + +// CHECK-LABEL: func.func private @hfa_f16(!fir.type) +func.func private @hfa_f16(!fir.type) +// CHECK-LABEL: func.func private @hfa_bf16(!fir.type) +func.func private @hfa_bf16(!fir.type) +// CHECK-LABEL: func.func private @hfa_f32(!fir.type) +func.func private @hfa_f32(!fir.type) +// CHECK-LABEL: func.func private @hfa_f64(!fir.type) +func.func private @hfa_f64(!fir.type) +// CHECK-LABEL: func.func private @hfa_f128(!fir.type) +func.func private @hfa_f128(!fir.type) + +// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>) +func.func private @multi_small_integer(!fir.type, !fir.type) +// CHECK-LABEL: func.func private @multi_hfas(!fir.type, !fir.type) +func.func private @multi_hfas(!fir.type, !fir.type) +// CHECK-LABEL: func.func private @multi_mixed(!fir.type, !fir.array<2xi64>, !fir.type, !fir.array<2xi64>) +func.func private @multi_mixed(!fir.type,!fir.type,!fir.type,!fir.type) + +// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>) +func.func private @int_max(!fir.type, + !fir.type, + !fir.type, + !fir.type) +// CHECK-LABEL: func.func private @hfa_max(!fir.type, !fir.type) +func.func private @hfa_max(!fir.type, !fir.type) +// CHECK-LABEL: func.func private @max(!fir.type, +// CHECK-SAME: !fir.type, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>) +func.func private @max(!fir.type, + !fir.type, + !fir.type, + !fir.type, + !fir.type, + !fir.type) + + +// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.array<2xi64>, +// CHECK-SAME: !fir.ref> {{{.*}}, llvm.byval = !fir.type}) +func.func private @too_many_int(!fir.type, + !fir.type, + !fir.type, + !fir.type, + !fir.type) +// CHECK-LABEL: func.func private @too_many_hfa(!fir.type, +// CHECK-SAME: !fir.type, +// CHECK-SAME: !fir.ref> {{{.*}}, llvm.byval = !fir.type}) +func.func private @too_many_hfa(!fir.type, + !fir.type, + !fir.type) + +// CHECK-LABEL: func.func private @too_big(!fir.ref}>> {{{.*}}, llvm.byval = !fir.type}>}) +func.func private @too_big(!fir.type}>)