From 78e1634b60254f3c71fea51740ba65b57270e8fe Mon Sep 17 00:00:00 2001 From: John Brawn Date: Fri, 1 Nov 2024 15:56:51 +0000 Subject: [PATCH] [InstCombine] Eliminate fptrunc/fpext if fast math flags allow it When expressions of a floating-point type are evaluated at a higher precision (e.g. _Float16 being evaluated as float) this results in a fptrunc then fpext between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these cast instructions. --- .../InstCombine/InstCombineCasts.cpp | 25 +++++++ llvm/test/Transforms/InstCombine/fpextend.ll | 73 +++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 7221c987b9821..6b8c362c54456 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1940,6 +1940,31 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); } + // fpext (fptrunc(x)) -> x, if the fast math flags allow it + if (auto *Trunc = dyn_cast(Src)) { + // Whether this transformation is possible depends on the fast math flags of + // both the fpext and fptrunc. + FastMathFlags SrcFlags = Trunc->getFastMathFlags(); + FastMathFlags DstFlags = FPExt.getFastMathFlags(); + // Trunc can introduce inf and change the encoding of a nan, so the + // destination must have the nnan and ninf flags to indicate that we don't + // need to care about that. We are also removing a rounding step, and that + // requires both the source and destination to allow contraction. + if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() && + DstFlags.allowContract()) { + Value *TruncSrc = Trunc->getOperand(0); + // We do need a single cast if the source and destination types don't + // match. + if (TruncSrc->getType() != Ty) { + Instruction *Ret = CastInst::CreateFPCast(TruncSrc, Ty); + Ret->copyFastMathFlags(&FPExt); + return Ret; + } else { + return replaceInstUsesWith(FPExt, TruncSrc); + } + } + } + return commonCastTransforms(FPExt); } diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll index c9adbe10d8db4..c18238d972192 100644 --- a/llvm/test/Transforms/InstCombine/fpextend.ll +++ b/llvm/test/Transforms/InstCombine/fpextend.ll @@ -448,3 +448,76 @@ define bfloat @bf16_frem(bfloat %x) { %t3 = fptrunc float %t2 to bfloat ret bfloat %t3 } + +define double @fptrunc_fpextend_nofast(double %x, double %y, double %z) { +; CHECK-LABEL: @fptrunc_fpextend_nofast( +; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float +; CHECK-NEXT: [[EXT:%.*]] = fpext float [[TRUNC]] to double +; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]] +; CHECK-NEXT: ret double [[ADD2]] +; + %add1 = fadd double %x, %y + %trunc = fptrunc double %add1 to float + %ext = fpext float %trunc to double + %add2 = fadd double %ext, %z + ret double %add2 +} + +define double @fptrunc_fpextend_fast(double %x, double %y, double %z) { +; CHECK-LABEL: @fptrunc_fpextend_fast( +; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[ADD1]], [[Z:%.*]] +; CHECK-NEXT: ret double [[ADD2]] +; + %add1 = fadd double %x, %y + %trunc = fptrunc contract double %add1 to float + %ext = fpext nnan ninf contract float %trunc to double + %add2 = fadd double %ext, %z + ret double %add2 +} + +define float @fptrunc_fpextend_result_smaller(double %x, double %y, float %z) { +; CHECK-LABEL: @fptrunc_fpextend_result_smaller( +; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[EXT:%.*]] = fptrunc nnan ninf contract double [[ADD1]] to float +; CHECK-NEXT: [[ADD2:%.*]] = fadd float [[Z:%.*]], [[EXT]] +; CHECK-NEXT: ret float [[ADD2]] +; + %add1 = fadd double %x, %y + %trunc = fptrunc contract double %add1 to half + %ext = fpext nnan ninf contract half %trunc to float + %add2 = fadd float %ext, %z + ret float %add2 +} + +define double @fptrunc_fpextend_result_larger(float %x, float %y, double %z) { +; CHECK-LABEL: @fptrunc_fpextend_result_larger( +; CHECK-NEXT: [[ADD1:%.*]] = fadd float [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[EXT:%.*]] = fpext nnan ninf contract float [[ADD1]] to double +; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]] +; CHECK-NEXT: ret double [[ADD2]] +; + %add1 = fadd float %x, %y + %trunc = fptrunc contract float %add1 to half + %ext = fpext nnan ninf contract half %trunc to double + %add2 = fadd double %ext, %z + ret double %add2 +} + +define double @fptrunc_fpextend_multiple_use(double %x, double %y, double %a, double %b) { +; CHECK-LABEL: @fptrunc_fpextend_multiple_use( +; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[ADD1]], [[A:%.*]] +; CHECK-NEXT: [[ADD3:%.*]] = fadd double [[ADD1]], [[B:%.*]] +; CHECK-NEXT: [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]] +; CHECK-NEXT: ret double [[MUL]] +; + %add1 = fadd double %x, %y + %trunc = fptrunc contract double %add1 to float + %ext = fpext nnan ninf contract float %trunc to double + %add2 = fadd double %ext, %a + %add3 = fadd double %ext, %b + %mul = fmul double %add2, %add3 + ret double %mul +}