diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 8ff70fdb1180b..6360ddb57007d 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -916,22 +916,24 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (It == VL.end()) return InstructionsState::invalid(); - Instruction *MainOp = cast(*It); + Value *V = *It; unsigned InstCnt = std::count_if(It, VL.end(), IsaPred); - if ((VL.size() > 2 && !isa(MainOp) && InstCnt < VL.size() / 2) || + if ((VL.size() > 2 && !isa(V) && InstCnt < VL.size() / 2) || (VL.size() == 2 && InstCnt < 2)) return InstructionsState::invalid(); - bool IsCastOp = isa(MainOp); - bool IsBinOp = isa(MainOp); - bool IsCmpOp = isa(MainOp); - CmpInst::Predicate BasePred = IsCmpOp ? cast(MainOp)->getPredicate() - : CmpInst::BAD_ICMP_PREDICATE; - Instruction *AltOp = MainOp; - unsigned Opcode = MainOp->getOpcode(); + bool IsCastOp = isa(V); + bool IsBinOp = isa(V); + bool IsCmpOp = isa(V); + CmpInst::Predicate BasePred = + IsCmpOp ? cast(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE; + unsigned Opcode = cast(V)->getOpcode(); unsigned AltOpcode = Opcode; + unsigned AltIndex = std::distance(VL.begin(), It); - bool SwappedPredsCompatible = IsCmpOp && [&]() { + bool SwappedPredsCompatible = [&]() { + if (!IsCmpOp) + return false; SetVector UniquePreds, UniqueNonSwappedPreds; UniquePreds.insert(BasePred); UniqueNonSwappedPreds.insert(BasePred); @@ -954,18 +956,18 @@ static InstructionsState getSameOpcode(ArrayRef VL, }(); // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). + auto *IBase = cast(V); Intrinsic::ID BaseID = 0; SmallVector BaseMappings; - if (auto *CallBase = dyn_cast(MainOp)) { + if (auto *CallBase = dyn_cast(IBase)) { BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI); BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase); if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty()) return InstructionsState::invalid(); } bool AnyPoison = InstCnt != VL.size(); - // Skip MainOp. - for (Value *V : iterator_range(It + 1, VL.end())) { - auto *I = dyn_cast(V); + for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { + auto *I = dyn_cast(VL[Cnt]); if (!I) continue; @@ -981,11 +983,11 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) && isValidForAlternation(Opcode)) { AltOpcode = InstOpcode; - AltOp = I; + AltIndex = Cnt; continue; } } else if (IsCastOp && isa(I)) { - Value *Op0 = MainOp->getOperand(0); + Value *Op0 = IBase->getOperand(0); Type *Ty0 = Op0->getType(); Value *Op1 = I->getOperand(0); Type *Ty1 = Op1->getType(); @@ -997,12 +999,12 @@ static InstructionsState getSameOpcode(ArrayRef VL, isValidForAlternation(InstOpcode) && "Cast isn't safe for alternation, logic needs to be updated!"); AltOpcode = InstOpcode; - AltOp = I; + AltIndex = Cnt; continue; } } - } else if (auto *Inst = dyn_cast(I); Inst && IsCmpOp) { - auto *BaseInst = cast(MainOp); + } else if (auto *Inst = dyn_cast(VL[Cnt]); Inst && IsCmpOp) { + auto *BaseInst = cast(V); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); if (Ty0 == Ty1) { @@ -1016,21 +1018,21 @@ static InstructionsState getSameOpcode(ArrayRef VL, CmpInst::Predicate SwappedCurrentPred = CmpInst::getSwappedPredicate(CurrentPred); - if ((VL.size() == 2 || SwappedPredsCompatible) && + if ((E == 2 || SwappedPredsCompatible) && (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) continue; if (isCmpSameOrSwapped(BaseInst, Inst, TLI)) continue; - auto *AltInst = cast(AltOp); - if (MainOp != AltOp) { + auto *AltInst = cast(VL[AltIndex]); + if (AltIndex) { if (isCmpSameOrSwapped(AltInst, Inst, TLI)) continue; } else if (BasePred != CurrentPred) { assert( isValidForAlternation(InstOpcode) && "CmpInst isn't safe for alternation, logic needs to be updated!"); - AltOp = I; + AltIndex = Cnt; continue; } CmpInst::Predicate AltPred = AltInst->getPredicate(); @@ -1044,17 +1046,17 @@ static InstructionsState getSameOpcode(ArrayRef VL, "CastInst."); if (auto *Gep = dyn_cast(I)) { if (Gep->getNumOperands() != 2 || - Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType()) + Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType()) return InstructionsState::invalid(); } else if (auto *EI = dyn_cast(I)) { if (!isVectorLikeInstWithConstOps(EI)) return InstructionsState::invalid(); } else if (auto *LI = dyn_cast(I)) { - auto *BaseLI = cast(MainOp); + auto *BaseLI = cast(IBase); if (!LI->isSimple() || !BaseLI->isSimple()) return InstructionsState::invalid(); } else if (auto *Call = dyn_cast(I)) { - auto *CallBase = cast(MainOp); + auto *CallBase = cast(IBase); if (Call->getCalledFunction() != CallBase->getCalledFunction()) return InstructionsState::invalid(); if (Call->hasOperandBundles() && @@ -1084,7 +1086,8 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); } - return InstructionsState(MainOp, AltOp); + return InstructionsState(cast(V), + cast(VL[AltIndex])); } /// \returns true if all of the values in \p VL have the same type or false