diff --git a/sycl/plugins/cuda/pi_cuda.cpp b/sycl/plugins/cuda/pi_cuda.cpp index 56242de042225..108e7adb176f0 100644 --- a/sycl/plugins/cuda/pi_cuda.cpp +++ b/sycl/plugins/cuda/pi_cuda.cpp @@ -9,13 +9,15 @@ #include #include #include + #include #include #include #include -#include #include +#include #include +#include std::string getCudaVersionString() { int driver_version = 0; @@ -447,6 +449,28 @@ pi_result getInfo(size_t param_value_size, void *param_value, param_value_size_ret, value); } +/// Finds kernel names by searching for entry points in the PTX source, as the +/// CUDA driver API doesn't expose an operation for this. +/// Note: This is currently only being used by the SYCL program class for the +/// has_kernel method, so an alternative would be to move the has_kernel +/// query to PI and use cuModuleGetFunction to check for a kernel. +std::string getKernelNames(pi_program program) { + std::string source(program->source_, + program->source_ + program->sourceLength_); + std::regex entries_pattern(".entry\\s+([^\\([:s:]]*)"); + std::string names(""); + std::smatch match; + bool first_match = true; + while (std::regex_search(source, match, entries_pattern)) { + assert(match.size() == 2); + names += first_match ? "" : ";"; + names += match[1]; // Second element is the group. + source = match.suffix().str(); + first_match = false; + } + return names; +} + /// RAII object that calls the reference count release function on the held PI /// object on destruction. /// @@ -1993,7 +2017,7 @@ pi_result cuda_piProgramGetInfo(pi_program program, pi_program_info param_name, &program->source_); case PI_PROGRAM_INFO_KERNEL_NAMES: { return getInfo(param_value_size, param_value, param_value_size_ret, - "not implemented"); + getKernelNames(program).c_str()); } default: PI_HANDLE_UNKNOWN_PARAM_NAME(param_name); diff --git a/sycl/test/basic_tests/kernel_info.cpp b/sycl/test/basic_tests/kernel_info.cpp index b2c8ffa92a912..541892023ba24 100644 --- a/sycl/test/basic_tests/kernel_info.cpp +++ b/sycl/test/basic_tests/kernel_info.cpp @@ -33,6 +33,7 @@ int main() { program prg(q.get_context()); prg.build_with_kernel_type(); + CHECK(prg.has_kernel()); kernel krn = prg.get_kernel(); q.submit([&](handler &cgh) {