Skip to content

[WIP] Unified G_BUILD_VECTOR combine #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: aie-public
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions llvm/lib/Target/AIE/AIECombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,11 @@ def combine_extract_vector_elt_and_zsa_ext : GICombineRule<
(apply [{ applyExtractVecEltAndExt(*${root}, MRI, B, ${matchinfo}); }])
>;

def combine_splat_vector_matchdata: GIDefMatchData<"std::pair<Register, Register>">;
def combine_splat_vector : GICombineRule<
(defs root:$root, combine_splat_vector_matchdata:$matchinfo),
(match (wip_match_opcode G_BUILD_VECTOR): $root,
[{ return matchSplatVector(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applySplatVector(*${root}, MRI, B, ${matchinfo}); }])
>;

def combine_single_diff_build_vector_matchdata: GIDefMatchData<"AIESingleDiffLaneBuildVectorMatchData">;
def combine_single_diff_build_vector : GICombineRule<
(defs root:$root, combine_single_diff_build_vector_matchdata:$matchinfo),
def combine_build_vector : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_BUILD_VECTOR): $root,
[{ return matchSingleDiffLaneBuildVector(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applySingleDiffLaneBuildVector(*${root}, MRI, B, ${matchinfo}); }])
[{ return matchBuildVectorPatterns(*${root}, MRI, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])
>;

def combine_pad_vector_matchdata: GIDefMatchData<"Register">;
Expand All @@ -72,11 +63,12 @@ def combine_unpad_vector : GICombineRule<
(apply [{ applyUnpadVector(*${root}, MRI, B); }])
>;

def combine_vector_broadcast_matchdata: GIDefMatchData<"std::pair<Register, Register>">;
def combine_vector_broadcast : GICombineRule<
(defs root:$root, combine_splat_vector_matchdata:$matchinfo),
(defs root:$root, combine_vector_broadcast_matchdata:$matchinfo),
(match (wip_match_opcode G_SHUFFLE_VECTOR): $root,
[{ return matchBroadcastElement(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applySplatVector(*${root}, MRI, B, ${matchinfo}); }])>;
(apply [{ applyBroadcastElement(*${root}, MRI, B, ${matchinfo}); }])>;

def combine_vector_shuffle_broadcast : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
Expand Down Expand Up @@ -107,8 +99,7 @@ def AIE2PreLegalizerCombiner
all_combines, combine_S20NarrowingOpt,
combine_globalval_offset,
combine_extract_vector_elt_and_zsa_ext,
combine_splat_vector, combine_concat_to_pad_vector,
combine_single_diff_build_vector]> {
combine_build_vector, combine_concat_to_pad_vector]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

Expand All @@ -119,11 +110,10 @@ def AIE2PPreLegalizerCombiner
all_combines, combine_S20NarrowingOpt,
combine_globalval_offset,
combine_extract_vector_elt_and_zsa_ext,
combine_splat_vector, combine_vector_broadcast,
combine_build_vector, combine_vector_broadcast,
combine_concat_to_pad_vector,
combine_vector_shuffle_vsel,
combine_vector_shuffle_broadcast,
combine_single_diff_build_vector]> {
combine_vector_shuffle_broadcast]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

Expand Down
200 changes: 96 additions & 104 deletions llvm/lib/Target/AIE/AIECombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,44 +1133,6 @@ void llvm::applyExtractVecEltAndExt(
MI.eraseFromParent();
MatchMI->eraseFromParent();
}

// Match something like:
// %0:_(<32 x s16>) = G_BUILD_VECTOR %1:_(s16), ... x32
//
// To turn it into
// %0:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR %1:_(s16)
bool llvm::matchSplatVector(MachineInstr &MI, MachineRegisterInfo &MRI,
std::pair<Register, Register> &MatchInfo) {

assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected a G_BUILD_VECTOR");

const Register DstVecReg = MI.getOperand(0).getReg();
const LLT DstVecTy = MRI.getType(DstVecReg);
const unsigned DstVecSize = DstVecTy.getSizeInBits();

switch (DstVecSize) {
case 256:
case 512:
case 1024:
case 2048:
break;
default:
// unimplemented
return false;
}

const unsigned NumOps = MI.getNumOperands();
const MachineOperand FirstOp = MI.getOperand(1);
for (unsigned i = 2; i < NumOps; i++) {
if (!MI.getOperand(i).isIdenticalTo(FirstOp)) {
return false;
}
}
MatchInfo = std::make_pair(DstVecReg, FirstOp.getReg());
return true;
}

static void buildBroadcastVector(MachineIRBuilder &B, MachineRegisterInfo &MRI,
Register SrcReg, Register DstVecReg) {
const AIEBaseInstrInfo &AIETII = (const AIEBaseInstrInfo &)B.getTII();
Expand Down Expand Up @@ -1230,101 +1192,121 @@ static void buildBroadcastVector(MachineIRBuilder &B, MachineRegisterInfo &MRI,
}
}

bool llvm::applySplatVector(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
std::pair<Register, Register> &MatchInfo) {
B.setInstrAndDebugLoc(MI);
auto [DstVecReg, SrcReg] = MatchInfo;
buildBroadcastVector(B, MRI, SrcReg, DstVecReg);
MI.eraseFromParent();
// This helper attempts to match a single-different-lane splat vector pattern.
// That pattern is a nearly-splat vector (one repeated register) but with one
// lane "differing" from the rest.
static bool matchSingleDiffSplatVector(
DenseMap<Register, std::pair<unsigned, unsigned>> UniqueRegs,
Register DstVecReg, MachineRegisterInfo &MRI, BuildFnTy &MatchInfo) {
Register SplatReg, DifferingReg;
unsigned DifferingIndex;

// Identify splat (multiple uses) and differing (single use) registers
for (const auto &[Reg, RegInfo] : UniqueRegs) {
if (RegInfo.first == 1) {
DifferingReg = Reg;
DifferingIndex = RegInfo.second;
} else {
SplatReg = Reg;
}
}
// Validate that one register was used exactly once
if (!DifferingReg.isValid() || !SplatReg.isValid())
return false;

// Ignore G_IMPLICIT_DEF to avoid conflicts with \fn matchBroadcastElement
const MachineInstr *SplatRegDef = getDefIgnoringCopies(SplatReg, MRI);
if (!SplatRegDef || SplatRegDef->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
return false;

// If we match, build a function-lambda that does the transformation:
// 1) Create a broadcast of the SplatReg into the destination vector.
// 2) Insert the differing lane at DifferingIndex.
MatchInfo = [=, &MRI](MachineIRBuilder &B) {
const LLT DstVecRegTy = MRI.getType(DstVecReg);
const Register BcstDstReg = MRI.createGenericVirtualRegister(DstVecRegTy);
const LLT S32 = LLT::scalar(32);

buildBroadcastVector(B, MRI, SplatReg, BcstDstReg);
const Register IdxReg = B.buildConstant(S32, DifferingIndex).getReg(0);
B.buildInsertVectorElement(DstVecReg, BcstDstReg, DifferingReg, IdxReg);
};
return true;
}

// Match something like:
// %0:_(<32 x s16>) = G_BUILD_VECTOR %1:_(s16), ... x32
//
// To turn it into
// %0:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR %1:_(s16)
//
// And this:
// %0:_(<32 x s16>) = G_BUILD_VECTOR %2:(s16), %2:(s16), %1:(s16) ... x32
//
// To turn it into
// %3:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR %2:_(s16)
// %0:(<32 x s16>) = G_AIE_INSERT_VECTOR_ELT %3:(<32 x s16>), %1:_(s16), 2
bool llvm::matchSingleDiffLaneBuildVector(
MachineInstr &MI, MachineRegisterInfo &MRI,
AIESingleDiffLaneBuildVectorMatchData &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected a G_BUILD_VECTOR");
static bool matchSplatVector(MachineInstr &MI, MachineRegisterInfo &MRI,
BuildFnTy &MatchInfo) {

const Register DstVecReg = MI.getOperand(0).getReg();
const LLT DstVecTy = MRI.getType(DstVecReg);
const unsigned DstVecSize = DstVecTy.getSizeInBits();

switch (DstVecSize) {
case 256:
case 512:
case 1024:
case 2048:
break;
default:
// unimplemented
return false;
}
// DenseMap to hold unique registers and their (count, last index)
DenseMap<Register, std::pair<unsigned, unsigned>> UniqueRegs;
const Register DstVecReg = MI.getOperand(0).getReg();
const unsigned NumOps = MI.getNumOperands();

for (unsigned i = 1; i < NumOps; i++) {
const Register OpReg = MI.getOperand(i).getReg();
auto &RegInfo = UniqueRegs[OpReg];
RegInfo.first += 1;
RegInfo.second = i - 1;

if (UniqueRegs.size() > 2)
return false;
}
// Ensure exactly 2 unique registers to match the single differing lane build
// vector pattern. More than 2 registers won't match; 1 unique register would
// be a splat vector combine
if (UniqueRegs.size() != 2)
return false;

Register SplatReg, DifferingReg;
unsigned DifferingIndex;
switch (UniqueRegs.size()) {
case 1: {
// Pure splat as there is only one unique register.

// Identify splat (multiple uses) and differing (single use) registers
for (const auto &[Reg, RegInfo] : UniqueRegs) {
if (RegInfo.first == 1) {
DifferingReg = Reg;
DifferingIndex = RegInfo.second;
} else {
SplatReg = Reg;
}
auto It = UniqueRegs.begin();
Register SplatReg = It->first;
// Build a lambda that creates a broadcast of SplatReg into DstVecReg.
MatchInfo = [=, &MRI](MachineIRBuilder &B) {
buildBroadcastVector(B, MRI, SplatReg, DstVecReg);
};
return true;
}
// Validate that one register was used exactly once
if (!DifferingReg.isValid() || !SplatReg.isValid())
case 2:
// Check for single differing lane splat, as there are 2 unique registers.
return matchSingleDiffSplatVector(UniqueRegs, DstVecReg, MRI, MatchInfo);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to take this further. We still have separate apply lambdas. Instead, we could match a splat pattern, and pack enough information in the apply lambda so that it can build the broadcast of the splat, and then, if we have different elements, perform one (or more) additional inserts. I'd make the apply lambda accept any number of additional asserts, and let the match function decide how many to allow.
The next step is then to recognize that the generic build vector inserts all elements on top of an undefined, so that we can also reuse that apply lambda to build a generic vector. Note that a broadcast is no more expensive than an insert into an undef.. The inserts could also be done by push lo or push hi.

Copy link
Collaborator

@martien-de-jong martien-de-jong Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's more fun to be had, by imagining the different registers to be inserted in vectors, and translating the BUILD initializer to a shufflemask. By virtualizing the shuffle match and apply methods, you could then build on all of the shuffle tricks, e.g. to create two (or more) broadcast vectors for two different registers, and to pick from the two using a a (or more) VSELs. (Not now, just to illustrate reuse possibilities)

default:
return false;
}
}

// Ignore G_IMPLICIT_DEF to avoid conflicts with \fn matchBroadcastElement
const MachineInstr *SplatRegDef = getDefIgnoringCopies(SplatReg, MRI);
if (!SplatRegDef || SplatRegDef->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
return false;
bool llvm::matchBuildVectorPatterns(MachineInstr &MI, MachineRegisterInfo &MRI,
BuildFnTy &MatchInfo) {

MatchInfo = {DstVecReg, SplatReg, DifferingReg, DifferingIndex};
return true;
}
assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
"Expected a G_BUILD_VECTOR");

bool llvm::applySingleDiffLaneBuildVector(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
AIESingleDiffLaneBuildVectorMatchData &MatchInfo) {
B.setInstrAndDebugLoc(MI);
const Register DstVecReg = MatchInfo.DstVecReg;
const LLT DstVecRegTy = MRI.getType(DstVecReg);
const Register BcstDstReg = MRI.createGenericVirtualRegister(DstVecRegTy);
const LLT S32 = LLT::scalar(32);
const Register DstVecReg = MI.getOperand(0).getReg();
const LLT DstVecTy = MRI.getType(DstVecReg);
const unsigned DstVecSize = DstVecTy.getSizeInBits();

buildBroadcastVector(B, MRI, MatchInfo.SplatReg, BcstDstReg);
const Register IdxReg =
B.buildConstant(S32, MatchInfo.DifferingIndex).getReg(0);
B.buildInsertVectorElement(DstVecReg, BcstDstReg, MatchInfo.DifferingReg,
IdxReg);
MI.eraseFromParent();
return true;
switch (DstVecSize) {
case 256:
case 512:
case 1024:
case 2048:
break;
default:
// unimplemented
return false;
}
if (matchSplatVector(MI, MRI, MatchInfo)) {
return true;
}
return false;
}

// Match something like:
Expand Down Expand Up @@ -1903,6 +1885,16 @@ bool llvm::matchBroadcastElement(MachineInstr &MI, MachineRegisterInfo &MRI,
return true;
}

bool llvm::applyBroadcastElement(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
std::pair<Register, Register> &MatchInfo) {
B.setInstrAndDebugLoc(MI);
auto [DstVecReg, SrcReg] = MatchInfo;
buildBroadcastVector(B, MRI, SrcReg, DstVecReg);
MI.eraseFromParent();
return true;
}

static void buildUnmergeVector(MachineIRBuilder &B, MachineRegisterInfo &MRI,
Register DstReg, Register SrcReg,
unsigned NumSubVectors, unsigned SubIdx) {
Expand Down
28 changes: 5 additions & 23 deletions llvm/lib/Target/AIE/AIECombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,6 @@ class MaskMatch {
int Amplitude = 1;
};

struct AIESingleDiffLaneBuildVectorMatchData {
/// Destination register of G_BUILD_VECTOR
Register DstVecReg;
/// The repeated register
Register SplatReg;
/// Register for the differing element
Register DifferingReg;
/// Lane index of the single differing element
unsigned DifferingIndex;
};

/// Look for any PtrAdd instruction that use the same base as \a MI that can be
/// combined with it and stores it in \a MatchData
/// \return true if an instruction is found
Expand Down Expand Up @@ -101,6 +90,9 @@ bool matchGlobalValOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
/// idiom into G_AIE_BROADCAST
bool matchBroadcastElement(MachineInstr &MI, MachineRegisterInfo &MRI,
std::pair<Register, Register> &MatchInfo);
bool applyBroadcastElement(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
std::pair<Register, Register> &MatchInfo);
bool matchShuffleToBroadcast(MachineInstr &MI, MachineRegisterInfo &MRI,
const AIEBaseInstrInfo &TII, BuildFnTy &MatchInfo);
/// Combine G_SHUFFLE_VECTOR(G_BUILD_VECTOR (VAL, UNDEF, ...), mask<0,0,...>)
Expand Down Expand Up @@ -186,18 +178,8 @@ void applyExtractVecEltAndExt(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
std::pair<MachineInstr *, bool> &MatchInfo);

bool matchSplatVector(MachineInstr &MI, MachineRegisterInfo &MRI,
std::pair<Register, Register> &MatchInfo);
bool applySplatVector(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B,
std::pair<Register, Register> &MatchInfo);

bool matchSingleDiffLaneBuildVector(
MachineInstr &MI, MachineRegisterInfo &MRI,
AIESingleDiffLaneBuildVectorMatchData &MatchInfo);
bool applySingleDiffLaneBuildVector(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
AIESingleDiffLaneBuildVectorMatchData &MatchInfo);
bool matchBuildVectorPatterns(MachineInstr &MI, MachineRegisterInfo &MRI,
BuildFnTy &MatchInfo);

bool matchUnpadVector(MachineInstr &MI, MachineRegisterInfo &MRI,
const AIEBaseInstrInfo &TII);
Expand Down