Skip to content

Commit 053e4d2

Browse files
committed
[SPIR-V] Add support for SPV_INTEL_usm_storage_classes extension
With this extension 2 new Storage Classes are introduced: DeviceOnlyINTEL and HostOnlyINTEL appropriately mapped on global_device and global_host SYCL/OpenCL address spaces which are part of SYCL_INTEL_usm_address_spaces extension. Co-authored-by: Viktoria Maksimova <[email protected]> Signed-off-by: Dmitry Sidorov <[email protected]> Signed-off-by: Viktoria Maksimova <[email protected]>
1 parent 3aec7f7 commit 053e4d2

15 files changed

+338
-12
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ EXT(SPV_INTEL_arbitrary_precision_integers)
2121
EXT(SPV_INTEL_optimization_hints)
2222
EXT(SPV_INTEL_float_controls2)
2323
EXT(SPV_INTEL_vector_compute)
24+
EXT(SPV_INTEL_usm_storage_classes)

llvm-spirv/lib/SPIRV/Mangler/ParameterType.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ enum TypeAttributeEnum {
136136
ATTR_CONSTANT,
137137
ATTR_LOCAL,
138138
ATTR_GENERIC,
139-
ATTR_ADDR_SPACE_LAST = ATTR_GENERIC,
139+
ATTR_GLOBAL_DEVICE,
140+
ATTR_GLOBAL_HOST,
141+
ATTR_ADDR_SPACE_LAST = ATTR_GLOBAL_HOST,
140142
ATTR_NONE,
141143
ATTR_NUM = ATTR_NONE
142144
};

llvm-spirv/lib/SPIRV/OCLUtil.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ static SPIR::TypeAttributeEnum mapAddrSpaceEnums(SPIRAddressSpace Addrspace) {
349349
return SPIR::ATTR_LOCAL;
350350
case SPIRAS_Generic:
351351
return SPIR::ATTR_GENERIC;
352+
case SPIRAS_GlobalDevice:
353+
return SPIR::ATTR_GLOBAL_DEVICE;
354+
case SPIRAS_GlobalHost:
355+
return SPIR::ATTR_GLOBAL_HOST;
352356
default:
353357
llvm_unreachable("Invalid addrspace enum member");
354358
}

llvm-spirv/lib/SPIRV/SPIRVInternal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ enum SPIRAddressSpace {
188188
SPIRAS_Constant,
189189
SPIRAS_Local,
190190
SPIRAS_Generic,
191+
SPIRAS_GlobalDevice,
192+
SPIRAS_GlobalHost,
191193
SPIRAS_Input,
192194
SPIRAS_Output,
193195
SPIRAS_Count,
@@ -200,6 +202,8 @@ template <> inline void SPIRVMap<SPIRAddressSpace, std::string>::init() {
200202
add(SPIRAS_Local, "Local");
201203
add(SPIRAS_Generic, "Generic");
202204
add(SPIRAS_Input, "Input");
205+
add(SPIRAS_GlobalDevice, "GlobalDevice");
206+
add(SPIRAS_GlobalHost, "GlobalHost");
203207
}
204208
typedef SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind>
205209
SPIRAddrSpaceCapitalizedNameMap;
@@ -212,6 +216,8 @@ inline void SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind>::init() {
212216
add(SPIRAS_Local, StorageClassWorkgroup);
213217
add(SPIRAS_Generic, StorageClassGeneric);
214218
add(SPIRAS_Input, StorageClassInput);
219+
add(SPIRAS_GlobalDevice, StorageClassDeviceOnlyINTEL);
220+
add(SPIRAS_GlobalHost, StorageClassHostOnlyINTEL);
215221
}
216222
typedef SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind> SPIRSPIRVAddrSpaceMap;
217223

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,18 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
931931
switch (BC->getOpCode()) {
932932
case OpPtrCastToGeneric:
933933
case OpGenericCastToPtr:
934+
case OpPtrCastToCrossWorkgroupINTEL:
935+
case OpCrossWorkgroupCastToPtrINTEL: {
936+
// If module has pointers with DeviceOnlyINTEL and HostOnlyINTEL storage
937+
// classes there will be a situation, when global_device/global_host
938+
// address space will be lowered to just global address space. If there also
939+
// is an addrspacecast - we need to replace it with source pointer.
940+
if (Src->getType()->getPointerAddressSpace() ==
941+
Dst->getPointerAddressSpace())
942+
return Src;
934943
CO = Instruction::AddrSpaceCast;
935944
break;
945+
}
936946
case OpSConvert:
937947
CO = IsExt ? Instruction::SExt : Instruction::Trunc;
938948
break;
@@ -3359,7 +3369,7 @@ bool SPIRVToLLVM::transOCLMetadata(SPIRVFunction *BF) {
33593369
if (F->getCallingConv() != CallingConv::SPIR_KERNEL)
33603370
return true;
33613371

3362-
// Generate metadata for kernel_arg_address_spaces
3372+
// Generate metadata for kernel_arg_addr_space
33633373
addOCLKernelArgumentMetadata(
33643374
Context, SPIR_MD_KERNEL_ARG_ADDR_SPACE, BF, F,
33653375
[=](SPIRVFunctionParameter *Arg) {

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ SPIRVType *LLVMToSPIRV::transType(Type *T) {
307307
return nullptr;
308308
auto ST = dyn_cast<StructType>(ET);
309309
auto AddrSpc = T->getPointerAddressSpace();
310+
// Lower global_device and global_host address spaces that were added in
311+
// SYCL as part of SYCL_INTEL_usm_address_spaces extension to just global
312+
// address space if device doesn't support SPV_INTEL_usm_storage_classes
313+
// extension
314+
if (!BM->isAllowedToUseExtension(
315+
ExtensionID::SPV_INTEL_usm_storage_classes) &&
316+
((AddrSpc == SPIRAS_GlobalDevice) || (AddrSpc == SPIRAS_GlobalHost)))
317+
AddrSpc = SPIRAS_Global;
310318
if (ST && !ST->isSized()) {
311319
Op OpCode;
312320
StringRef STName = ST->getName();
@@ -760,14 +768,47 @@ SPIRV::SPIRVInstruction *LLVMToSPIRV::transUnaryInst(UnaryInstruction *U,
760768
Op BOC = OpNop;
761769
SPIRVValue *Op = nullptr;
762770
if (auto Cast = dyn_cast<AddrSpaceCastInst>(U)) {
763-
if (Cast->getDestTy()->getPointerAddressSpace() == SPIRAS_Generic) {
764-
assert(Cast->getSrcTy()->getPointerAddressSpace() != SPIRAS_Constant &&
771+
const auto SrcAddrSpace = Cast->getSrcTy()->getPointerAddressSpace();
772+
const auto DestAddrSpace = Cast->getDestTy()->getPointerAddressSpace();
773+
if (DestAddrSpace == SPIRAS_Generic) {
774+
assert(SrcAddrSpace != SPIRAS_Constant &&
765775
"Casts from constant address space to generic are illegal");
766776
BOC = OpPtrCastToGeneric;
777+
// In SPIR-V only casts to/from generic are allowed. But with
778+
// SPV_INTEL_usm_storage_classes we can also have casts from global_device
779+
// and global_host to global addr space and vice versa.
780+
} else if (SrcAddrSpace == SPIRAS_GlobalDevice ||
781+
SrcAddrSpace == SPIRAS_GlobalHost) {
782+
assert(
783+
(DestAddrSpace == SPIRAS_Global || DestAddrSpace == SPIRAS_Generic) &&
784+
"Casts from global_device/global_host only allowed to \
785+
global/generic");
786+
if (!BM->isAllowedToUseExtension(
787+
ExtensionID::SPV_INTEL_usm_storage_classes)) {
788+
if (DestAddrSpace == SPIRAS_Global)
789+
return nullptr;
790+
BOC = OpPtrCastToGeneric;
791+
} else {
792+
BOC = OpPtrCastToCrossWorkgroupINTEL;
793+
}
794+
} else if (DestAddrSpace == SPIRAS_GlobalDevice ||
795+
DestAddrSpace == SPIRAS_GlobalHost) {
796+
assert(
797+
(SrcAddrSpace == SPIRAS_Global || SrcAddrSpace == SPIRAS_Generic) &&
798+
"Casts to global_device/global_host only allowed from \
799+
global/generic");
800+
if (!BM->isAllowedToUseExtension(
801+
ExtensionID::SPV_INTEL_usm_storage_classes)) {
802+
if (SrcAddrSpace == SPIRAS_Global)
803+
return nullptr;
804+
BOC = OpGenericCastToPtr;
805+
} else {
806+
BOC = OpCrossWorkgroupCastToPtrINTEL;
807+
}
767808
} else {
768-
assert(Cast->getDestTy()->getPointerAddressSpace() != SPIRAS_Constant &&
809+
assert(DestAddrSpace != SPIRAS_Constant &&
769810
"Casts from generic address space to constant are illegal");
770-
assert(Cast->getSrcTy()->getPointerAddressSpace() == SPIRAS_Generic);
811+
assert(SrcAddrSpace == SPIRAS_Generic);
771812
BOC = OpGenericCastToPtr;
772813
}
773814
} else {
@@ -1056,8 +1097,18 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
10561097
if (IsVectorCompute)
10571098
StorageClass =
10581099
VectorComputeUtil::getVCGlobalVarStorageClass(AddressSpace);
1059-
else
1100+
else {
1101+
// Lower global_device and global_host address spaces that were added in
1102+
// SYCL as part of SYCL_INTEL_usm_address_spaces extension to just global
1103+
// address space if device doesn't support SPV_INTEL_usm_storage_classes
1104+
// extension
1105+
if ((AddressSpace == SPIRAS_GlobalDevice ||
1106+
AddressSpace == SPIRAS_GlobalHost) &&
1107+
!BM->isAllowedToUseExtension(
1108+
ExtensionID::SPV_INTEL_usm_storage_classes))
1109+
AddressSpace = SPIRAS_Global;
10601110
StorageClass = SPIRSPIRVAddrSpaceMap::map(AddressSpace);
1111+
}
10611112

10621113
auto BVar = static_cast<SPIRVVariable *>(
10631114
BM->addVariable(transType(Ty), GV->isConstant(), transLinkageType(GV),
@@ -1315,7 +1366,8 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
13151366
if (UnaryInstruction *U = dyn_cast<UnaryInstruction>(V)) {
13161367
if (isSpecialTypeInitializer(U))
13171368
return mapValue(V, transValue(U->getOperand(0), BB));
1318-
return mapValue(V, transUnaryInst(U, BB));
1369+
auto UI = transUnaryInst(U, BB);
1370+
return mapValue(V, UI ? UI : transValue(U->getOperand(0), BB));
13191371
}
13201372

13211373
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ template <> inline void SPIRVMap<SPIRVStorageClassKind, SPIRVCapVec>::init() {
266266
ADD_VEC_INIT(StorageClassGeneric, {CapabilityGenericPointer});
267267
ADD_VEC_INIT(StorageClassPushConstant, {CapabilityShader});
268268
ADD_VEC_INIT(StorageClassAtomicCounter, {CapabilityAtomicStorage});
269+
ADD_VEC_INIT(StorageClassDeviceOnlyINTEL, {CapabilityUSMStorageClassesINTEL});
270+
ADD_VEC_INIT(StorageClassHostOnlyINTEL, {CapabilityUSMStorageClassesINTEL});
269271
}
270272

271273
template <> inline void SPIRVMap<SPIRVImageDimKind, SPIRVCapVec>::init() {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ bool isSpecConstantOpAllowedOp(Op OC) {
183183
OpConvertUToPtr,
184184
OpGenericCastToPtr,
185185
OpPtrCastToGeneric,
186+
OpCrossWorkgroupCastToPtrINTEL,
187+
OpPtrCastToCrossWorkgroupINTEL,
186188
OpBitcast,
187189
OpQuantizeToF16,
188190
OpSNegate,

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,15 @@ class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
538538
SPIRVInstruction::validate();
539539
if (getSrc()->isForward() || getDst()->isForward())
540540
return;
541-
assert(getValueType(PtrId)->getPointerElementType() ==
542-
getValueType(ValId) &&
543-
"Inconsistent operand types");
541+
#ifndef NDEBUG
542+
if (getValueType(PtrId)->getPointerElementType() != getValueType(ValId)) {
543+
assert(getValueType(PtrId)
544+
->getPointerElementType()
545+
->getPointerStorageClass() ==
546+
getValueType(ValId)->getPointerStorageClass() &&
547+
"Inconsistent operand types");
548+
}
549+
#endif // NDEBUG
544550
}
545551

546552
private:
@@ -1603,6 +1609,8 @@ _SPIRV_OP(ConvertPtrToU)
16031609
_SPIRV_OP(ConvertUToPtr)
16041610
_SPIRV_OP(PtrCastToGeneric)
16051611
_SPIRV_OP(GenericCastToPtr)
1612+
_SPIRV_OP(CrossWorkgroupCastToPtrINTEL)
1613+
_SPIRV_OP(PtrCastToCrossWorkgroupINTEL)
16061614
_SPIRV_OP(Bitcast)
16071615
_SPIRV_OP(SNegate)
16081616
_SPIRV_OP(FNegate)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ inline bool isValid(spv::StorageClass V) {
174174
case StorageClassPushConstant:
175175
case StorageClassAtomicCounter:
176176
case StorageClassImage:
177+
case StorageClassDeviceOnlyINTEL:
178+
case StorageClassHostOnlyINTEL:
177179
return true;
178180
default:
179181
return false;
@@ -728,6 +730,8 @@ inline bool isValid(spv::Op V) {
728730
case OpConvertUToPtr:
729731
case OpPtrCastToGeneric:
730732
case OpGenericCastToPtr:
733+
case OpPtrCastToCrossWorkgroupINTEL:
734+
case OpCrossWorkgroupCastToPtrINTEL:
731735
case OpGenericCastToPtrExplicit:
732736
case OpBitcast:
733737
case OpSNegate:

0 commit comments

Comments
 (0)