Skip to content

Commit c7afc22

Browse files
author
Alexander Johnston
authored
[SYCL] Ensure proper definition of spirv builtins for SYCL (#1393)
Ensures proper definition of SPIRV global and workgroup builtin variables for both SYCL compilation with both NVPTX and SPIRV. Removes the downstream changes from llvm-spirv, instead opting to implement the required spirv vars in the spirv_vars.hpp header. Due to the removal of the llvm-spirv changes, the libdevice library now requires access to spirv_vars.hpp to get all required spirv vars. Signed-off-by: Alexander Johnston <[email protected]>
1 parent cf9d1a4 commit c7afc22

File tree

7 files changed

+220
-164
lines changed

7 files changed

+220
-164
lines changed

libdevice/device.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#ifndef __LIBDEVICE_DEVICE_H__
1010
#define __LIBDEVICE_DEVICE_H__
1111

12+
// We need the following header to ensure the definition of all spirv variables
13+
// required by the wrapper libraries.
14+
#include "spirv_vars.hpp"
15+
1216
#ifdef __cplusplus
1317
#define EXTERN_C extern "C"
1418
#else // __cplusplus

libdevice/spirv_vars.hpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//==---------- spirv_vars.hpp --- SPIRV variables -------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// ===-------------------------------------------------------------------=== //
8+
9+
#pragma once
10+
11+
#include <cstddef>
12+
#include <cstdint>
13+
14+
#ifdef __SYCL_DEVICE_ONLY__
15+
16+
#ifdef __SYCL_NVPTX__
17+
18+
SYCL_EXTERNAL size_t __spirv_GlobalInvocationId_x();
19+
SYCL_EXTERNAL size_t __spirv_GlobalInvocationId_y();
20+
SYCL_EXTERNAL size_t __spirv_GlobalInvocationId_z();
21+
22+
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_x();
23+
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_y();
24+
SYCL_EXTERNAL size_t __spirv_LocalInvocationId_z();
25+
26+
#else // __SYCL_NVPTX__
27+
28+
typedef size_t size_t_vec __attribute__((ext_vector_type(3)));
29+
extern "C" const __attribute__((opencl_constant))
30+
size_t_vec __spirv_BuiltInGlobalInvocationId;
31+
extern "C" const __attribute__((opencl_constant))
32+
size_t_vec __spirv_BuiltInLocalInvocationId;
33+
34+
SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_x() {
35+
return __spirv_BuiltInGlobalInvocationId.x;
36+
}
37+
SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_y() {
38+
return __spirv_BuiltInGlobalInvocationId.y;
39+
}
40+
SYCL_EXTERNAL inline size_t __spirv_GlobalInvocationId_z() {
41+
return __spirv_BuiltInGlobalInvocationId.z;
42+
}
43+
44+
SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_x() {
45+
return __spirv_BuiltInLocalInvocationId.x;
46+
}
47+
SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_y() {
48+
return __spirv_BuiltInLocalInvocationId.y;
49+
}
50+
SYCL_EXTERNAL inline size_t __spirv_LocalInvocationId_z() {
51+
return __spirv_BuiltInLocalInvocationId.z;
52+
}
53+
54+
#endif // __SYCL_NVPTX__
55+
56+
#define DEFINE_FUNC_ID_TO_XYZ_CONVERTER(POSTFIX) \
57+
template <int ID> static inline size_t get##POSTFIX(); \
58+
template <> size_t get##POSTFIX<0>() { return __spirv_##POSTFIX##_x(); } \
59+
template <> size_t get##POSTFIX<1>() { return __spirv_##POSTFIX##_y(); } \
60+
template <> size_t get##POSTFIX<2>() { return __spirv_##POSTFIX##_z(); }
61+
62+
namespace __spirv {
63+
64+
DEFINE_FUNC_ID_TO_XYZ_CONVERTER(GlobalInvocationId);
65+
DEFINE_FUNC_ID_TO_XYZ_CONVERTER(LocalInvocationId);
66+
67+
} // namespace __spirv
68+
69+
#undef DEFINE_FUNC_ID_TO_XYZ_CONVERTER
70+
71+
#define DEFINE_INIT_SIZES(POSTFIX) \
72+
\
73+
template <int Dim, class DstT> struct InitSizesST##POSTFIX; \
74+
\
75+
template <class DstT> struct InitSizesST##POSTFIX<1, DstT> { \
76+
static DstT initSize() { return {get##POSTFIX<0>()}; } \
77+
}; \
78+
\
79+
template <class DstT> struct InitSizesST##POSTFIX<2, DstT> { \
80+
static DstT initSize() { return {get##POSTFIX<1>(), get##POSTFIX<0>()}; } \
81+
}; \
82+
\
83+
template <class DstT> struct InitSizesST##POSTFIX<3, DstT> { \
84+
static DstT initSize() { \
85+
return {get##POSTFIX<2>(), get##POSTFIX<1>(), get##POSTFIX<0>()}; \
86+
} \
87+
}; \
88+
\
89+
template <int Dims, class DstT> static DstT init##POSTFIX() { \
90+
return InitSizesST##POSTFIX<Dims, DstT>::initSize(); \
91+
}
92+
93+
namespace __spirv {
94+
95+
DEFINE_INIT_SIZES(GlobalInvocationId)
96+
DEFINE_INIT_SIZES(LocalInvocationId)
97+
98+
} // namespace __spirv
99+
100+
#undef DEFINE_INIT_SIZES
101+
102+
#endif // __SYCL_DEVICE_ONLY__

llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -273,31 +273,11 @@ class OCL20ToSPIRV : public ModulePass, public InstVisitor<OCL20ToSPIRV> {
273273
Module *M;
274274
LLVMContext *Ctx;
275275
unsigned CLVer; /// OpenCL version as major*10+minor
276-
unsigned CLLang; /// OpenCL language, see `spv::SourceLanguage`.
277276
std::set<Value *> ValuesToDelete;
278277

279278
ConstantInt *addInt32(int I) { return getInt32(M, I); }
280279
ConstantInt *addSizet(uint64_t I) { return getSizet(M, I); }
281280

282-
/// Return the index of the id dimension represented by the demangled built-in name.
283-
/// ie. given `__spirv__GlobalInvocationId_x`, return `0`.
284-
Optional<uint64_t> spirvDimensionFromBuiltin(StringRef Name) {
285-
if (!Name.startswith("__spirv_")) {
286-
return {};
287-
}
288-
289-
Optional<uint64_t> Result = {};
290-
if (Name.endswith("_x")) {
291-
Result = 0;
292-
} else if (Name.endswith("_y")) {
293-
Result = 1;
294-
} else if (Name.endswith("_z")) {
295-
Result = 2;
296-
}
297-
298-
return Result;
299-
}
300-
301281
/// Get vector width from OpenCL vload* function name.
302282
SPIRVWord getVecLoadWidth(const std::string &DemangledName) {
303283
SPIRVWord Width = 0;
@@ -347,8 +327,7 @@ bool OCL20ToSPIRV::runOnModule(Module &Module) {
347327
M = &Module;
348328
Ctx = &M->getContext();
349329
auto Src = getSPIRVSource(&Module);
350-
CLLang = std::get<0>(Src);
351-
if (CLLang != spv::SourceLanguageOpenCL_C && CLLang != spv::SourceLanguageOpenCL_CPP)
330+
if (std::get<0>(Src) != spv::SourceLanguageOpenCL_C)
352331
return false;
353332

354333
CLVer = std::get<1>(Src);
@@ -1245,18 +1224,9 @@ void OCL20ToSPIRV::transWorkItemBuiltinsToVariables() {
12451224
std::vector<Function *> WorkList;
12461225
for (auto &I : *M) {
12471226
StringRef DemangledName;
1248-
auto MangledName = I.getName();
1249-
LLVM_DEBUG(dbgs() << "Function mangled name: " << MangledName << '\n');
1250-
if (!oclIsBuiltin(MangledName, DemangledName))
1227+
if (!oclIsBuiltin(I.getName(), DemangledName))
12511228
continue;
12521229
LLVM_DEBUG(dbgs() << "Function demangled name: " << DemangledName << '\n');
1253-
auto SpirvDimension {spirvDimensionFromBuiltin(DemangledName)};
1254-
auto IsSpirvBuiltinWithDimensions {SpirvDimension.hasValue()};
1255-
if ((!IsSpirvBuiltinWithDimensions && CLLang == spv::SourceLanguageOpenCL_CPP) ||
1256-
(IsSpirvBuiltinWithDimensions && CLLang == spv::SourceLanguageOpenCL_C)) {
1257-
// Only transform `__spirv_` builtins in OpenCL C++.
1258-
continue;
1259-
}
12601230
std::string BuiltinVarName;
12611231
SPIRVBuiltinVariableKind BVKind;
12621232
if (!SPIRSPIRVBuiltinVariableMap::find(DemangledName.str(), &BVKind))
@@ -1265,15 +1235,11 @@ void OCL20ToSPIRV::transWorkItemBuiltinsToVariables() {
12651235
std::string(kSPIRVName::Prefix) + SPIRVBuiltInNameMap::map(BVKind);
12661236
LLVM_DEBUG(dbgs() << "builtin variable name: " << BuiltinVarName << '\n');
12671237
bool IsVec = I.getFunctionType()->getNumParams() > 0;
1268-
Type *GVType = (IsVec || IsSpirvBuiltinWithDimensions) ?
1269-
VectorType::get(I.getReturnType(), 3) : I.getReturnType();
1270-
// Each of the `__spirv__GlobalInvocationId_*` functions all extract an element of
1271-
// the same global variable, so ensure that we only create the global once.
1272-
auto BV = M->getOrInsertGlobal(BuiltinVarName, GVType, [&] {
1273-
return new GlobalVariable(
1274-
*M, GVType, true, GlobalValue::ExternalLinkage, nullptr, BuiltinVarName,
1275-
0, GlobalVariable::NotThreadLocal, SPIRAS_Input);
1276-
});
1238+
Type *GVType =
1239+
IsVec ? VectorType::get(I.getReturnType(), 3) : I.getReturnType();
1240+
auto BV = new GlobalVariable(*M, GVType, true, GlobalValue::ExternalLinkage,
1241+
nullptr, BuiltinVarName, 0,
1242+
GlobalVariable::NotThreadLocal, SPIRAS_Input);
12771243
std::vector<Instruction *> InstList;
12781244
for (auto UI = I.user_begin(), UE = I.user_end(); UI != UE; ++UI) {
12791245
auto CI = dyn_cast<CallInst>(*UI);
@@ -1284,10 +1250,6 @@ void OCL20ToSPIRV::transWorkItemBuiltinsToVariables() {
12841250
NewValue =
12851251
ExtractElementInst::Create(NewValue, CI->getArgOperand(0), "", CI);
12861252
LLVM_DEBUG(dbgs() << *NewValue << '\n');
1287-
} else if (IsSpirvBuiltinWithDimensions) {
1288-
auto Index = ConstantInt::get(I.getReturnType(), SpirvDimension.getValue(), false);
1289-
NewValue = ExtractElementInst::Create(NewValue, Index, "", CI);
1290-
LLVM_DEBUG(dbgs() << *NewValue << '\n');
12911253
}
12921254
NewValue->takeName(CI);
12931255
CI->replaceAllUsesWith(NewValue);

llvm-spirv/lib/SPIRV/OCLUtil.h

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -595,46 +595,16 @@ template <> inline void SPIRVMap<OclExt::Kind, SPIRVCapabilityKind>::init() {
595595
template <>
596596
inline void SPIRVMap<std::string, SPIRVBuiltinVariableKind>::init() {
597597
add("get_work_dim", BuiltInWorkDim);
598-
add("__spirv_GlobalSize_x", BuiltInGlobalSize);
599-
add("__spirv_GlobalSize_y", BuiltInGlobalSize);
600-
add("__spirv_GlobalSize_z", BuiltInGlobalSize);
601598
add("get_global_size", BuiltInGlobalSize);
602-
add("__spirv_GlobalInvocationId_x", BuiltInGlobalInvocationId);
603-
add("__spirv_GlobalInvocationId_y", BuiltInGlobalInvocationId);
604-
add("__spirv_GlobalInvocationId_z", BuiltInGlobalInvocationId);
605599
add("get_global_id", BuiltInGlobalInvocationId);
606-
add("__spirv_GlobalOffset_x", BuiltInGlobalOffset);
607-
add("__spirv_GlobalOffset_y", BuiltInGlobalOffset);
608-
add("__spirv_GlobalOffset_z", BuiltInGlobalOffset);
609600
add("get_global_offset", BuiltInGlobalOffset);
610-
add("__spirv_WorkgroupSize_x", BuiltInWorkgroupSize);
611-
add("__spirv_WorkgroupSize_y", BuiltInWorkgroupSize);
612-
add("__spirv_WorkgroupSize_z", BuiltInWorkgroupSize);
613601
add("get_local_size", BuiltInWorkgroupSize);
614-
add("__spirv_WorkgroupSize_x", BuiltInWorkgroupSize);
615-
add("__spirv_WorkgroupSize_y", BuiltInWorkgroupSize);
616-
add("__spirv_WorkgroupSize_z", BuiltInWorkgroupSize);
617602
add("get_enqueued_local_size", BuiltInEnqueuedWorkgroupSize);
618-
add("__spirv_LocalInvocationId_x", BuiltInLocalInvocationId);
619-
add("__spirv_LocalInvocationId_y", BuiltInLocalInvocationId);
620-
add("__spirv_LocalInvocationId_z", BuiltInLocalInvocationId);
621603
add("get_local_id", BuiltInLocalInvocationId);
622-
add("__spirv_NumWorkgroups_x", BuiltInNumWorkgroups);
623-
add("__spirv_NumWorkgroups_y", BuiltInNumWorkgroups);
624-
add("__spirv_NumWorkgroups_z", BuiltInNumWorkgroups);
625604
add("get_num_groups", BuiltInNumWorkgroups);
626-
add("__spirv_WorkgroupId_x", BuiltInWorkgroupId);
627-
add("__spirv_WorkgroupId_y", BuiltInWorkgroupId);
628-
add("__spirv_WorkgroupId_z", BuiltInWorkgroupId);
629605
add("get_group_id", BuiltInWorkgroupId);
630-
add("__spirv_WorkgroupId_x", BuiltInWorkgroupId);
631-
add("__spirv_WorkgroupId_y", BuiltInWorkgroupId);
632-
add("__spirv_WorkgroupId_z", BuiltInWorkgroupId);
633606
add("get_global_linear_id", BuiltInGlobalLinearId);
634607
add("get_local_linear_id", BuiltInLocalInvocationIndex);
635-
add("__spirv_LocalInvocationId_x", BuiltInLocalInvocationId);
636-
add("__spirv_LocalInvocationId_y", BuiltInLocalInvocationId);
637-
add("__spirv_LocalInvocationId_z", BuiltInLocalInvocationId);
638608
add("get_sub_group_size", BuiltInSubgroupSize);
639609
add("get_max_sub_group_size", BuiltInSubgroupMaxSize);
640610
add("get_num_sub_groups", BuiltInNumSubgroups);

llvm-spirv/test/builtin_vars_to_func.ll

Lines changed: 0 additions & 41 deletions
This file was deleted.

llvm-spirv/test/builtin_vars_to_func_cpp.ll

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)