diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 4eab357f1b33b..c43a1b5c1b2aa 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1247,13 +1247,16 @@ class TargetTransformInfo { /// cases or optimizations based on those values. /// \p CxtI is the optional original context instruction, if one exists, to /// provide even more information. + /// \p TLibInfo is used to search for platform specific vector library + /// functions for instructions that might be converted to calls (e.g. frem). InstructionCost getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput, TTI::OperandValueInfo Opd1Info = {TTI::OK_AnyValue, TTI::OP_None}, TTI::OperandValueInfo Opd2Info = {TTI::OK_AnyValue, TTI::OP_None}, ArrayRef Args = ArrayRef(), - const Instruction *CxtI = nullptr) const; + const Instruction *CxtI = nullptr, + const TargetLibraryInfo *TLibInfo = nullptr) const; /// Returns the cost estimation for alternating opcode pattern that can be /// lowered to a single instruction on the target. In X86 this is for the diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 15311be4dba27..2e0bd84339659 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -9,6 +9,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" @@ -874,7 +875,22 @@ TargetTransformInfo::getOperandInfo(const Value *V) { InstructionCost TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, OperandValueInfo Op1Info, OperandValueInfo Op2Info, - ArrayRef Args, const Instruction *CxtI) const { + ArrayRef Args, const Instruction *CxtI, + const TargetLibraryInfo *TLibInfo) const { + + // Use call cost for frem intructions that have platform specific vector math + // functions, as those will be replaced with calls later by SelectionDAG or + // ReplaceWithVecLib pass. + if (TLibInfo && Opcode == Instruction::FRem) { + VectorType *VecTy = dyn_cast(Ty); + LibFunc Func; + if (VecTy && + TLibInfo->getLibFunc(Instruction::FRem, Ty->getScalarType(), Func) && + TLibInfo->isFunctionVectorizable(TLibInfo->getName(Func), + VecTy->getElementCount())) + return getCallInstrCost(nullptr, VecTy, {VecTy, VecTy}, CostKind); + } + InstructionCost Cost = TTIImpl->getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, Op2Info, diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index edaad4d033bdf..52b992b19e4b0 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6911,25 +6911,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector Operands(I->operand_values()); - auto InstrCost = TTI.getArithmeticInstrCost( + return TTI.getArithmeticInstrCost( I->getOpcode(), VectorTy, CostKind, {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, - Op2Info, Operands, I); - - // Some targets can replace frem with vector library calls. - InstructionCost VecCallCost = InstructionCost::getInvalid(); - if (I->getOpcode() == Instruction::FRem) { - LibFunc Func; - if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) && - TLI->isFunctionVectorizable(TLI->getName(Func), VF)) { - SmallVector OpTypes; - for (auto &Op : I->operands()) - OpTypes.push_back(Op->getType()); - VecCallCost = - TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind); - } - } - return std::min(InstrCost, VecCallCost); + Op2Info, Operands, I, TLI); } case Instruction::FNeg: { return TTI.getArithmeticInstrCost( diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 7b99c3ac8c55a..aa98004c38c73 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -8852,7 +8852,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0)); TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx)); return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info, - Op2Info) + + Op2Info, std::nullopt, nullptr, TLI) + CommonCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll new file mode 100644 index 0000000000000..a38f4bdc4640e --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/slp-frem.ll @@ -0,0 +1,55 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt < %s -S -mtriple=aarch64 -vector-library=ArmPL -passes=slp-vectorizer | FileCheck %s + +@a = common global ptr null, align 8 + +define void @frem_v2double() { +; CHECK-LABEL: define void @frem_v2double() { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load <2 x double>, ptr @a, align 8 +; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr @a, align 8 +; CHECK-NEXT: [[TMP2:%.*]] = frem <2 x double> [[TMP0]], [[TMP1]] +; CHECK-NEXT: store <2 x double> [[TMP2]], ptr @a, align 8 +; CHECK-NEXT: ret void +; +entry: + %a0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8 + %a1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8 + %b0 = load double, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8 + %b1 = load double, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8 + %r0 = frem double %a0, %b0 + %r1 = frem double %a1, %b1 + store double %r0, ptr getelementptr inbounds (double, ptr @a, i64 0), align 8 + store double %r1, ptr getelementptr inbounds (double, ptr @a, i64 1), align 8 + ret void +} + +define void @frem_v4float() { +; CHECK-LABEL: define void @frem_v4float() { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr @a, align 8 +; CHECK-NEXT: [[TMP1:%.*]] = load <4 x float>, ptr @a, align 8 +; CHECK-NEXT: [[TMP2:%.*]] = frem <4 x float> [[TMP0]], [[TMP1]] +; CHECK-NEXT: store <4 x float> [[TMP2]], ptr @a, align 8 +; CHECK-NEXT: ret void +; +entry: + %a0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8 + %a1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8 + %a2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8 + %a3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8 + %b0 = load float, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8 + %b1 = load float, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8 + %b2 = load float, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8 + %b3 = load float, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8 + %r0 = frem float %a0, %b0 + %r1 = frem float %a1, %b1 + %r2 = frem float %a2, %b2 + %r3 = frem float %a3, %b3 + store float %r0, ptr getelementptr inbounds (float, ptr @a, i64 0), align 8 + store float %r1, ptr getelementptr inbounds (float, ptr @a, i64 1), align 8 + store float %r2, ptr getelementptr inbounds (float, ptr @a, i64 2), align 8 + store float %r3, ptr getelementptr inbounds (float, ptr @a, i64 3), align 8 + ret void +} +