From 397d3a20ccc63cc65ed460f2062c33a3d739d4e3 Mon Sep 17 00:00:00 2001 From: Amara Emerson Date: Fri, 1 Sep 2023 02:23:18 -0700 Subject: [PATCH 1/2] [GlobalISel] Add constant folding support for G_FMA/G_FMAD in the combiner. --- .../llvm/CodeGen/GlobalISel/CombinerHelper.h | 3 ++ .../include/llvm/Target/GlobalISel/Combine.td | 10 +++- .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 38 +++++++++++++++ .../GlobalISel/combine-constant-fold-fma.mir | 48 +++++++++++++++++++ 4 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/AArch64/GlobalISel/combine-constant-fold-fma.mir diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h index 1708ef9436979..b7c0cd4fc47fa 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -657,6 +657,9 @@ class CombinerHelper { /// Do constant FP folding when opportunities are exposed after MIR building. bool matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &MatchInfo); + /// Constant fold G_FMA/G_FMAD. + bool matchConstantFoldFMA(MachineInstr &MI, ConstantFP *&MatchInfo); + /// \returns true if it is possible to narrow the width of a scalar binop /// feeding a G_AND instruction \p MI. bool matchNarrowBinopFeedingAnd(MachineInstr &MI, BuildFnTy &MatchInfo); diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td index e3634e50ec741..7e0691e1ee950 100644 --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -985,6 +985,13 @@ def constant_fold_fp_binop : GICombineRule< [{ return Helper.matchConstantFoldFPBinOp(*${d}, ${matchinfo}); }]), (apply [{ Helper.replaceInstWithFConstant(*${d}, ${matchinfo}); }])>; + +def constant_fold_fma : GICombineRule< + (defs root:$d, constantfp_matchinfo:$matchinfo), + (match (wip_match_opcode G_FMAD, G_FMA):$d, + [{ return Helper.matchConstantFoldFMA(*${d}, ${matchinfo}); }]), + (apply [{ Helper.replaceInstWithFConstant(*${d}, ${matchinfo}); }])>; + def constant_fold_cast_op : GICombineRule< (defs root:$d, apint_matchinfo:$matchinfo), (match (wip_match_opcode G_ZEXT, G_SEXT, G_ANYEXT):$d, @@ -1253,7 +1260,8 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines, const_combines, xor_of_and_with_same_reg, ptr_add_with_zero, shift_immed_chain, shift_of_shifted_logic_chain, load_or_combine, div_rem_to_divrem, funnel_shift_combines, commute_shift, - form_bitfield_extract, constant_fold_binops, constant_fold_cast_op, fabs_fneg_fold, + form_bitfield_extract, constant_fold_binops, constant_fold_fma, + constant_fold_cast_op, fabs_fneg_fold, intdiv_combines, mulh_combines, redundant_neg_operands, and_or_disjoint_mask, fma_combines, fold_binop_into_select, sub_add_reg, select_to_minmax, redundant_binop_in_equality, diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 9030efb9c07b6..6c8e439a15d77 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" @@ -4621,6 +4622,43 @@ bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &Mat return true; } +bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, + ConstantFP *&MatchInfo) { + unsigned Opc = MI.getOpcode(); + auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); + + const ConstantFP *Op3Cst = getConstantFPVRegVal(Op3, MRI); + if (!Op3Cst) + return false; + + const ConstantFP *Op2Cst = getConstantFPVRegVal(Op2, MRI); + if (!Op2Cst) + return false; + + const ConstantFP *Op1Cst = getConstantFPVRegVal(Op1, MRI); + if (!Op1Cst) + return false; + + APFloat Op1F = Op1Cst->getValueAPF(); + APFloat Op2F = Op2Cst->getValueAPF(); + APFloat Op3F = Op3Cst->getValueAPF(); + + switch (Opc) { + case TargetOpcode::G_FMA: + Op1F.fusedMultiplyAdd(Op2F, Op3F, APFloat::rmNearestTiesToEven); + MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F); + break; + case TargetOpcode::G_FMAD: { + APFloat Res = (Op1F * Op2F) + Op3F; + MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Res); + break; + } + default: + llvm_unreachable("Unexpected opcode"); + } + return true; +} + bool CombinerHelper::matchNarrowBinopFeedingAnd( MachineInstr &MI, std::function &MatchInfo) { // Look for a binop feeding into an AND with a mask: diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/combine-constant-fold-fma.mir b/llvm/test/CodeGen/AArch64/GlobalISel/combine-constant-fold-fma.mir new file mode 100644 index 0000000000000..73e03282f6c7f --- /dev/null +++ b/llvm/test/CodeGen/AArch64/GlobalISel/combine-constant-fold-fma.mir @@ -0,0 +1,48 @@ +# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py +# RUN: llc -mtriple aarch64 -run-pass=aarch64-prelegalizer-combiner -verify-machineinstrs %s -o - | FileCheck %s + +--- +name: fma +liveins: + - { reg: '$d0' } +body: | + bb.1.entry: + liveins: $d0 + + ; CHECK-LABEL: name: fma + ; CHECK: liveins: $d0 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: %res:_(s64) = G_FCONSTANT double 8.100000e+01 + ; CHECK-NEXT: $d0 = COPY %res(s64) + ; CHECK-NEXT: RET_ReallyLR implicit $d0 + %a:_(s64) = G_FCONSTANT double 40.0 + %b:_(s64) = G_FCONSTANT double 2.0 + %c:_(s64) = G_FCONSTANT double 1.0 + %res:_(s64) = G_FMA %a, %b, %c + $d0 = COPY %res(s64) + RET_ReallyLR implicit $d0 + +... + +--- +name: fmad +liveins: + - { reg: '$d0' } +body: | + bb.1.entry: + liveins: $d0 + + ; CHECK-LABEL: name: fmad + ; CHECK: liveins: $d0 + ; CHECK-NEXT: {{ $}} + ; CHECK-NEXT: %res:_(s64) = G_FCONSTANT double 8.100000e+01 + ; CHECK-NEXT: $d0 = COPY %res(s64) + ; CHECK-NEXT: RET_ReallyLR implicit $d0 + %a:_(s64) = G_FCONSTANT double 40.0 + %b:_(s64) = G_FCONSTANT double 2.0 + %c:_(s64) = G_FCONSTANT double 1.0 + %res:_(s64) = G_FMAD %a, %b, %c + $d0 = COPY %res(s64) + RET_ReallyLR implicit $d0 + +... From d095b9fedc10fd3da6a9a94a94f14a9442e37d06 Mon Sep 17 00:00:00 2001 From: Amara Emerson Date: Fri, 8 Sep 2023 02:29:44 -0700 Subject: [PATCH 2/2] Address comments. --- .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 22 +++++-------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp index 6c8e439a15d77..2ce6895042409 100644 --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -4624,7 +4624,8 @@ bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, ConstantFP* &Mat bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, ConstantFP *&MatchInfo) { - unsigned Opc = MI.getOpcode(); + assert(MI.getOpcode() == TargetOpcode::G_FMA || + MI.getOpcode() == TargetOpcode::G_FMAD); auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); const ConstantFP *Op3Cst = getConstantFPVRegVal(Op3, MRI); @@ -4640,22 +4641,9 @@ bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, return false; APFloat Op1F = Op1Cst->getValueAPF(); - APFloat Op2F = Op2Cst->getValueAPF(); - APFloat Op3F = Op3Cst->getValueAPF(); - - switch (Opc) { - case TargetOpcode::G_FMA: - Op1F.fusedMultiplyAdd(Op2F, Op3F, APFloat::rmNearestTiesToEven); - MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F); - break; - case TargetOpcode::G_FMAD: { - APFloat Res = (Op1F * Op2F) + Op3F; - MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Res); - break; - } - default: - llvm_unreachable("Unexpected opcode"); - } + Op1F.fusedMultiplyAdd(Op2Cst->getValueAPF(), Op3Cst->getValueAPF(), + APFloat::rmNearestTiesToEven); + MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F); return true; }