diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 0bd4b6d1a835a..912d9ac404052 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1706,6 +1706,34 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I, return BinaryOperator::CreateFMulFMF(Op0, Pow, &I); } +/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv +/// instruction. +static Instruction *foldFDivSqrtDivisor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + // X / sqrt(Y / Z) --> X * sqrt(Z / Y) + if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) + return nullptr; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + auto *II = dyn_cast(Op1); + if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() || + !II->hasAllowReassoc() || !II->hasAllowReciprocal()) + return nullptr; + + Value *Y, *Z; + auto *DivOp = dyn_cast(II->getOperand(0)); + if (!DivOp) + return nullptr; + if (!match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) + return nullptr; + if (!DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() || + !DivOp->hasOneUse()) + return nullptr; + Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp); + Value *NewSqrt = + Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II); + return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I); +} + Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { Module *M = I.getModule(); @@ -1813,6 +1841,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { if (Instruction *Mul = foldFDivPowDivisor(I, Builder)) return Mul; + if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder)) + return Mul; + // pow(X, Y) / X --> pow(X, Y-1) if (I.hasAllowReassoc() && match(Op0, m_OneUse(m_Intrinsic(m_Specific(Op1), diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll index 346271be7da76..9f030c5ebf7bb 100644 --- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll @@ -6,9 +6,9 @@ declare double @llvm.sqrt.f64(double) define double @sqrt_div_fast(double %x, double %y, double %z) { ; CHECK-LABEL: @sqrt_div_fast( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]]) -; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]] +; CHECK-NEXT: [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]]) +; CHECK-NEXT: [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret double [[DIV1]] ; entry: @@ -36,9 +36,9 @@ entry: define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) { ; CHECK-LABEL: @sqrt_div_reassoc_arcp( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]]) -; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]] +; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]]) +; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret double [[DIV1]] ; entry: @@ -96,9 +96,9 @@ entry: define double @sqrt_div_arcp_missing(double %x, double %y, double %z) { ; CHECK-LABEL: @sqrt_div_arcp_missing( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc double [[Y:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]]) -; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]] +; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc double [[Z:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]]) +; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret double [[DIV1]] ; entry: @@ -173,3 +173,19 @@ entry: ret double %div1 } +define float @sqrt_non_div_operator(float %a) { +; CHECK-LABEL: @sqrt_non_div_operator( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CONV:%.*]] = fpext float [[A:%.*]] to double +; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[CONV]]) +; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[CONV]], [[SQRT]] +; CHECK-NEXT: [[CONV2:%.*]] = fptrunc double [[DIV]] to float +; CHECK-NEXT: ret float [[CONV2]] +; +entry: + %conv = fpext float %a to double + %sqrt = call fast double @llvm.sqrt.f64(double %conv) + %div = fdiv fast double %conv, %sqrt + %conv2 = fptrunc double %div to float + ret float %conv2 +}