Skip to content

Handle Vector types in G_FADD using G_FPEXT #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: aie-public
Choose a base branch
from
Open
133 changes: 133 additions & 0 deletions llvm/lib/Target/AIE/AIELegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ bool AIELegalizerHelper::legalizeG_FPTRUNC(LegalizerHelper &Helper,

bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
MachineInstr &MI) const {
const AIEBaseInstrInfo *II = ST.getInstrInfo();
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

Expand All @@ -1206,6 +1207,67 @@ bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

// Vectors
/*
VDst = G_FPEXT VSrc
converts to
ZeroVec = G_AIE_BROADCAST_VECTOR VSrc
VShuffleLow = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 2
VShuffleHigh = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 3
VShuffleLow = G_BITCAST VShuffleLow
VShuffleHigh = G_BITCAST VShuffleHigh
VDst = G_CONCAT_VECTORS VShuffleLow, VShuffleHigh
*/
if (DstTy.isVector() && SrcTy.isVector()) {
// Extract type information
auto DstElementType = DstTy.getElementType();
auto SrcNumElements = SrcTy.getNumElements();
// Create constants for shuffle modes
Register Mode2 = MIRBuilder.buildConstant(S32, 2).getReg(0);
Register Mode3 = MIRBuilder.buildConstant(S32, 3).getReg(0);
Register Zero = MIRBuilder.buildConstant(S32, 0).getReg(0);
// Get the instructions
const unsigned BroadcastOpc = II->getGenericBroadcastVectorOpcode();
const unsigned VShuffleOpc = II->getGenericShuffleVectorOpcode();

// Step 1: Create a zero vector using broadcast
Register ZeroVec =
MIRBuilder.buildInstr(BroadcastOpc, {SrcTy}, {Zero}).getReg(0);
// Step 2: Create VSHUFFLE for lower 512 bits (mode 2)
Register VShuffleLow =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode2})
.getReg(0);
// Step 3: Create VSHUFFLE for high 512 bits (mode 3)
Register VShuffleHigh =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode3})
.getReg(0);
// Step 4: bitcast VShuffleLow and VShuffleHigh
// Example: <32xs16> -> <16xs32>
LLT CastToNewTy =
LLT::vector(ElementCount::getFixed(SrcNumElements / 2), DstElementType);
if (CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleLow).getSizeInBits() ||
CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleHigh).getSizeInBits()) {
llvm::errs()
<< "Error: Size mismatch in vector bitcast for G_FPEXT. Expected: "
<< CastToNewTy.getSizeInBits()
<< " bits, got: " << MRI.getType(VShuffleLow).getSizeInBits()
<< " and " << MRI.getType(VShuffleHigh).getSizeInBits() << " bits\n";
return false;
}
auto VShuffleLowCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleLow).getReg(0);
auto VShuffleHighCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleHigh).getReg(0);
// Step 5: Concatenate the two src vectors into dst vector
MIRBuilder.buildConcatVectors(DstReg, {VShuffleLowCast, VShuffleHighCast});

MI.eraseFromParent();
return true;
}

// Scalars
// We only handle bfloat16 to single precision conversion
if (DstTy != LLT::scalar(32) || SrcTy != LLT::scalar(16))
return false;
Expand Down Expand Up @@ -1300,6 +1362,9 @@ bool AIELegalizerHelper::legalizeG_FMUL(LegalizerHelper &Helper,
MI.eraseFromParent();
return true;
}
bool isBF16Vector(const LLT Ty) {
return Ty.isVector() && Ty.getScalarSizeInBits() == 16;
}

bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper,
MachineInstr &MI) const {
Expand All @@ -1309,6 +1374,74 @@ bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper,
const Register DstReg = MI.getOperand(0).getReg();
Register SrcLHS = MI.getOperand(1).getReg();
Register SrcRHS = MI.getOperand(2).getReg();
const LLT SrcLHSTy = MRI.getType(SrcLHS);
const LLT SrcRHSTy = MRI.getType(SrcRHS);

// Handle bf16 vectors code assumes the input is <32 x bf16>, the
// LegalizerInfo makes sure that the input is either padded or unmerged to <32
// x bf16>.
if (isBF16Vector(SrcLHSTy) && isBF16Vector(SrcRHSTy)) {
// vector should be of size 32 asssert
assert(SrcLHSTy.getNumElements() == 32 && SrcRHSTy.getNumElements() == 32 &&
"Expected vector of size 32 for inputs of G_FADD/G_FSUB");

// Step 1: Convert bf16 vectors to f32 vectors using FPExt
const LLT F32VecTy =
LLT::fixed_vector(SrcLHSTy.getNumElements(), LLT::scalar(32));
Register SrcLHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
Register SrcRHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
MIRBuilder.buildFPExt(SrcLHSF32, SrcLHS);
MIRBuilder.buildFPExt(SrcRHSF32, SrcRHS);

// Step 2: Input is going to be <32 x bf16> pad it to <64 x f32> for AIE2P
// as AccV64S32 is legal on AIE2P.
if (ST.isAIE2P()) {
const Register UndefVec = MIRBuilder.buildUndef(F32VecTy).getReg(0);
const Register ConcatLHS = MRI.createGenericVirtualRegister(V64FP32);
const Register ConcatRHS = MRI.createGenericVirtualRegister(V64FP32);
MIRBuilder.buildConcatVectors(ConcatLHS, {SrcLHSF32, UndefVec});
MIRBuilder.buildConcatVectors(ConcatRHS, {SrcRHSF32, UndefVec});
SrcLHSF32 = ConcatLHS;
SrcRHSF32 = ConcatRHS;
}

// Step 3: Perform the floating point operation
Register Res = MIRBuilder
.buildInstr(MI.getOpcode(), {MRI.getType(SrcLHSF32)},
{SrcLHSF32, SrcRHSF32})
.getReg(0);

// Step 4: Handle accumulator conversion based on target
if (ST.isAIE2()) {
Res = MIRBuilder.buildBitcast(V8ACC64, Res).getReg(0);
} else if (ST.isAIE2P()) {
// Unmerge to get 2 vectors of <32xf32> as FADD/FSUB was done on <64xf32>
SmallVector<Register, 2> UnmergedRegs;
const auto Unmerge = MIRBuilder.buildUnmerge(F32VecTy, Res);
getUnmergeResults(UnmergedRegs, *Unmerge);
Res = UnmergedRegs[0]; // Take the first <32xf32> vector, other half is
// just zeros.
}

// Step 5: Convert back to bf16 using the truncation intrinsic
const int VecSize = MRI.getType(Res).getSizeInBits();
const LLT DstLLT = ST.isAIE2P() ? V32BF16 : V16BF16;
Res = MIRBuilder
.buildIntrinsic(getFpTrunc32ToBF16IntrID(ST, VecSize), {DstLLT},
true, false)
.addUse(Res)
.getReg(0);

// Handle AIE2 padding
if (ST.isAIE2()) {
Res = emitPadUndefVector(MRI, MIRBuilder, V32BF16, Res);
}

MIRBuilder.buildCopy(DstReg, Res);

MI.eraseFromParent();
return true;
}

assert(MRI.getType(DstReg) == LLT::scalar(16) &&
"Expected bfloat16 type in custom legalization.");
Expand Down
61 changes: 60 additions & 1 deletion llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) {
};
}

// `V2 = G_FPEXT V1` on vectors is valid iff:
// - V1 and V2 are floating-point vectors
// - V2 is wider than V1 for total vector sizes
// - Number of elements of both vectors are same
// - Size of Element of V2 = 2 * Size of Element of V1
static LegalityPredicate isValidVectorFPEXT(const unsigned TypeIdx_dst,
const unsigned TypeIdx_src) {
return [=](const LegalityQuery &Query) {
const LLT DstTy = Query.Types[TypeIdx_dst];
const LLT SrcTy = Query.Types[TypeIdx_src];
if (DstTy.isVector() && SrcTy.isVector()) {
auto DstElementCount = DstTy.getElementCount();
auto SrcElementCount = SrcTy.getElementCount();
auto DstElementType = DstTy.getElementType();
auto SrcElementType = SrcTy.getElementType();
auto DstElementSize = DstElementType.getSizeInBits();
auto SrcElementSize = SrcElementType.getSizeInBits();
return DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
DstElementCount == SrcElementCount &&
(DstElementSize == (SrcElementSize * 2));
}
return false;
};
}

static LegalityPredicate
negatePredicate(const std::function<bool(const LegalityQuery &)> &Func) {
return [=](const LegalityQuery &Query) { return !Func(Query); };
Expand Down Expand Up @@ -219,6 +244,13 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
getActionDefinitionsBuilder(G_FPEXT)
.libcallFor({{S64, S32}})
.customFor({{S32, S16}})
// Add support for vector types
// Extend vectors to have at least 512-bits
.clampMinNumElements(1, S8, 64)
.clampMinNumElements(1, S16, 32)
.clampMinNumElements(1, S32, 16)
.customIf(isValidVectorFPEXT(0 /* Dst */, 1 /* Src */))
// .customFor({{V32S32, V32S16}})
.narrowScalarFor({{S64, S16}}, llvm::LegalizeMutations::changeTo(0, S32));

getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
Expand All @@ -241,7 +273,34 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)

getActionDefinitionsBuilder({G_FADD, G_FSUB})
.legalFor({AccV64S32})
.customFor({S16})
// Handle custom bf16 case for both scalar and vector types
.customFor({S16, V32S16})
// Convert smaller than <32 x f32/bf16> to legal sizes, doesn't change types
.moreElementsIf(
[=](const LegalityQuery &Query) {
const LLT &Ty = Query.Types[0];
return Ty.isVector() &&
(Ty.getScalarSizeInBits() == 32 ||
Ty.getScalarSizeInBits() == 16) &&
Ty.getNumElements() <= 32;
},
[=](const LegalityQuery &Query) {
if (Query.Types[0].getScalarSizeInBits() == 32) {
return std::make_pair(0, LLT::fixed_vector(64, S32));
} else {
return std::make_pair(0, LLT::fixed_vector(32, S16));
}
})
// Converts <64xbf16> into 2 chunks of <32xbf16>
.fewerElementsIf(
[=](const LegalityQuery &Query) {
const LLT &Ty = Query.Types[0];
return Ty.isVector() && (Ty.getScalarSizeInBits() == 16) &&
Ty.getNumElements() == 64;
},
[=](const LegalityQuery &Query) {
return std::make_pair(0, LLT::fixed_vector(32, S16));
})
.libcallFor({S32, S64});

getActionDefinitionsBuilder({G_FDIV, G_FREM})
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the tests, it is better to run just the pass of interest. For example, you can create an MIR test including just the ilegal type operation and run llc with -run-pass=legalizer. In this way we can easily spot the specific legalization change in action.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The good part is that you can use llvm/utils/update_mir_test_checks.py to update the tests.

; This test is a carved out test for sending patch upstream from
; iree-amd-aie/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/multi_reduction_to_reduction_sizes_types.mlirUntitled-1.mlir

; Ideally reduction should be as follows(with minor changes for each shape):
; Input1: <32xbf16> and Input2: <32xbf16>
; Extended1<32xf32> = fpext <32xbf16>
; Extended2<32xf32> = fpext <32xbf16>
; Zero<32xf32> = zeroinitializer
; Out1<64xf32> = Concat zero, <Extended1<32xf32>>
; Out2<64xf32> = Concat zero, <Extended2<32xf32>>
; Result<64xf32> = fadd <Out1<64xf32>>, <Out2<64xf32>>
; R1<32xf32>, R2<32xf32> = unmerge <Result<64xf32>>
; R2 is all 0s
; R1<32xbf16> = trunc <R1<32xf32>>

; check the vadd.f
; pad checks
; checks similar to <32xbf16>
; unpad checks
define bfloat @multi_reduction_1d_16_bf16(<16 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v16bf16(bfloat %1, <16 x bfloat> %0)
ret bfloat %3
}



; CHECK-LABEL: name: multi_reduction_1d_32_bf16
; CHECK: G_CONSTANT i32 0
; CHECK: G_AIE_BROADCAST_VECTOR %{{[0-9]+}}(s32)
; CHECK: G_CONSTANT i32 2
; CHECK: G_CONSTANT i32 3
; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32)
; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32)
; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>)
; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>)
; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>)
; CHECK: G_IMPLICIT_DEF
; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<32 x s32>), %{{[0-9]+}}(<32 x s32>)
; CHECK: G_FADD %{{[0-9]+}}, %{{[0-9]+}}
; CHECK: G_UNMERGE_VALUES %{{[0-9]+}}(<64 x s32>)
; CHECK: G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.aie2p.v32accfloat.to.v32bf16), %{{[0-9]+}}(<32 x s32>)
define bfloat @multi_reduction_1d_32_bf16(<32 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v32bf16(bfloat %1, <32 x bfloat> %0)
ret bfloat %3
}

; ; Converted to chunks of <32 x bf16>
; Check if the input is split into 2 chunks of <32 x bf16>
; Check for each chunk similar to <32xbf16> case
; Check if both inputs get concatenated to <64xbf16>

define bfloat @multi_reduction_1d_64_bf16(<64 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v64bf16(bfloat %1, <64 x bfloat> %0)
ret bfloat %3
}
73 changes: 73 additions & 0 deletions llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s


; Validates bfloat -> float legalization.
; CHECK-LABEL: name: extend
; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0
; CHECK-NOT: G_SHL
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32)
; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C2]](s32)
; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C3]](s32)
; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>)
; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>)
; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[BIT1]](<16 x s32>), [[BIT2]](<16 x s32>)

define <32 x float> @extend(bfloat %o, <32 x bfloat> %in) nounwind {
%X = fpext <32 x bfloat> %in to <32 x float>
ret <32 x float> %X
}

; Pads the 17 valid values with undefined values to form a 32 size vector.

; CHECK-LABEL: name: extend_non_power_of_2
; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0
; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT
; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32)
; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C2]](s32)
; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C3]](s32)
; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>)
; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>)
; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT
; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI
; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>)
define <17 x float> @extend_non_power_of_2(<17 x bfloat> %in) nounwind {
%X = fpext <17 x bfloat> %in to <17 x float>
ret <17 x float> %X
}

; Validates if vector size < 256 bits

; CHECK-LABEL: name: fpext_bf16_to_f32
; CHECK: bb.1
; CHECK: [[VEC_CONCAT:%[0-9]+]]:_(<32 x s16>) = G_CONCAT_VECTORS
; CHECK: G_AIE_SEXT_EXTRACT_VECTOR_ELT [[VEC_CONCAT]]
; CHECK: G_AIE_ADD_VECTOR_ELT_HI
; CHECK: [[SHUFFLE_VEC:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR
; CHECK-NOT: G_AIE_SHUFFLE_VECTOR
; CHECK: [[BITCAST:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUFFLE_VEC]]
; CHECK: $x0 = COPY [[BITCAST]]
define <16 x float> @fpext_bf16_to_f32(<16 x bfloat> %in) nounwind {
%X = fpext <16 x bfloat> %in to <16 x float>
ret <16 x float> %X
}

; Validates scalar path
; CHECK-LABEL: name: fpext_scalar_bf16_to_f32
; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $r1
; CHECK-NEXT: [[C16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16
; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[COPY]], [[C16]](s32)
; CHECK-NOT: G_AIE_SHUFFLE_VECTOR
; CHECK-NEXT: $r0 = COPY [[SHL]](s32)
; CHECK-NEXT: PseudoRET implicit $lr, implicit $r0

define float @fpext_scalar_bf16_to_f32(bfloat %in) nounwind {
%X = fpext bfloat %in to float
ret float %X
}