diff --git a/sycl/source/detail/program_impl.cpp b/sycl/source/detail/program_impl.cpp index 71eb1c4e9464e..a6f65f51f3682 100644 --- a/sycl/source/detail/program_impl.cpp +++ b/sycl/source/detail/program_impl.cpp @@ -15,7 +15,9 @@ #include #include +#include #include +#include __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { @@ -38,6 +40,16 @@ program_impl::program_impl( throw runtime_error("Non-empty vector of programs expected", PI_INVALID_VALUE); } + + // Sort the programs to avoid deadlocks due to locking multiple mutexes & + // verify that all programs are unique. + std::sort(ProgramList.begin(), ProgramList.end()); + auto It = std::unique(ProgramList.begin(), ProgramList.end()); + if (It != ProgramList.end()) { + throw runtime_error("Attempting to link a program with itself", + PI_INVALID_PROGRAM); + } + MContext = ProgramList[0]->MContext; MDevices = ProgramList[0]->MDevices; vector_class DevicesSorted; @@ -45,7 +57,9 @@ program_impl::program_impl( DevicesSorted = sort_devices_by_cl_device_id(MDevices); } check_device_feature_support(MDevices); + std::list> Locks; for (const auto &Prg : ProgramList) { + Locks.emplace_back(Prg->MMutex); Prg->throw_if_state_is_not(program_state::compiled); if (Prg->MContext != MContext) { throw invalid_object_error( @@ -184,6 +198,7 @@ cl_program program_impl::get() const { void program_impl::compile_with_kernel_name(string_class KernelName, string_class CompileOptions, OSModuleHandle M) { + std::lock_guard Lock(MMutex); throw_if_state_is_not(program_state::none); MProgramModuleHandle = M; if (!is_host()) { @@ -195,6 +210,7 @@ void program_impl::compile_with_kernel_name(string_class KernelName, void program_impl::compile_with_source(string_class KernelSource, string_class CompileOptions) { + std::lock_guard Lock(MMutex); throw_if_state_is_not(program_state::none); // TODO should it throw if it's host? if (!is_host()) { @@ -207,6 +223,7 @@ void program_impl::compile_with_source(string_class KernelSource, void program_impl::build_with_kernel_name(string_class KernelName, string_class BuildOptions, OSModuleHandle Module) { + std::lock_guard Lock(MMutex); throw_if_state_is_not(program_state::none); MProgramModuleHandle = Module; if (!is_host()) { @@ -227,6 +244,7 @@ void program_impl::build_with_kernel_name(string_class KernelName, void program_impl::build_with_source(string_class KernelSource, string_class BuildOptions) { + std::lock_guard Lock(MMutex); throw_if_state_is_not(program_state::none); // TODO should it throw if it's host? if (!is_host()) { @@ -237,6 +255,7 @@ void program_impl::build_with_source(string_class KernelSource, } void program_impl::link(string_class LinkOptions) { + std::lock_guard Lock(MMutex); throw_if_state_is_not(program_state::compiled); if (!is_host()) { check_device_feature_support(MDevices); diff --git a/sycl/source/detail/program_impl.hpp b/sycl/source/detail/program_impl.hpp index ddf79492f2f3d..36284d0a96f57 100644 --- a/sycl/source/detail/program_impl.hpp +++ b/sycl/source/detail/program_impl.hpp @@ -20,6 +20,7 @@ #include #include #include +#include __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { @@ -385,6 +386,7 @@ class program_impl { RT::PiProgram MProgram = nullptr; program_state MState = program_state::none; + std::mutex MMutex; ContextImplPtr MContext; bool MLinkable = false; vector_class MDevices;