diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4346f82fa5da9..c3163f70b847e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -119,7 +119,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, /// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC /// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, +static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, InstCombiner::BuilderTy &Builder, const SimplifyQuery &SQ) { const APInt *SelTC, *SelFC; @@ -129,36 +129,47 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // If this is a vector select, we need a vector compare. Type *SelType = Sel.getType(); - if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) + if (SelType->isVectorTy() != CondVal->getType()->isVectorTy()) return nullptr; Value *V; APInt AndMask; bool CreateAnd = false; - ICmpInst::Predicate Pred = Cmp->getPredicate(); - if (ICmpInst::isEquality(Pred)) { - if (!match(Cmp->getOperand(1), m_Zero())) - return nullptr; + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; - V = Cmp->getOperand(0); - const APInt *AndRHS; - if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) - return nullptr; + if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + if (ICmpInst::isEquality(Pred)) { + if (!match(CmpRHS, m_Zero())) + return nullptr; + + V = CmpLHS; + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; - AndMask = *AndRHS; - } else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0), - Cmp->getOperand(1), Pred)) { - assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?"); - AndMask = Res->Mask; - V = Res->X; - KnownBits Known = - computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel)); - AndMask &= Known.getMaxValue(); - if (!AndMask.isPowerOf2()) + AndMask = *AndRHS; + } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) { + assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?"); + AndMask = Res->Mask; + V = Res->X; + KnownBits Known = + computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel)); + AndMask &= Known.getMaxValue(); + if (!AndMask.isPowerOf2()) + return nullptr; + + Pred = Res->Pred; + CreateAnd = true; + } else { return nullptr; + } - Pred = Res->Pred; - CreateAnd = true; + } else if (auto *Trunc = dyn_cast(CondVal)) { + V = Trunc->getOperand(0); + AndMask = APInt(V->getType()->getScalarSizeInBits(), 1); + Pred = ICmpInst::ICMP_NE; + CreateAnd = !Trunc->hasNoUnsignedWrap(); } else { return nullptr; } @@ -176,7 +187,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, return nullptr; // If we have to create an 'and', then we must kill the cmp to not // increase the instruction count. - if (CreateAnd && !Cmp->hasOneUse()) + if (CreateAnd && !CondVal->hasOneUse()) return nullptr; // (V & AndMaskC) == 0 ? TC : FC --> TC | (V & AndMaskC) @@ -217,7 +228,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // a 'select' + 'icmp', then this transformation would result in more // instructions and potentially interfere with other folding. if (CreateAnd + ShouldNotVal + NeedShift + NeedZExtTrunc > - 1 + Cmp->hasOneUse()) + 1 + CondVal->hasOneUse()) return nullptr; // Insert the 'and' instruction on the input to the truncate. @@ -1961,9 +1972,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; - if (Value *V = foldSelectICmpAnd(SI, ICI, Builder, SQ)) - return replaceInstUsesWith(SI, V); - // NOTE: if we wanted to, this is where to detect integer MIN/MAX bool Changed = false; Value *TrueVal = SI.getTrueValue(); @@ -3961,6 +3969,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; + if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder, SQ)) + return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); diff --git a/llvm/test/Transforms/InstCombine/select-icmp-and.ll b/llvm/test/Transforms/InstCombine/select-icmp-and.ll index 8b61b55a62712..4309c603bba81 100644 --- a/llvm/test/Transforms/InstCombine/select-icmp-and.ll +++ b/llvm/test/Transforms/InstCombine/select-icmp-and.ll @@ -809,8 +809,8 @@ define i8 @select_bittest_to_xor(i8 %x) { define i8 @select_trunc_bittest_to_sub(i8 %x) { ; CHECK-LABEL: @select_trunc_bittest_to_sub( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4 +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 1 +; CHECK-NEXT: [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]] ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc i8 %x to i1 @@ -820,8 +820,7 @@ define i8 @select_trunc_bittest_to_sub(i8 %x) { define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) { ; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4 +; CHECK-NEXT: [[RET:%.*]] = sub i8 4, [[X:%.*]] ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc nuw i8 %x to i1 @@ -831,8 +830,8 @@ define i8 @select_trunc_nuw_bittest_to_sub(i8 %x) { define i8 @select_trunc_nsw_bittest_to_sub(i8 %x) { ; CHECK-LABEL: @select_trunc_nsw_bittest_to_sub( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4 +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], 1 +; CHECK-NEXT: [[RET:%.*]] = sub nuw nsw i8 4, [[TMP1]] ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc nsw i8 %x to i1 @@ -844,7 +843,7 @@ define i8 @select_trunc_nuw_bittest_to_sub_extra_use(i8 %x) { ; CHECK-LABEL: @select_trunc_nuw_bittest_to_sub_extra_use( ; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1 ; CHECK-NEXT: call void @use1(i1 [[TRUNC]]) -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 3, i8 4 +; CHECK-NEXT: [[RET:%.*]] = sub i8 4, [[X]] ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc nuw i8 %x to i1 @@ -868,8 +867,8 @@ define i8 @neg_select_trunc_bittest_to_sub_extra_use(i8 %x) { define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) { ; CHECK-LABEL: @select_trunc_nuw_bittest_to_shl_not( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 0, i8 4 +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 2 +; CHECK-NEXT: [[RET:%.*]] = xor i8 [[TMP1]], 4 ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc nuw i8 %x to i1 @@ -879,8 +878,8 @@ define i8 @select_trunc_nuw_bittest_to_shl_not(i8 %x) { define i8 @select_trunc_bittest_to_shl(i8 %x) { ; CHECK-LABEL: @select_trunc_bittest_to_shl( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RET:%.*]] = select i1 [[TRUNC]], i8 4, i8 0 +; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 2 +; CHECK-NEXT: [[RET:%.*]] = and i8 [[TMP1]], 4 ; CHECK-NEXT: ret i8 [[RET]] ; %trunc = trunc i8 %x to i1 @@ -903,8 +902,9 @@ define i8 @neg_select_trunc_bittest_to_shl_extra_use(i8 %x) { define i16 @select_trunc_nuw_bittest_or(i8 %x) { ; CHECK-LABEL: @select_trunc_nuw_bittest_or( -; CHECK-NEXT: [[TMP1:%.*]] = trunc nuw i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[RES:%.*]] = select i1 [[TMP1]], i16 20, i16 4 +; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[SELECT:%.*]] = shl nuw nsw i16 [[TMP1]], 4 +; CHECK-NEXT: [[RES:%.*]] = or disjoint i16 [[SELECT]], 4 ; CHECK-NEXT: ret i16 [[RES]] ; %trunc = trunc nuw i8 %x to i1