diff --git a/llvm/lib/Target/AIE/AIELegalizerHelper.cpp b/llvm/lib/Target/AIE/AIELegalizerHelper.cpp index 24ef97844444..34c349676027 100644 --- a/llvm/lib/Target/AIE/AIELegalizerHelper.cpp +++ b/llvm/lib/Target/AIE/AIELegalizerHelper.cpp @@ -1196,6 +1196,7 @@ bool AIELegalizerHelper::legalizeG_FPTRUNC(LegalizerHelper &Helper, bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper, MachineInstr &MI) const { + const AIEBaseInstrInfo *II = ST.getInstrInfo(); MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); @@ -1206,6 +1207,67 @@ bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper, LLT DstTy = MRI.getType(DstReg); LLT SrcTy = MRI.getType(SrcReg); + // Vectors + /* + VDst = G_FPEXT VSrc + converts to + ZeroVec = G_AIE_BROADCAST_VECTOR VSrc + VShuffleLow = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 2 + VShuffleHigh = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 3 + VShuffleLow = G_BITCAST VShuffleLow + VShuffleHigh = G_BITCAST VShuffleHigh + VDst = G_CONCAT_VECTORS VShuffleLow, VShuffleHigh + */ + if (DstTy.isVector() && SrcTy.isVector()) { + // Extract type information + auto DstElementType = DstTy.getElementType(); + auto SrcNumElements = SrcTy.getNumElements(); + // Create constants for shuffle modes + Register Mode2 = MIRBuilder.buildConstant(S32, 2).getReg(0); + Register Mode3 = MIRBuilder.buildConstant(S32, 3).getReg(0); + Register Zero = MIRBuilder.buildConstant(S32, 0).getReg(0); + // Get the instructions + const unsigned BroadcastOpc = II->getGenericBroadcastVectorOpcode(); + const unsigned VShuffleOpc = II->getGenericShuffleVectorOpcode(); + + // Step 1: Create a zero vector using broadcast + Register ZeroVec = + MIRBuilder.buildInstr(BroadcastOpc, {SrcTy}, {Zero}).getReg(0); + // Step 2: Create VSHUFFLE for lower 512 bits (mode 2) + Register VShuffleLow = + MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode2}) + .getReg(0); + // Step 3: Create VSHUFFLE for high 512 bits (mode 3) + Register VShuffleHigh = + MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode3}) + .getReg(0); + // Step 4: bitcast VShuffleLow and VShuffleHigh + // Example: <32xs16> -> <16xs32> + LLT CastToNewTy = + LLT::vector(ElementCount::getFixed(SrcNumElements / 2), DstElementType); + if (CastToNewTy.getSizeInBits() != + MRI.getType(VShuffleLow).getSizeInBits() || + CastToNewTy.getSizeInBits() != + MRI.getType(VShuffleHigh).getSizeInBits()) { + llvm::errs() + << "Error: Size mismatch in vector bitcast for G_FPEXT. Expected: " + << CastToNewTy.getSizeInBits() + << " bits, got: " << MRI.getType(VShuffleLow).getSizeInBits() + << " and " << MRI.getType(VShuffleHigh).getSizeInBits() << " bits\n"; + return false; + } + auto VShuffleLowCast = + MIRBuilder.buildCast(CastToNewTy, VShuffleLow).getReg(0); + auto VShuffleHighCast = + MIRBuilder.buildCast(CastToNewTy, VShuffleHigh).getReg(0); + // Step 5: Concatenate the two src vectors into dst vector + MIRBuilder.buildConcatVectors(DstReg, {VShuffleLowCast, VShuffleHighCast}); + + MI.eraseFromParent(); + return true; + } + + // Scalars // We only handle bfloat16 to single precision conversion if (DstTy != LLT::scalar(32) || SrcTy != LLT::scalar(16)) return false; @@ -1300,6 +1362,9 @@ bool AIELegalizerHelper::legalizeG_FMUL(LegalizerHelper &Helper, MI.eraseFromParent(); return true; } +bool isBF16Vector(const LLT Ty) { + return Ty.isVector() && Ty.getScalarSizeInBits() == 16; +} bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper, MachineInstr &MI) const { @@ -1309,6 +1374,74 @@ bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper, const Register DstReg = MI.getOperand(0).getReg(); Register SrcLHS = MI.getOperand(1).getReg(); Register SrcRHS = MI.getOperand(2).getReg(); + const LLT SrcLHSTy = MRI.getType(SrcLHS); + const LLT SrcRHSTy = MRI.getType(SrcRHS); + + // Handle bf16 vectors code assumes the input is <32 x bf16>, the + // LegalizerInfo makes sure that the input is either padded or unmerged to <32 + // x bf16>. + if (isBF16Vector(SrcLHSTy) && isBF16Vector(SrcRHSTy)) { + // vector should be of size 32 asssert + assert(SrcLHSTy.getNumElements() == 32 && SrcRHSTy.getNumElements() == 32 && + "Expected vector of size 32 for inputs of G_FADD/G_FSUB"); + + // Step 1: Convert bf16 vectors to f32 vectors using FPExt + const LLT F32VecTy = + LLT::fixed_vector(SrcLHSTy.getNumElements(), LLT::scalar(32)); + Register SrcLHSF32 = MRI.createGenericVirtualRegister(F32VecTy); + Register SrcRHSF32 = MRI.createGenericVirtualRegister(F32VecTy); + MIRBuilder.buildFPExt(SrcLHSF32, SrcLHS); + MIRBuilder.buildFPExt(SrcRHSF32, SrcRHS); + + // Step 2: Input is going to be <32 x bf16> pad it to <64 x f32> for AIE2P + // as AccV64S32 is legal on AIE2P. + if (ST.isAIE2P()) { + const Register UndefVec = MIRBuilder.buildUndef(F32VecTy).getReg(0); + const Register ConcatLHS = MRI.createGenericVirtualRegister(V64FP32); + const Register ConcatRHS = MRI.createGenericVirtualRegister(V64FP32); + MIRBuilder.buildConcatVectors(ConcatLHS, {SrcLHSF32, UndefVec}); + MIRBuilder.buildConcatVectors(ConcatRHS, {SrcRHSF32, UndefVec}); + SrcLHSF32 = ConcatLHS; + SrcRHSF32 = ConcatRHS; + } + + // Step 3: Perform the floating point operation + Register Res = MIRBuilder + .buildInstr(MI.getOpcode(), {MRI.getType(SrcLHSF32)}, + {SrcLHSF32, SrcRHSF32}) + .getReg(0); + + // Step 4: Handle accumulator conversion based on target + if (ST.isAIE2()) { + Res = MIRBuilder.buildBitcast(V8ACC64, Res).getReg(0); + } else if (ST.isAIE2P()) { + // Unmerge to get 2 vectors of <32xf32> as FADD/FSUB was done on <64xf32> + SmallVector UnmergedRegs; + const auto Unmerge = MIRBuilder.buildUnmerge(F32VecTy, Res); + getUnmergeResults(UnmergedRegs, *Unmerge); + Res = UnmergedRegs[0]; // Take the first <32xf32> vector, other half is + // just zeros. + } + + // Step 5: Convert back to bf16 using the truncation intrinsic + const int VecSize = MRI.getType(Res).getSizeInBits(); + const LLT DstLLT = ST.isAIE2P() ? V32BF16 : V16BF16; + Res = MIRBuilder + .buildIntrinsic(getFpTrunc32ToBF16IntrID(ST, VecSize), {DstLLT}, + true, false) + .addUse(Res) + .getReg(0); + + // Handle AIE2 padding + if (ST.isAIE2()) { + Res = emitPadUndefVector(MRI, MIRBuilder, V32BF16, Res); + } + + MIRBuilder.buildCopy(DstReg, Res); + + MI.eraseFromParent(); + return true; + } assert(MRI.getType(DstReg) == LLT::scalar(16) && "Expected bfloat16 type in custom legalization."); diff --git a/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp b/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp index 5ea47376c667..d2cd95b2c832 100644 --- a/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp +++ b/llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp @@ -73,6 +73,31 @@ static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) { }; } +// `V2 = G_FPEXT V1` on vectors is valid iff: +// - V1 and V2 are floating-point vectors +// - V2 is wider than V1 for total vector sizes +// - Number of elements of both vectors are same +// - Size of Element of V2 = 2 * Size of Element of V1 +static LegalityPredicate isValidVectorFPEXT(const unsigned TypeIdx_dst, + const unsigned TypeIdx_src) { + return [=](const LegalityQuery &Query) { + const LLT DstTy = Query.Types[TypeIdx_dst]; + const LLT SrcTy = Query.Types[TypeIdx_src]; + if (DstTy.isVector() && SrcTy.isVector()) { + auto DstElementCount = DstTy.getElementCount(); + auto SrcElementCount = SrcTy.getElementCount(); + auto DstElementType = DstTy.getElementType(); + auto SrcElementType = SrcTy.getElementType(); + auto DstElementSize = DstElementType.getSizeInBits(); + auto SrcElementSize = SrcElementType.getSizeInBits(); + return DstTy.getSizeInBits() > SrcTy.getSizeInBits() && + DstElementCount == SrcElementCount && + (DstElementSize == (SrcElementSize * 2)); + } + return false; + }; +} + static LegalityPredicate negatePredicate(const std::function &Func) { return [=](const LegalityQuery &Query) { return !Func(Query); }; @@ -219,6 +244,13 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST) getActionDefinitionsBuilder(G_FPEXT) .libcallFor({{S64, S32}}) .customFor({{S32, S16}}) + // Add support for vector types + // Extend vectors to have at least 512-bits + .clampMinNumElements(1, S8, 64) + .clampMinNumElements(1, S16, 32) + .clampMinNumElements(1, S32, 16) + .customIf(isValidVectorFPEXT(0 /* Dst */, 1 /* Src */)) + // .customFor({{V32S32, V32S16}}) .narrowScalarFor({{S64, S16}}, llvm::LegalizeMutations::changeTo(0, S32)); getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) @@ -241,7 +273,34 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST) getActionDefinitionsBuilder({G_FADD, G_FSUB}) .legalFor({AccV64S32}) - .customFor({S16}) + // Handle custom bf16 case for both scalar and vector types + .customFor({S16, V32S16}) + // Convert smaller than <32 x f32/bf16> to legal sizes, doesn't change types + .moreElementsIf( + [=](const LegalityQuery &Query) { + const LLT &Ty = Query.Types[0]; + return Ty.isVector() && + (Ty.getScalarSizeInBits() == 32 || + Ty.getScalarSizeInBits() == 16) && + Ty.getNumElements() <= 32; + }, + [=](const LegalityQuery &Query) { + if (Query.Types[0].getScalarSizeInBits() == 32) { + return std::make_pair(0, LLT::fixed_vector(64, S32)); + } else { + return std::make_pair(0, LLT::fixed_vector(32, S16)); + } + }) + // Converts <64xbf16> into 2 chunks of <32xbf16> + .fewerElementsIf( + [=](const LegalityQuery &Query) { + const LLT &Ty = Query.Types[0]; + return Ty.isVector() && (Ty.getScalarSizeInBits() == 16) && + Ty.getNumElements() == 64; + }, + [=](const LegalityQuery &Query) { + return std::make_pair(0, LLT::fixed_vector(32, S16)); + }) .libcallFor({S32, S64}); getActionDefinitionsBuilder({G_FDIV, G_FREM}) diff --git a/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fadd.ll b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fadd.ll new file mode 100644 index 000000000000..c6e7feae1338 --- /dev/null +++ b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fadd.ll @@ -0,0 +1,56 @@ +; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s +; This test is a carved out test for sending patch upstream from +; iree-amd-aie/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/multi_reduction_to_reduction_sizes_types.mlirUntitled-1.mlir + +; Ideally reduction should be as follows(with minor changes for each shape): + ; Input1: <32xbf16> and Input2: <32xbf16> + ; Extended1<32xf32> = fpext <32xbf16> + ; Extended2<32xf32> = fpext <32xbf16> + ; Zero<32xf32> = zeroinitializer + ; Out1<64xf32> = Concat zero, > + ; Out2<64xf32> = Concat zero, > + ; Result<64xf32> = fadd >, > + ; R1<32xf32>, R2<32xf32> = unmerge > + ; R2 is all 0s + ; R1<32xbf16> = trunc > + +; check the vadd.f +; pad checks +; checks similar to <32xbf16> +; unpad checks +define bfloat @multi_reduction_1d_16_bf16(<16 x bfloat> %0, bfloat %1) { + %3 = call reassoc bfloat @llvm.vector.reduce.fadd.v16bf16(bfloat %1, <16 x bfloat> %0) + ret bfloat %3 +} + + + +; CHECK-LABEL: name: multi_reduction_1d_32_bf16 +; CHECK: G_CONSTANT i32 0 +; CHECK: G_AIE_BROADCAST_VECTOR %{{[0-9]+}}(s32) +; CHECK: G_CONSTANT i32 2 +; CHECK: G_CONSTANT i32 3 +; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32) +; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32) +; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>) +; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>) +; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>) +; CHECK: G_IMPLICIT_DEF +; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<32 x s32>), %{{[0-9]+}}(<32 x s32>) +; CHECK: G_FADD %{{[0-9]+}}, %{{[0-9]+}} +; CHECK: G_UNMERGE_VALUES %{{[0-9]+}}(<64 x s32>) +; CHECK: G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.aie2p.v32accfloat.to.v32bf16), %{{[0-9]+}}(<32 x s32>) +define bfloat @multi_reduction_1d_32_bf16(<32 x bfloat> %0, bfloat %1) { + %3 = call reassoc bfloat @llvm.vector.reduce.fadd.v32bf16(bfloat %1, <32 x bfloat> %0) + ret bfloat %3 +} + +; ; Converted to chunks of <32 x bf16> +; Check if the input is split into 2 chunks of <32 x bf16> +; Check for each chunk similar to <32xbf16> case +; Check if both inputs get concatenated to <64xbf16> + +define bfloat @multi_reduction_1d_64_bf16(<64 x bfloat> %0, bfloat %1) { + %3 = call reassoc bfloat @llvm.vector.reduce.fadd.v64bf16(bfloat %1, <64 x bfloat> %0) + ret bfloat %3 +} diff --git a/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll new file mode 100644 index 000000000000..6eb8fa08279f --- /dev/null +++ b/llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll @@ -0,0 +1,73 @@ +; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s + + +; Validates bfloat -> float legalization. +; CHECK-LABEL: name: extend +; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0 +; CHECK-NOT: G_SHL +; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2 +; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3 +; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32) +; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C2]](s32) +; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C3]](s32) +; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>) +; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>) +; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[BIT1]](<16 x s32>), [[BIT2]](<16 x s32>) + +define <32 x float> @extend(bfloat %o, <32 x bfloat> %in) nounwind { + %X = fpext <32 x bfloat> %in to <32 x float> + ret <32 x float> %X +} + +; Pads the 17 valid values with undefined values to form a 32 size vector. + +; CHECK-LABEL: name: extend_non_power_of_2 +; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0 +; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT +; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI +; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2 +; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3 +; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 +; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32) +; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C2]](s32) +; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C3]](s32) +; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>) +; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>) +; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT +; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI +; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>) +define <17 x float> @extend_non_power_of_2(<17 x bfloat> %in) nounwind { + %X = fpext <17 x bfloat> %in to <17 x float> + ret <17 x float> %X +} + +; Validates if vector size < 256 bits + +; CHECK-LABEL: name: fpext_bf16_to_f32 +; CHECK: bb.1 +; CHECK: [[VEC_CONCAT:%[0-9]+]]:_(<32 x s16>) = G_CONCAT_VECTORS +; CHECK: G_AIE_SEXT_EXTRACT_VECTOR_ELT [[VEC_CONCAT]] +; CHECK: G_AIE_ADD_VECTOR_ELT_HI +; CHECK: [[SHUFFLE_VEC:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR +; CHECK-NOT: G_AIE_SHUFFLE_VECTOR +; CHECK: [[BITCAST:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUFFLE_VEC]] +; CHECK: $x0 = COPY [[BITCAST]] +define <16 x float> @fpext_bf16_to_f32(<16 x bfloat> %in) nounwind { + %X = fpext <16 x bfloat> %in to <16 x float> + ret <16 x float> %X +} + +; Validates scalar path +; CHECK-LABEL: name: fpext_scalar_bf16_to_f32 +; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $r1 +; CHECK-NEXT: [[C16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16 +; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[COPY]], [[C16]](s32) +; CHECK-NOT: G_AIE_SHUFFLE_VECTOR +; CHECK-NEXT: $r0 = COPY [[SHL]](s32) +; CHECK-NEXT: PseudoRET implicit $lr, implicit $r0 + +define float @fpext_scalar_bf16_to_f32(bfloat %in) nounwind { + %X = fpext bfloat %in to float + ret float %X +}