Skip to content

[SPIR-V] Add support for SPV_INTEL_usm_storage_classes extension #1985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm-spirv/include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ EXT(SPV_INTEL_arbitrary_precision_integers)
EXT(SPV_INTEL_optimization_hints)
EXT(SPV_INTEL_float_controls2)
EXT(SPV_INTEL_vector_compute)
EXT(SPV_INTEL_usm_storage_classes)
4 changes: 3 additions & 1 deletion llvm-spirv/lib/SPIRV/Mangler/ParameterType.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ enum TypeAttributeEnum {
ATTR_CONSTANT,
ATTR_LOCAL,
ATTR_GENERIC,
ATTR_ADDR_SPACE_LAST = ATTR_GENERIC,
ATTR_GLOBAL_DEVICE,
ATTR_GLOBAL_HOST,
ATTR_ADDR_SPACE_LAST = ATTR_GLOBAL_HOST,
ATTR_NONE,
ATTR_NUM = ATTR_NONE
};
Expand Down
4 changes: 4 additions & 0 deletions llvm-spirv/lib/SPIRV/OCLUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ static SPIR::TypeAttributeEnum mapAddrSpaceEnums(SPIRAddressSpace Addrspace) {
return SPIR::ATTR_LOCAL;
case SPIRAS_Generic:
return SPIR::ATTR_GENERIC;
case SPIRAS_GlobalDevice:
return SPIR::ATTR_GLOBAL_DEVICE;
case SPIRAS_GlobalHost:
return SPIR::ATTR_GLOBAL_HOST;
default:
llvm_unreachable("Invalid addrspace enum member");
}
Expand Down
6 changes: 6 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ enum SPIRAddressSpace {
SPIRAS_Constant,
SPIRAS_Local,
SPIRAS_Generic,
SPIRAS_GlobalDevice,
SPIRAS_GlobalHost,
SPIRAS_Input,
SPIRAS_Output,
SPIRAS_Count,
Expand All @@ -200,6 +202,8 @@ template <> inline void SPIRVMap<SPIRAddressSpace, std::string>::init() {
add(SPIRAS_Local, "Local");
add(SPIRAS_Generic, "Generic");
add(SPIRAS_Input, "Input");
add(SPIRAS_GlobalDevice, "GlobalDevice");
add(SPIRAS_GlobalHost, "GlobalHost");
}
typedef SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind>
SPIRAddrSpaceCapitalizedNameMap;
Expand All @@ -212,6 +216,8 @@ inline void SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind>::init() {
add(SPIRAS_Local, StorageClassWorkgroup);
add(SPIRAS_Generic, StorageClassGeneric);
add(SPIRAS_Input, StorageClassInput);
add(SPIRAS_GlobalDevice, StorageClassDeviceOnlyINTEL);
add(SPIRAS_GlobalHost, StorageClassHostOnlyINTEL);
}
typedef SPIRVMap<SPIRAddressSpace, SPIRVStorageClassKind> SPIRSPIRVAddrSpaceMap;

Expand Down
12 changes: 11 additions & 1 deletion llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,18 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
switch (BC->getOpCode()) {
case OpPtrCastToGeneric:
case OpGenericCastToPtr:
case OpPtrCastToCrossWorkgroupINTEL:
case OpCrossWorkgroupCastToPtrINTEL: {
// If module has pointers with DeviceOnlyINTEL and HostOnlyINTEL storage
// classes there will be a situation, when global_device/global_host
// address space will be lowered to just global address space. If there also
// is an addrspacecast - we need to replace it with source pointer.
if (Src->getType()->getPointerAddressSpace() ==
Dst->getPointerAddressSpace())
return Src;
CO = Instruction::AddrSpaceCast;
break;
}
case OpSConvert:
CO = IsExt ? Instruction::SExt : Instruction::Trunc;
break;
Expand Down Expand Up @@ -3359,7 +3369,7 @@ bool SPIRVToLLVM::transOCLMetadata(SPIRVFunction *BF) {
if (F->getCallingConv() != CallingConv::SPIR_KERNEL)
return true;

// Generate metadata for kernel_arg_address_spaces
// Generate metadata for kernel_arg_addr_space
addOCLKernelArgumentMetadata(
Context, SPIR_MD_KERNEL_ARG_ADDR_SPACE, BF, F,
[=](SPIRVFunctionParameter *Arg) {
Expand Down
64 changes: 58 additions & 6 deletions llvm-spirv/lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ SPIRVType *LLVMToSPIRV::transType(Type *T) {
return nullptr;
auto ST = dyn_cast<StructType>(ET);
auto AddrSpc = T->getPointerAddressSpace();
// Lower global_device and global_host address spaces that were added in
// SYCL as part of SYCL_INTEL_usm_address_spaces extension to just global
// address space if device doesn't support SPV_INTEL_usm_storage_classes
// extension
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_usm_storage_classes) &&
((AddrSpc == SPIRAS_GlobalDevice) || (AddrSpc == SPIRAS_GlobalHost)))
AddrSpc = SPIRAS_Global;
if (ST && !ST->isSized()) {
Op OpCode;
StringRef STName = ST->getName();
Expand Down Expand Up @@ -760,14 +768,47 @@ SPIRV::SPIRVInstruction *LLVMToSPIRV::transUnaryInst(UnaryInstruction *U,
Op BOC = OpNop;
SPIRVValue *Op = nullptr;
if (auto Cast = dyn_cast<AddrSpaceCastInst>(U)) {
if (Cast->getDestTy()->getPointerAddressSpace() == SPIRAS_Generic) {
assert(Cast->getSrcTy()->getPointerAddressSpace() != SPIRAS_Constant &&
const auto SrcAddrSpace = Cast->getSrcTy()->getPointerAddressSpace();
const auto DestAddrSpace = Cast->getDestTy()->getPointerAddressSpace();
if (DestAddrSpace == SPIRAS_Generic) {
assert(SrcAddrSpace != SPIRAS_Constant &&
"Casts from constant address space to generic are illegal");
BOC = OpPtrCastToGeneric;
// In SPIR-V only casts to/from generic are allowed. But with
// SPV_INTEL_usm_storage_classes we can also have casts from global_device
// and global_host to global addr space and vice versa.
} else if (SrcAddrSpace == SPIRAS_GlobalDevice ||
SrcAddrSpace == SPIRAS_GlobalHost) {
assert(
(DestAddrSpace == SPIRAS_Global || DestAddrSpace == SPIRAS_Generic) &&
"Casts from global_device/global_host only allowed to \
global/generic");
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_usm_storage_classes)) {
if (DestAddrSpace == SPIRAS_Global)
return nullptr;
BOC = OpPtrCastToGeneric;
} else {
BOC = OpPtrCastToCrossWorkgroupINTEL;
}
} else if (DestAddrSpace == SPIRAS_GlobalDevice ||
DestAddrSpace == SPIRAS_GlobalHost) {
assert(
(SrcAddrSpace == SPIRAS_Global || SrcAddrSpace == SPIRAS_Generic) &&
"Casts to global_device/global_host only allowed from \
global/generic");
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_usm_storage_classes)) {
if (SrcAddrSpace == SPIRAS_Global)
return nullptr;
BOC = OpGenericCastToPtr;
} else {
BOC = OpCrossWorkgroupCastToPtrINTEL;
}
} else {
assert(Cast->getDestTy()->getPointerAddressSpace() != SPIRAS_Constant &&
assert(DestAddrSpace != SPIRAS_Constant &&
"Casts from generic address space to constant are illegal");
assert(Cast->getSrcTy()->getPointerAddressSpace() == SPIRAS_Generic);
assert(SrcAddrSpace == SPIRAS_Generic);
BOC = OpGenericCastToPtr;
}
} else {
Expand Down Expand Up @@ -1056,8 +1097,18 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
if (IsVectorCompute)
StorageClass =
VectorComputeUtil::getVCGlobalVarStorageClass(AddressSpace);
else
else {
// Lower global_device and global_host address spaces that were added in
// SYCL as part of SYCL_INTEL_usm_address_spaces extension to just global
// address space if device doesn't support SPV_INTEL_usm_storage_classes
// extension
if ((AddressSpace == SPIRAS_GlobalDevice ||
AddressSpace == SPIRAS_GlobalHost) &&
!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_usm_storage_classes))
AddressSpace = SPIRAS_Global;
StorageClass = SPIRSPIRVAddrSpaceMap::map(AddressSpace);
}

auto BVar = static_cast<SPIRVVariable *>(
BM->addVariable(transType(Ty), GV->isConstant(), transLinkageType(GV),
Expand Down Expand Up @@ -1315,7 +1366,8 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
if (UnaryInstruction *U = dyn_cast<UnaryInstruction>(V)) {
if (isSpecialTypeInitializer(U))
return mapValue(V, transValue(U->getOperand(0), BB));
return mapValue(V, transUnaryInst(U, BB));
auto UI = transUnaryInst(U, BB);
return mapValue(V, UI ? UI : transValue(U->getOperand(0), BB));
}

if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
Expand Down
2 changes: 2 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ template <> inline void SPIRVMap<SPIRVStorageClassKind, SPIRVCapVec>::init() {
ADD_VEC_INIT(StorageClassGeneric, {CapabilityGenericPointer});
ADD_VEC_INIT(StorageClassPushConstant, {CapabilityShader});
ADD_VEC_INIT(StorageClassAtomicCounter, {CapabilityAtomicStorage});
ADD_VEC_INIT(StorageClassDeviceOnlyINTEL, {CapabilityUSMStorageClassesINTEL});
ADD_VEC_INIT(StorageClassHostOnlyINTEL, {CapabilityUSMStorageClassesINTEL});
}

template <> inline void SPIRVMap<SPIRVImageDimKind, SPIRVCapVec>::init() {
Expand Down
2 changes: 2 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ bool isSpecConstantOpAllowedOp(Op OC) {
OpConvertUToPtr,
OpGenericCastToPtr,
OpPtrCastToGeneric,
OpCrossWorkgroupCastToPtrINTEL,
OpPtrCastToCrossWorkgroupINTEL,
OpBitcast,
OpQuantizeToF16,
OpSNegate,
Expand Down
14 changes: 11 additions & 3 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,15 @@ class SPIRVStore : public SPIRVInstruction, public SPIRVMemoryAccess {
SPIRVInstruction::validate();
if (getSrc()->isForward() || getDst()->isForward())
return;
assert(getValueType(PtrId)->getPointerElementType() ==
getValueType(ValId) &&
"Inconsistent operand types");
#ifndef NDEBUG
if (getValueType(PtrId)->getPointerElementType() != getValueType(ValId)) {
assert(getValueType(PtrId)
->getPointerElementType()
->getPointerStorageClass() ==
getValueType(ValId)->getPointerStorageClass() &&
"Inconsistent operand types");
}
#endif // NDEBUG
}

private:
Expand Down Expand Up @@ -1603,6 +1609,8 @@ _SPIRV_OP(ConvertPtrToU)
_SPIRV_OP(ConvertUToPtr)
_SPIRV_OP(PtrCastToGeneric)
_SPIRV_OP(GenericCastToPtr)
_SPIRV_OP(CrossWorkgroupCastToPtrINTEL)
_SPIRV_OP(PtrCastToCrossWorkgroupINTEL)
_SPIRV_OP(Bitcast)
_SPIRV_OP(SNegate)
_SPIRV_OP(FNegate)
Expand Down
4 changes: 4 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ inline bool isValid(spv::StorageClass V) {
case StorageClassPushConstant:
case StorageClassAtomicCounter:
case StorageClassImage:
case StorageClassDeviceOnlyINTEL:
case StorageClassHostOnlyINTEL:
return true;
default:
return false;
Expand Down Expand Up @@ -728,6 +730,8 @@ inline bool isValid(spv::Op V) {
case OpConvertUToPtr:
case OpPtrCastToGeneric:
case OpGenericCastToPtr:
case OpPtrCastToCrossWorkgroupINTEL:
case OpCrossWorkgroupCastToPtrINTEL:
case OpGenericCastToPtrExplicit:
case OpBitcast:
case OpSNegate:
Expand Down
3 changes: 3 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ template <> inline void SPIRVMap<StorageClass, std::string>::init() {
add(StorageClassPushConstant, "PushConstant");
add(StorageClassAtomicCounter, "AtomicCounter");
add(StorageClassImage, "Image");
add(StorageClassDeviceOnlyINTEL, "DeviceOnlyINTEL");
add(StorageClassHostOnlyINTEL, "HostOnlyINTEL");
}
SPIRV_DEF_NAMEMAP(StorageClass, SPIRVStorageClassNameMap)

Expand Down Expand Up @@ -549,6 +551,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityGroupNonUniformShuffleRelative,
"GroupNonUniformShuffleRelative");
add(CapabilityGroupNonUniformClustered, "GroupNonUniformClustered");
add(CapabilityUSMStorageClassesINTEL, "USMStorageClassesINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
4 changes: 3 additions & 1 deletion llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ inline bool isCmpOpCode(Op OpCode) {

inline bool isCvtOpCode(Op OpCode) {
return ((unsigned)OpCode >= OpConvertFToU && (unsigned)OpCode <= OpBitcast) ||
OpCode == OpSatConvertSToU || OpCode == OpSatConvertUToS;
OpCode == OpSatConvertSToU || OpCode == OpSatConvertUToS ||
OpCode == OpPtrCastToCrossWorkgroupINTEL ||
OpCode == OpCrossWorkgroupCastToPtrINTEL;
}

inline bool isCvtToUnsignedOpCode(Op OpCode) {
Expand Down
2 changes: 2 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ _SPIRV_OP(SubgroupAvcSicGetPackedSkcLumaCountThresholdINTEL, 5814)
_SPIRV_OP(SubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL, 5815)
_SPIRV_OP(SubgroupAvcSicGetInterRawSadsINTEL, 5816)
_SPIRV_OP(LoopControlINTEL, 5887)
_SPIRV_OP(PtrCastToCrossWorkgroupINTEL, 5934)
_SPIRV_OP(CrossWorkgroupCastToPtrINTEL, 5938)
_SPIRV_OP(ReadPipeBlockingINTEL, 5946)
_SPIRV_OP(WritePipeBlockingINTEL, 5947)
_SPIRV_OP(FPGARegINTEL, 5949)
5 changes: 5 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ enum StorageClass {
StorageClassShaderRecordBufferNV = 5343,
StorageClassPhysicalStorageBuffer = 5349,
StorageClassPhysicalStorageBufferEXT = 5349,
StorageClassDeviceOnlyINTEL = 5936,
StorageClassHostOnlyINTEL = 5937,
StorageClassMax = 0x7fffffff,
};

Expand Down Expand Up @@ -953,6 +955,7 @@ enum Capability {
CapabilityFPGARegINTEL = 5948,
CapabilityKernelAttributesINTEL = 5892,
CapabilityFPGAKernelAttributesINTEL = 5897,
CapabilityUSMStorageClassesINTEL = 5935,
CapabilityIOPipeINTEL = 5943,
CapabilityMax = 0x7fffffff,
};
Expand Down Expand Up @@ -1493,6 +1496,8 @@ enum Op {
OpSubgroupAvcSicGetPackedSkcLumaSumThresholdINTEL = 5815,
OpSubgroupAvcSicGetInterRawSadsINTEL = 5816,
OpLoopControlINTEL = 5887,
OpPtrCastToCrossWorkgroupINTEL = 5934,
OpCrossWorkgroupCastToPtrINTEL = 5938,
OpReadPipeBlockingINTEL = 5946,
OpWritePipeBlockingINTEL = 5947,
OpFPGARegINTEL = 5949,
Expand Down
Loading