Skip to content

Commit 27c361b

Browse files
committed
Lower llvm.dx.rawbufferload to dxil ops
This PR lowers the @llvm.dx.rawBufferLoad intrinsic to @dx.op.rawBufferLoad
1 parent 8e1bca0 commit 27c361b

File tree

4 files changed

+396
-0
lines changed

4 files changed

+396
-0
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
2727
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
2828
[IntrNoMem]>;
2929

30+
def int_dx_rawBufferLoad
31+
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty, llvm_i32_ty]>;
32+
3033
def int_dx_typedBufferLoad
3134
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
3235
[IntrReadMem]>;

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,17 @@ def AnnotateHandle : DXILOp<216, annotateHandle> {
854854
let stages = [Stages<DXIL1_6, [all_stages]>];
855855
}
856856

857+
def RawBufferLoad : DXILOp<139, rawBufferLoad> {
858+
let Doc = "reads from a ByteAddressBuffer or StructuredBuffer";
859+
// Handle, Coord0, Coord1, mask, alignment
860+
let arguments = [HandleTy, Int32Ty, Int32Ty, Int8Ty, Int32Ty];
861+
let result = OverloadTy;
862+
let overloads =
863+
[Overloads<DXIL1_0,
864+
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
865+
let stages = [Stages<DXIL1_0, [all_stages]>];
866+
}
867+
857868
def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
858869
let Doc = "create resource handle from binding";
859870
let arguments = [ResBindTy, Int32Ty, Int1Ty];

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,193 @@ class OpLowerer {
628628
});
629629
}
630630

631+
Value *GenerateRawBufLd(Value *handle, Value *bufIdx, Value *offset, Type *Ty,
632+
IRBuilder<> &Builder, unsigned NumComponents,
633+
Constant *alignment) {
634+
if (bufIdx == nullptr) {
635+
// This is actually a byte address buffer load with a struct template
636+
// type. The call takes only one coordinates for the offset.
637+
bufIdx = offset;
638+
offset = UndefValue::get(offset->getType());
639+
}
640+
641+
// NumComponents 1: mask = 1 // Mask_X;
642+
// NumComponents 2: mask = 3 // Mask_X | Mask_Y
643+
// NumComponents 3: mask = 7 // Mask_X | Mask_Y | Mask_Z
644+
// NumComponents 4: mask = 15 // Mask_X | Mask_Y | Mask_Z | Mask_W
645+
assert((NumComponents) > 0 && (NumComponents < 5));
646+
Constant *mask =
647+
ConstantInt::get(Builder.getInt8Ty(), ((1 << NumComponents) - 1));
648+
649+
Value *Args[] = {handle, bufIdx, offset, mask, alignment};
650+
Type *NewRetTy = OpBuilder.getResRetType(Ty->getScalarType());
651+
Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
652+
OpCode::RawBufferLoad, Args, "", NewRetTy); // TODO: Need name argument?
653+
if (Error E = OpCall.takeError())
654+
return nullptr;
655+
656+
return *OpCall;
657+
}
658+
659+
void TranslateRawBufVecLd(Type *Ty, unsigned ElemCount, IRBuilder<> &Builder,
660+
Value *handle, Value *bufIdx, Value *baseOffset,
661+
const DataLayout &DL, std::vector<Value *> &bufLds,
662+
unsigned baseAlign, bool isScalarTy) {
663+
Type *VecEltTy = Ty->getScalarType();
664+
665+
unsigned EltSize = DL.getTypeAllocSize(VecEltTy);
666+
unsigned alignment = std::min(baseAlign, EltSize);
667+
Constant *alignmentVal =
668+
ConstantInt::get(M.getContext(), APInt(32, alignment));
669+
670+
if (baseOffset == nullptr) {
671+
baseOffset = ConstantInt::get(Builder.getInt32Ty(), 0);
672+
}
673+
674+
std::vector<Value *> elts(ElemCount);
675+
unsigned rest = (ElemCount % 4);
676+
for (unsigned i = 0; i < ElemCount - rest; i += 4) {
677+
Value *bufLd = GenerateRawBufLd(handle, bufIdx, baseOffset, Ty, Builder,
678+
4, alignmentVal);
679+
bufLds.emplace_back(bufLd);
680+
681+
baseOffset = Builder.CreateAdd(
682+
baseOffset, ConstantInt::get(Builder.getInt32Ty(), 4 * EltSize));
683+
}
684+
685+
if (rest) {
686+
Value *bufLd = GenerateRawBufLd(handle, bufIdx, baseOffset, Ty, Builder,
687+
rest, alignmentVal);
688+
bufLds.emplace_back(bufLd);
689+
}
690+
}
691+
692+
Error replaceMultiResRetsUses(CallInst *Intrin,
693+
std::vector<Value *> &bufLds) {
694+
IRBuilder<> &IRB = OpBuilder.getIRB();
695+
696+
// TODO: HasCheckBit????
697+
698+
Type *OldTy = Intrin->getType();
699+
700+
// For scalars, we just extract the first element.
701+
if (!isa<FixedVectorType>(OldTy)) {
702+
CallInst *Op = dyn_cast<CallInst>(bufLds[0]);
703+
assert(Op != nullptr);
704+
Value *EVI = IRB.CreateExtractValue(Op, 0);
705+
706+
Intrin->replaceAllUsesWith(EVI);
707+
Intrin->eraseFromParent();
708+
709+
return Error::success();
710+
}
711+
712+
const auto *VecTy = cast<FixedVectorType>(OldTy);
713+
const unsigned N = VecTy->getNumElements();
714+
715+
std::vector<Value *> Extracts(N);
716+
717+
// The users of the operation should all be scalarized, so we attempt to
718+
// replace the extractelements with extractvalues directly.
719+
for (Use &U : make_early_inc_range(Intrin->uses())) {
720+
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
721+
if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
722+
size_t IndexVal = IndexOp->getZExtValue();
723+
assert(IndexVal < N && "Index into buffer load out of range");
724+
if (!Extracts[IndexVal]) {
725+
CallInst *Op = dyn_cast<CallInst>(bufLds[IndexVal / 4]);
726+
assert(Op != nullptr);
727+
Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal % 4);
728+
}
729+
EEI->replaceAllUsesWith(Extracts[IndexVal]);
730+
EEI->eraseFromParent();
731+
} else {
732+
// Need to handle DynamicAccesses here???
733+
}
734+
}
735+
}
736+
737+
// If there's a dynamic access we need to round trip through stack memory so
738+
// that we don't leave vectors around.
739+
//
740+
// TODO: dynamic access for rawbuffer??????
741+
//
742+
743+
// If we still have uses, then we're not fully scalarized and need to
744+
// recreate the vector. This should only happen for things like exported
745+
// functions from libraries.
746+
if (!Intrin->use_empty()) {
747+
for (int I = 0, E = N; I != E; ++I)
748+
if (!Extracts[I]) {
749+
CallInst *Op = dyn_cast<CallInst>(bufLds[I / 4]);
750+
assert(Op != nullptr);
751+
Extracts[I] = IRB.CreateExtractValue(Op, I % 4);
752+
}
753+
754+
Value *Vec = UndefValue::get(OldTy);
755+
for (int I = 0, E = N; I != E; ++I)
756+
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
757+
758+
Intrin->replaceAllUsesWith(Vec);
759+
}
760+
761+
// TODO:
762+
// Remove the dx.op.rawbufferload without any uses now?
763+
764+
Intrin->eraseFromParent();
765+
766+
return Error::success();
767+
}
768+
769+
[[nodiscard]] bool lowerRawBufferLoad(Function &F) {
770+
IRBuilder<> &IRB = OpBuilder.getIRB();
771+
772+
return replaceFunction(F, [&](CallInst *CI) -> Error {
773+
IRB.SetInsertPoint(CI);
774+
#if 0
775+
auto *It = DRM.find(dyn_cast<CallInst>(CI->getArgOperand(0)));
776+
assert(It != DRM.end() && "Resource not in map?");
777+
dxil::ResourceInfo &RI = *It;
778+
779+
assert((RI.getResourceKind() == dxil::ResourceKind::StructuredBuffer) ||
780+
(RI.getResourceKind() == dxil::ResourceKind::RawBuffer));
781+
#else
782+
ResourceKind RCKind = dxil::ResourceKind::StructuredBuffer;
783+
#endif
784+
785+
Type *Ty = CI->getType();
786+
std::vector<Value *> bufLds;
787+
// TODO: Need check Bool type load???
788+
789+
unsigned numComponents = 1;
790+
if (Ty->isVectorTy()) {
791+
numComponents = dyn_cast<FixedVectorType>(Ty)->getNumElements();
792+
}
793+
794+
Value *Handle =
795+
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
796+
Value *bufIdx = CI->getArgOperand(1);
797+
Value *baseOffset = CI->getArgOperand(2);
798+
799+
bool isScalarTy = !Ty->isVectorTy();
800+
801+
if (RCKind == dxil::ResourceKind::StructuredBuffer) {
802+
TranslateRawBufVecLd(Ty, numComponents, IRB, Handle, bufIdx, baseOffset,
803+
F.getDataLayout(), bufLds,
804+
/*baseAlign (in bytes)*/ 8, isScalarTy);
805+
} else {
806+
TranslateRawBufVecLd(Ty, numComponents, IRB, Handle, bufIdx, baseOffset,
807+
F.getDataLayout(), bufLds,
808+
/*baseAlign (in bytes)*/ 4, isScalarTy);
809+
}
810+
811+
if (Error E = replaceMultiResRetsUses(CI, bufLds))
812+
return E;
813+
814+
return Error::success();
815+
});
816+
}
817+
631818
bool lowerIntrinsics() {
632819
bool Updated = false;
633820
bool HasErrors = false;
@@ -647,6 +834,9 @@ class OpLowerer {
647834
case Intrinsic::dx_handle_fromBinding:
648835
HasErrors |= lowerHandleFromBinding(F);
649836
break;
837+
case Intrinsic::dx_rawBufferLoad:
838+
HasErrors |= lowerRawBufferLoad(F);
839+
break;
650840
case Intrinsic::dx_typedBufferLoad:
651841
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
652842
break;

0 commit comments

Comments
 (0)