diff --git a/sycl/source/detail/platform_impl.cpp b/sycl/source/detail/platform_impl.cpp index f3dfa35606ab4..307390d2d3b23 100644 --- a/sycl/source/detail/platform_impl.cpp +++ b/sycl/source/detail/platform_impl.cpp @@ -58,6 +58,12 @@ struct DevDescT { const char *devDriverVer = nullptr; int devDriverVerSize = 0; + + const char *platformName = nullptr; + int platformNameSize = 0; + + const char *platformVer = nullptr; + int platformVerSize = 0; }; static std::vector getWhiteListDesc() { @@ -68,6 +74,8 @@ static std::vector getWhiteListDesc() { std::vector decDescs; const char devNameStr[] = "DeviceName"; const char driverVerStr[] = "DriverVersion"; + const char platformNameStr[] = "PlatformName"; + const char platformVerStr[] = "PlatformVersion"; decDescs.emplace_back(); while ('\0' != *str) { const char **valuePtr = nullptr; @@ -78,6 +86,15 @@ static std::vector getWhiteListDesc() { valuePtr = &decDescs.back().devName; size = &decDescs.back().devNameSize; str += sizeof(devNameStr) - 1; + } else if (0 == + strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) { + valuePtr = &decDescs.back().platformName; + size = &decDescs.back().platformNameSize; + str += sizeof(platformNameStr) - 1; + } else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) { + valuePtr = &decDescs.back().platformVer; + size = &decDescs.back().platformVerSize; + str += sizeof(platformVerStr) - 1; } else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) { valuePtr = &decDescs.back().devDriverVer; size = &decDescs.back().devDriverVerSize; @@ -125,23 +142,43 @@ static std::vector getWhiteListDesc() { return decDescs; } -static void filterWhiteList(vector_class &pi_devices) { +static void filterWhiteList(vector_class &pi_devices, + RT::PiPlatform pi_platform) { const std::vector whiteList(getWhiteListDesc()); if (whiteList.empty()) return; + const string_class platformName = + sycl::detail::get_platform_info::get( + pi_platform); + + const string_class platformVer = sycl::detail::get_platform_info< + string_class, info::platform::version>::get(pi_platform); + int insertIDx = 0; for (RT::PiDevice dev : pi_devices) { const string_class devName = - sycl::detail::get_device_info::get(dev); + sycl::detail::get_device_info::get( + dev); const string_class devDriverVer = sycl::detail::get_device_info::get(dev); for (const DevDescT &desc : whiteList) { - // At least device name is required field to consider the filter so far - if (nullptr == desc.devName || + if (nullptr != desc.platformName && + !std::regex_match(platformName, + std::regex(std::string(desc.platformName, + desc.platformNameSize)))) + continue; + + if (nullptr != desc.platformVer && + !std::regex_match( + platformVer, + std::regex(std::string(desc.platformVer, desc.platformVerSize)))) + continue; + + if (nullptr != desc.devName && !std::regex_match( devName, std::regex(std::string(desc.devName, desc.devNameSize)))) continue; @@ -179,7 +216,7 @@ platform_impl_pi::get_devices(info::device_type deviceType) const { // Filter out devices that are not present in the white list if (SYCLConfig::get()) - filterWhiteList(pi_devices); + filterWhiteList(pi_devices, m_platform); std::for_each(pi_devices.begin(), pi_devices.end(), [&res](const RT::PiDevice &a_pi_device) { diff --git a/sycl/test/config/white_list.cpp b/sycl/test/config/white_list.cpp index e3f9e11a5bc79..0b9db514828ea 100644 --- a/sycl/test/config/white_list.cpp +++ b/sycl/test/config/white_list.cpp @@ -1,8 +1,13 @@ // REQUIRES: cpu // RUN: %clangxx -fsycl %s -o %t.out -// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t.conf -// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t.conf %t.out -// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="" %t.out +// +// RUN: env PRINT_DEVICE_INFO=1 %t.out > %t1.conf +// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t1.conf %t.out +// +// RUN: env PRINT_PLATFORM_INFO=1 %t.out > %t2.conf +// RUN: env TEST_DEVICE_AVAILABLE=1 env SYCL_CONFIG_FILE_NAME=%t2.conf %t.out +// +// RUN: env TEST_DEVICE_IS_NOT_AVAILABLE=1 env SYCL_DEVICE_WHITE_LIST="PlatformName:{{SUCH NAME DOESN'T EXIST}}" %t.out #include #include @@ -12,25 +17,50 @@ using namespace cl; +static void replaceSpecialCharacters(std::string &Str) { + // Replace common special symbols with '.' which matches to any character + std::replace_if(Str.begin(), Str.end(), + [](const char Sym) { return '(' == Sym || ')' == Sym; }, '.'); +} + int main() { + // Expected that white list filter is not set + if (getenv("PRINT_PLATFORM_INFO")) { + for (const sycl::platform &Platform : sycl::platform::get_platforms()) + if (!Platform.is_host()) { + + std::string Name = Platform.get_info(); + std::string Ver = Platform.get_info(); + // As a string will be used as regexp pattern, we need to get rid of + // symbols that can be treated in a special way. + replaceSpecialCharacters(Name); + replaceSpecialCharacters(Ver); + + std::cout << "SYCL_DEVICE_WHITE_LIST=PlatformName:{{" << Name + << "}},PlatformVersion:{{" << Ver << "}}"; + + return 0; + } + throw std::runtime_error("Non host device is not found"); + } + // Expected that white list filter is not set if (getenv("PRINT_DEVICE_INFO")) { - for (const sycl::platform &Plt : sycl::platform::get_platforms()) - if (!Plt.is_host()) { - const sycl::device Dev = Plt.get_devices().at(0); - std::string DevName = Dev.get_info(); - const std::string DevVer = - Dev.get_info(); - // As device name string will be used as regexp pattern, we need to - // get rid of symbols that can be treated in a special way. - // Replace common special symbols with '.' which matches to any sybmol - for (char &Sym : DevName) { - if (')' == Sym || '(' == Sym) - Sym = '.'; - } - std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << DevName - << "}},DriverVersion:{{" << DevVer << "}}"; + for (const sycl::platform &Platform : sycl::platform::get_platforms()) + if (!Platform.is_host()) { + const sycl::device Dev = Platform.get_devices().at(0); + std::string Name = Dev.get_info(); + std::string Ver = Dev.get_info(); + + // As a string will be used as regexp pattern, we need to get rid of + // symbols that can be treated in a special way. + replaceSpecialCharacters(Name); + replaceSpecialCharacters(Ver); + + std::cout << "SYCL_DEVICE_WHITE_LIST=DeviceName:{{" << Name + << "}},DriverVersion:{{" << Ver << "}}"; + return 0; } throw std::runtime_error("Non host device is not found"); @@ -38,9 +68,9 @@ int main() { // Expected white list to be set with result from "PRINT_DEVICE_INFO" run if (getenv("TEST_DEVICE_AVAILABLE")) { - for (const sycl::platform &Plt : sycl::platform::get_platforms()) - if (!Plt.is_host()) { - if (Plt.get_devices().size() != 1) + for (const sycl::platform &Platform : sycl::platform::get_platforms()) + if (!Platform.is_host()) { + if (Platform.get_devices().size() != 1) throw std::runtime_error("Expected only one non host device."); return 0; @@ -50,8 +80,8 @@ int main() { // Expected white list to be set but empty if (getenv("TEST_DEVICE_IS_NOT_AVAILABLE")) { - for (const sycl::platform &Plt : sycl::platform::get_platforms()) - if (!Plt.is_host()) + for (const sycl::platform &Platform : sycl::platform::get_platforms()) + if (!Platform.is_host()) throw std::runtime_error("Expected no non host device is available"); return 0; }