Skip to content

Commit 16a24b0

Browse files
committed
[SYCL] Add support for platform name and platform version in device whitelist
PlatformName and PlatformVersion are now supported keys. Also all fields are optional now, so SYCL_DEVICE_WHITE_LIST="" matches any device Signed-off-by: Vlad Romanov <[email protected]>
1 parent b06fc66 commit 16a24b0

File tree

2 files changed

+82
-18
lines changed

2 files changed

+82
-18
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,26 @@ struct DevDescT {
5858

5959
const char *devDriverVer = nullptr;
6060
int devDriverVerSize = 0;
61+
62+
const char *pltName = nullptr;
63+
int pltNameSize = 0;
64+
65+
const char *pltVer = nullptr;
66+
int pltVerSize = 0;
6167
};
6268

6369
static std::vector<DevDescT> getWhiteListDesc() {
6470
const char *str = SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get();
6571
if (!str)
6672
return {};
6773

74+
75+
6876
std::vector<DevDescT> decDescs;
6977
const char devNameStr[] = "DeviceName";
7078
const char driverVerStr[] = "DriverVersion";
79+
const char pltNameStr[] = "PlatformName";
80+
const char platformVerStr[] = "PlatformVersion";
7181
decDescs.emplace_back();
7282
while ('\0' != *str) {
7383
const char **valuePtr = nullptr;
@@ -78,6 +88,14 @@ static std::vector<DevDescT> getWhiteListDesc() {
7888
valuePtr = &decDescs.back().devName;
7989
size = &decDescs.back().devNameSize;
8090
str += sizeof(devNameStr) - 1;
91+
} else if (0 == strncmp(pltNameStr, str, sizeof(pltNameStr) - 1)) {
92+
valuePtr = &decDescs.back().pltName;
93+
size = &decDescs.back().pltNameSize;
94+
str += sizeof(pltNameStr) - 1;
95+
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
96+
valuePtr = &decDescs.back().pltVer;
97+
size = &decDescs.back().pltVerSize;
98+
str += sizeof(platformVerStr) - 1;
8199
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
82100
valuePtr = &decDescs.back().devDriverVer;
83101
size = &decDescs.back().devDriverVerSize;
@@ -125,11 +143,19 @@ static std::vector<DevDescT> getWhiteListDesc() {
125143
return decDescs;
126144
}
127145

128-
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
146+
static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices,
147+
RT::PiPlatform pi_platform) {
129148
const std::vector<DevDescT> whiteList(getWhiteListDesc());
130149
if (whiteList.empty())
131150
return;
132151

152+
const string_class pltName =
153+
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
154+
pi_platform);
155+
156+
const string_class pltVer = sycl::detail::get_platform_info<
157+
string_class, info::platform::version>::get(pi_platform);
158+
133159
int insertIDx = 0;
134160
for (RT::PiDevice dev : pi_devices) {
135161
const string_class devName =
@@ -140,8 +166,17 @@ static void filterWhiteList(vector_class<RT::PiDevice> &pi_devices) {
140166
info::device::driver_version>::get(dev);
141167

142168
for (const DevDescT &desc : whiteList) {
143-
// At least device name is required field to consider the filter so far
144-
if (nullptr == desc.devName ||
169+
if (nullptr != desc.pltName &&
170+
!std::regex_match(
171+
pltName, std::regex(std::string(desc.pltName, desc.pltNameSize))))
172+
continue;
173+
174+
if (nullptr != desc.pltVer &&
175+
!std::regex_match(
176+
pltVer, std::regex(std::string(desc.pltVer, desc.pltVerSize))))
177+
continue;
178+
179+
if (nullptr != desc.devName &&
145180
!std::regex_match(
146181
devName, std::regex(std::string(desc.devName, desc.devNameSize))))
147182
continue;
@@ -179,7 +214,7 @@ platform_impl_pi::get_devices(info::device_type deviceType) const {
179214

180215
// Filter out devices that are not present in the white list
181216
if (SYCLConfig<SYCL_DEVICE_WHITE_LIST>::get())
182-
filterWhiteList(pi_devices);
217+
filterWhiteList(pi_devices, m_platform);
183218

184219
std::for_each(pi_devices.begin(), pi_devices.end(),
185220
[&res](const RT::PiDevice &a_pi_device) {

sycl/test/config/white_list.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
// REQUIRES: cpu
22
// RUN: %clangxx -fsycl %s -o %t.out
3-
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf
4-
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out
5-
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out
3+
//
4+
// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t1.conf
5+
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t1.conf %t.out
6+
//
7+
// RUN: env PRINT_PLATFORM_INFO=1 %t.out > %t2.conf
8+
// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t2.conf %t.out
9+
//
10+
// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="PlatformName:{{SUCH NAME DOESN'T EXIST}}" %t.out
611

712
#include <CL/sycl.hpp>
813
#include <iostream>
@@ -12,25 +17,49 @@
1217

1318
using namespace cl;
1419

20+
static void replaceEscapeCharacters(std::string &Str) {
21+
// As a stringwill be used as regexp pattern, we need to get rid of symbols
22+
// that can be treated in a special way. Replace common special symbols with
23+
// '.' which matches to any character
24+
std::replace_if(Str.begin(), Str.end(),
25+
[](const char Sym) { return '(' == Sym || ')' == Sym; }, '.');
26+
}
27+
1528
int main() {
1629

30+
// Expected that white list filter is not set
31+
if (getenv("PRINT_PLATFORM_INFO")) {
32+
for (const sycl::platform &Plt : sycl::platform::get_platforms())
33+
if (!Plt.is_host()) {
34+
35+
std::string Name = Plt.get_info<sycl::info::platform::name>();
36+
const std::string Ver =
37+
Plt.get_info<sycl::info::platform::version>();
38+
39+
replaceEscapeCharacters(Name);
40+
41+
std::cout << "SYCL_DEVICE_WHITE_LIST=PlatformName:{{" << Name
42+
<< "}},PlatformVersion:{{" << Ver << "}}";
43+
44+
return 0;
45+
}
46+
throw std::runtime_error("Non host device is not found");
47+
}
48+
1749
// Expected that white list filter is not set
1850
if (getenv("PRINT_DEVICE_INFO")) {
1951
for (const sycl::platform &Plt : sycl::platform::get_platforms())
2052
if (!Plt.is_host()) {
2153
const sycl::device Dev = Plt.get_devices().at(0);
22-
std::string DevName = Dev.get_info<sycl::info::device::name>();
23-
const std::string DevVer =
54+
std::string Name = Dev.get_info<sycl::info::device::name>();
55+
const std::string Ver =
2456
Dev.get_info<sycl::info::device::driver_version>();
25-
// As device name string will be used as regexp pattern, we need to
26-
// get rid of symbols that can be treated in a special way.
27-
// Replace common special symbols with '.' which matches to any sybmol
28-
for (char &Sym : DevName) {
29-
if (')' == Sym || '(' == Sym)
30-
Sym = '.';
31-
}
32-
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName
33-
<< "}},DriverVersion:{{" << DevVer << "}}";
57+
58+
replaceEscapeCharacters(Name);
59+
60+
std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << Name
61+
<< "}},DriverVersion:{{" << Ver << "}}";
62+
3463
return 0;
3564
}
3665
throw std::runtime_error("Non host device is not found");

0 commit comments

Comments
 (0)