diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index c4582df89213d..5241a599e0b25 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -915,24 +915,22 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (It == VL.end()) return InstructionsState::invalid(); - Value *V = *It; + Instruction *MainOp = cast(*It); unsigned InstCnt = std::count_if(It, VL.end(), IsaPred); - if ((VL.size() > 2 && !isa(V) && InstCnt < VL.size() / 2) || + if ((VL.size() > 2 && !isa(MainOp) && InstCnt < VL.size() / 2) || (VL.size() == 2 && InstCnt < 2)) return InstructionsState::invalid(); - bool IsCastOp = isa(V); - bool IsBinOp = isa(V); - bool IsCmpOp = isa(V); - CmpInst::Predicate BasePred = - IsCmpOp ? cast(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE; - unsigned Opcode = cast(V)->getOpcode(); + bool IsCastOp = isa(MainOp); + bool IsBinOp = isa(MainOp); + bool IsCmpOp = isa(MainOp); + CmpInst::Predicate BasePred = IsCmpOp ? cast(MainOp)->getPredicate() + : CmpInst::BAD_ICMP_PREDICATE; + Instruction *AltOp = MainOp; + unsigned Opcode = MainOp->getOpcode(); unsigned AltOpcode = Opcode; - unsigned AltIndex = std::distance(VL.begin(), It); - bool SwappedPredsCompatible = [&]() { - if (!IsCmpOp) - return false; + bool SwappedPredsCompatible = IsCmpOp && [&]() { SetVector UniquePreds, UniqueNonSwappedPreds; UniquePreds.insert(BasePred); UniqueNonSwappedPreds.insert(BasePred); @@ -955,18 +953,18 @@ static InstructionsState getSameOpcode(ArrayRef VL, }(); // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). - auto *IBase = cast(V); Intrinsic::ID BaseID = 0; SmallVector BaseMappings; - if (auto *CallBase = dyn_cast(IBase)) { + if (auto *CallBase = dyn_cast(MainOp)) { BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI); BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase); if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty()) return InstructionsState::invalid(); } bool AnyPoison = InstCnt != VL.size(); - for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { - auto *I = dyn_cast(VL[Cnt]); + // Skip MainOp. + for (Value *V : iterator_range(It + 1, VL.end())) { + auto *I = dyn_cast(V); if (!I) continue; @@ -982,11 +980,11 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) && isValidForAlternation(Opcode)) { AltOpcode = InstOpcode; - AltIndex = Cnt; + AltOp = I; continue; } } else if (IsCastOp && isa(I)) { - Value *Op0 = IBase->getOperand(0); + Value *Op0 = MainOp->getOperand(0); Type *Ty0 = Op0->getType(); Value *Op1 = I->getOperand(0); Type *Ty1 = Op1->getType(); @@ -998,12 +996,12 @@ static InstructionsState getSameOpcode(ArrayRef VL, isValidForAlternation(InstOpcode) && "Cast isn't safe for alternation, logic needs to be updated!"); AltOpcode = InstOpcode; - AltIndex = Cnt; + AltOp = I; continue; } } - } else if (auto *Inst = dyn_cast(VL[Cnt]); Inst && IsCmpOp) { - auto *BaseInst = cast(V); + } else if (auto *Inst = dyn_cast(I); Inst && IsCmpOp) { + auto *BaseInst = cast(MainOp); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); if (Ty0 == Ty1) { @@ -1017,21 +1015,21 @@ static InstructionsState getSameOpcode(ArrayRef VL, CmpInst::Predicate SwappedCurrentPred = CmpInst::getSwappedPredicate(CurrentPred); - if ((E == 2 || SwappedPredsCompatible) && + if ((VL.size() == 2 || SwappedPredsCompatible) && (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) continue; if (isCmpSameOrSwapped(BaseInst, Inst, TLI)) continue; - auto *AltInst = cast(VL[AltIndex]); - if (AltIndex) { + auto *AltInst = cast(AltOp); + if (MainOp != AltOp) { if (isCmpSameOrSwapped(AltInst, Inst, TLI)) continue; } else if (BasePred != CurrentPred) { assert( isValidForAlternation(InstOpcode) && "CmpInst isn't safe for alternation, logic needs to be updated!"); - AltIndex = Cnt; + AltOp = I; continue; } CmpInst::Predicate AltPred = AltInst->getPredicate(); @@ -1045,17 +1043,17 @@ static InstructionsState getSameOpcode(ArrayRef VL, "CastInst."); if (auto *Gep = dyn_cast(I)) { if (Gep->getNumOperands() != 2 || - Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType()) + Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType()) return InstructionsState::invalid(); } else if (auto *EI = dyn_cast(I)) { if (!isVectorLikeInstWithConstOps(EI)) return InstructionsState::invalid(); } else if (auto *LI = dyn_cast(I)) { - auto *BaseLI = cast(IBase); + auto *BaseLI = cast(MainOp); if (!LI->isSimple() || !BaseLI->isSimple()) return InstructionsState::invalid(); } else if (auto *Call = dyn_cast(I)) { - auto *CallBase = cast(IBase); + auto *CallBase = cast(MainOp); if (Call->getCalledFunction() != CallBase->getCalledFunction()) return InstructionsState::invalid(); if (Call->hasOperandBundles() && @@ -1085,8 +1083,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); } - return InstructionsState(cast(V), - cast(VL[AltIndex])); + return InstructionsState(MainOp, AltOp); } /// \returns true if all of the values in \p VL have the same type or false