-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[RISCV][GISel] Add ISel supports for SHXADD from Zba extension #67863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
08f77d6
4d81ad5
2d4dce1
9de0c2f
0b2e658
5484c7e
9eef0c5
d8edaba
7df267f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#include "RISCVTargetMachine.h" | ||
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" | ||
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" | ||
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" | ||
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" | ||
#include "llvm/IR/IntrinsicsRISCV.h" | ||
#include "llvm/Support/Debug.h" | ||
|
@@ -55,6 +56,14 @@ class RISCVInstructionSelector : public InstructionSelector { | |
|
||
ComplexRendererFns selectShiftMask(MachineOperand &Root) const; | ||
|
||
ComplexRendererFns selectNonImm12(MachineOperand &Root) const; | ||
|
||
ComplexRendererFns selectSHXADDOp(MachineOperand &Root, unsigned ShAmt) const; | ||
template <unsigned ShAmt> | ||
ComplexRendererFns selectSHXADDOp(MachineOperand &Root) const { | ||
return selectSHXADDOp(Root, ShAmt); | ||
} | ||
|
||
// Custom renderers for tablegen | ||
void renderNegImm(MachineInstrBuilder &MIB, const MachineInstr &MI, | ||
int OpIdx) const; | ||
|
@@ -105,6 +114,127 @@ RISCVInstructionSelector::selectShiftMask(MachineOperand &Root) const { | |
return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}}; | ||
} | ||
|
||
// This complex pattern actually serves as a perdicate that is effectively | ||
// `!isInt<12>(Imm)`. | ||
InstructionSelector::ComplexRendererFns | ||
RISCVInstructionSelector::selectNonImm12(MachineOperand &Root) const { | ||
MachineFunction &MF = *Root.getParent()->getParent()->getParent(); | ||
MachineRegisterInfo &MRI = MF.getRegInfo(); | ||
|
||
if (Root.isReg() && Root.getReg()) | ||
if (auto Val = getIConstantVRegValWithLookThrough(Root.getReg(), MRI)) { | ||
// We do NOT want immediates that fit in 12 bits. | ||
if (isInt<12>(Val->Value.getSExtValue())) | ||
return std::nullopt; | ||
} | ||
|
||
return {{[=](MachineInstrBuilder &MIB) { MIB.add(Root); }}}; | ||
} | ||
|
||
InstructionSelector::ComplexRendererFns | ||
RISCVInstructionSelector::selectSHXADDOp(MachineOperand &Root, | ||
unsigned ShAmt) const { | ||
using namespace llvm::MIPatternMatch; | ||
MachineFunction &MF = *Root.getParent()->getParent()->getParent(); | ||
MachineRegisterInfo &MRI = MF.getRegInfo(); | ||
|
||
if (!Root.isReg()) | ||
return std::nullopt; | ||
Register RootReg = Root.getReg(); | ||
|
||
const unsigned XLen = STI.getXLen(); | ||
APInt Mask, C2; | ||
Register RegY; | ||
std::optional<bool> LeftShift; | ||
// (and (shl y, c2), mask) | ||
if (mi_match(RootReg, MRI, | ||
m_GAnd(m_GShl(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask)))) | ||
LeftShift = true; | ||
// (and (lshr y, c2), mask) | ||
else if (mi_match(RootReg, MRI, | ||
m_GAnd(m_GLShr(m_Reg(RegY), m_ICst(C2)), m_ICst(Mask)))) | ||
LeftShift = false; | ||
|
||
if (LeftShift.has_value()) { | ||
if (*LeftShift) | ||
Mask &= maskTrailingZeros<uint64_t>(C2.getLimitedValue()); | ||
else | ||
Mask &= maskTrailingOnes<uint64_t>(XLen - C2.getLimitedValue()); | ||
|
||
if (Mask.isShiftedMask()) { | ||
unsigned Leading = XLen - Mask.getActiveBits(); | ||
unsigned Trailing = Mask.countr_zero(); | ||
// Given (and (shl y, c2), mask) in which mask has no leading zeros and c3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm copying the comments from RISCVISelDAGToDAG.cpp to here for better readability. |
||
// trailing zeros. We can use an SRLI by c3 - c2 followed by a SHXADD. | ||
if (*LeftShift && Leading == 0 && C2.ult(Trailing) && Trailing == ShAmt) { | ||
Register DstReg = | ||
MRI.createGenericVirtualRegister(MRI.getType(RootReg)); | ||
return {{[=](MachineInstrBuilder &MIB) { | ||
MachineIRBuilder(*MIB.getInstr()) | ||
.buildInstr(RISCV::SRLI, {DstReg}, {RegY}) | ||
.addImm(Trailing - C2.getLimitedValue()); | ||
MIB.addReg(DstReg); | ||
}}}; | ||
} | ||
|
||
// Given (and (lshr y, c2), mask) in which mask has c2 leading zeros and c3 | ||
// trailing zeros. We can use an SRLI by c2 + c3 followed by a SHXADD. | ||
if (!*LeftShift && Leading == C2 && Trailing == ShAmt) { | ||
Register DstReg = | ||
MRI.createGenericVirtualRegister(MRI.getType(RootReg)); | ||
return {{[=](MachineInstrBuilder &MIB) { | ||
MachineIRBuilder(*MIB.getInstr()) | ||
.buildInstr(RISCV::SRLI, {DstReg}, {RegY}) | ||
.addImm(Leading + Trailing); | ||
MIB.addReg(DstReg); | ||
}}}; | ||
} | ||
} | ||
} | ||
|
||
LeftShift.reset(); | ||
|
||
// (shl (and y, mask), c2) | ||
if (mi_match(RootReg, MRI, | ||
m_GShl(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))), | ||
m_ICst(C2)))) | ||
LeftShift = true; | ||
// (lshr (and y, mask), c2) | ||
else if (mi_match(RootReg, MRI, | ||
m_GLShr(m_OneNonDBGUse(m_GAnd(m_Reg(RegY), m_ICst(Mask))), | ||
m_ICst(C2)))) | ||
LeftShift = false; | ||
|
||
if (LeftShift.has_value()) | ||
if (Mask.isShiftedMask()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we merge this condition with the previous if? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this comment addressed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is addressed now |
||
unsigned Leading = XLen - Mask.getActiveBits(); | ||
unsigned Trailing = Mask.countr_zero(); | ||
|
||
// Given (shl (and y, mask), c2) in which mask has 32 leading zeros and | ||
// c3 trailing zeros. If c1 + c3 == ShAmt, we can emit SRLIW + SHXADD. | ||
bool Cond = *LeftShift && Leading == 32 && Trailing > 0 && | ||
(Trailing + C2.getLimitedValue()) == ShAmt; | ||
if (!Cond) | ||
// Given (lshr (and y, mask), c2) in which mask has 32 leading zeros and | ||
// c3 trailing zeros. If c3 - c1 == ShAmt, we can emit SRLIW + SHXADD. | ||
Cond = !*LeftShift && Leading == 32 && C2.ult(Trailing) && | ||
(Trailing - C2.getLimitedValue()) == ShAmt; | ||
|
||
if (Cond) { | ||
Register DstReg = | ||
MRI.createGenericVirtualRegister(MRI.getType(RootReg)); | ||
return {{[=](MachineInstrBuilder &MIB) { | ||
MachineIRBuilder(*MIB.getInstr()) | ||
.buildInstr(RISCV::SRLIW, {DstReg}, {RegY}) | ||
.addImm(Trailing); | ||
MIB.addReg(DstReg); | ||
}}}; | ||
} | ||
} | ||
|
||
return std::nullopt; | ||
} | ||
|
||
// Tablegen doesn't allow us to write SRLIW/SRAIW/SLLIW patterns because the | ||
// immediate Operand has type XLenVT. GlobalISel wants it to be i32. | ||
bool RISCVInstructionSelector::earlySelectShift( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -235,10 +235,7 @@ def SimmShiftRightBy3XForm : SDNodeXForm<imm, [{ | |
}]>; | ||
|
||
// Pattern to exclude simm12 immediates from matching. | ||
def non_imm12 : PatLeaf<(XLenVT GPR:$a), [{ | ||
auto *C = dyn_cast<ConstantSDNode>(N); | ||
return !C || !isInt<12>(C->getSExtValue()); | ||
}]>; | ||
def non_imm12 : ComplexPattern<XLenVT, 1, "selectNonImm12", [], [], 0>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some digging, I think the answer is no and I'm sad about it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would add GISelPredicateCode to something like this work topperc@01205c1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This works, though I'm a little concerned that this might create too many boilerplate code in the future, since there needs to be a Predicate TG record for every opcode that goes with non_imm12 (even we abstract the real predicate logics into a function). What do you think? Also, interestingly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Done: it's no longer using ComplexPattern of |
||
|
||
def Shifted32OnesMask : PatLeaf<(imm), [{ | ||
uint64_t Imm = N->getZExtValue(); | ||
|
@@ -651,19 +648,19 @@ let Predicates = [HasStdExtZbb, IsRV64] in | |
def : Pat<(i64 (and GPR:$rs, 0xFFFF)), (ZEXT_H_RV64 GPR:$rs)>; | ||
|
||
let Predicates = [HasStdExtZba] in { | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), non_imm12:$rs2), | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 1)), (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH1ADD GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), non_imm12:$rs2), | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 2)), (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH2ADD GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), non_imm12:$rs2), | ||
def : Pat<(add (shl GPR:$rs1, (XLenVT 3)), (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH3ADD GPR:$rs1, GPR:$rs2)>; | ||
|
||
// More complex cases use a ComplexPattern. | ||
def : Pat<(add sh1add_op:$rs1, non_imm12:$rs2), | ||
def : Pat<(add sh1add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH1ADD sh1add_op:$rs1, GPR:$rs2)>; | ||
def : Pat<(add sh2add_op:$rs1, non_imm12:$rs2), | ||
def : Pat<(add sh2add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH2ADD sh2add_op:$rs1, GPR:$rs2)>; | ||
def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2), | ||
def : Pat<(add sh3add_op:$rs1, (non_imm12 (XLenVT GPR:$rs2))), | ||
(SH3ADD sh3add_op:$rs1, GPR:$rs2)>; | ||
|
||
def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2), | ||
|
@@ -735,48 +732,48 @@ def : Pat<(i64 (and GPR:$rs1, Shifted32OnesMask:$mask)), | |
(SLLI_UW (SRLI GPR:$rs1, Shifted32OnesMask:$mask), | ||
Shifted32OnesMask:$mask)>; | ||
|
||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (and GPR:$rs, 0xFFFFFFFF)), (ADD_UW GPR:$rs, (XLenVT X0))>; | ||
|
||
def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), non_imm12:$rs2)), | ||
def : Pat<(i64 (or_is_add (and GPR:$rs1, 0xFFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
|
||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 1)), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 2)), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (shl (and GPR:$rs1, 0xFFFFFFFF), (i64 3)), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
|
||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 1)), 0x1FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH1ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 2)), 0x3FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH2ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH3ADD_UW GPR:$rs1, GPR:$rs2)>; | ||
|
||
// More complex cases use a ComplexPattern. | ||
def : Pat<(i64 (add sh1add_uw_op:$rs1, non_imm12:$rs2)), | ||
def : Pat<(i64 (add sh1add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add sh2add_uw_op:$rs1, non_imm12:$rs2)), | ||
def : Pat<(i64 (add sh2add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>; | ||
def : Pat<(i64 (add sh3add_uw_op:$rs1, non_imm12:$rs2)), | ||
def : Pat<(i64 (add sh3add_uw_op:$rs1, (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>; | ||
|
||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>; | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH2ADD (SRLIW GPR:$rs1, 2), GPR:$rs2)>; | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH3ADD (SRLIW GPR:$rs1, 3), GPR:$rs2)>; | ||
|
||
// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift. | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>; | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>; | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)), | ||
def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), (non_imm12 (XLenVT GPR:$rs2)))), | ||
(SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>; | ||
|
||
def : Pat<(i64 (mul (and_oneuse GPR:$r, 0xFFFFFFFF), C3LeftShiftUW:$i)), | ||
|
Uh oh!
There was an error while loading. Please reload this page.