diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index 83af5b246683a..f667fc29a3823 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -58,10 +58,10 @@ constexpr uint8_t ImageOriginKernelCompiler = 1 << 2; class ManagedDeviceGlobalsRegistry { public: ManagedDeviceGlobalsRegistry( - const std::shared_ptr &ContextImpl, - const std::string &Prefix, std::vector &&DeviceGlobalNames, + context_impl &ContextImpl, const std::string &Prefix, + std::vector &&DeviceGlobalNames, std::vector> &&DeviceGlobalAllocations) - : MContextImpl{ContextImpl}, MPrefix{Prefix}, + : MContextImpl{ContextImpl.shared_from_this()}, MPrefix{Prefix}, MDeviceGlobalNames{std::move(DeviceGlobalNames)}, MDeviceGlobalAllocations{std::move(DeviceGlobalAllocations)} {} @@ -704,12 +704,11 @@ class device_image_impl { assert(MRTCBinInfo); assert(MOrigins & ImageOriginKernelCompiler); - const std::shared_ptr &ContextImpl = - getSyclObjImpl(MContext); + sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext); for (const auto &SyclDev : Devices) { device_impl &DevImpl = *getSyclObjImpl(SyclDev); - if (!ContextImpl->hasDevice(DevImpl)) { + if (!ContextImpl.hasDevice(DevImpl)) { throw sycl::exception(make_error_code(errc::invalid), "device not part of kernel_bundle context"); } @@ -742,7 +741,7 @@ class device_image_impl { Devices, BuildOptions, *SourceStrPtr, UrProgram); } - const AdapterPtr &Adapter = ContextImpl->getAdapter(); + const AdapterPtr &Adapter = ContextImpl.getAdapter(); if (!FetchedFromCache) UrProgram = createProgramFromSource(Devices, BuildOptions, LogPtr); @@ -752,7 +751,7 @@ class device_image_impl { UrProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str()); if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { Res = Adapter->call_nocheck( - ContextImpl->getHandleRef(), UrProgram, XsFlags.c_str()); + ContextImpl.getHandleRef(), UrProgram, XsFlags.c_str()); } Adapter->checkUrResult(Res); @@ -796,12 +795,11 @@ class device_image_impl { "compile is only available for kernel_bundle " "when the source language was sycl."); - std::shared_ptr ContextImpl = - getSyclObjImpl(MContext); + sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext); for (const auto &SyclDev : Devices) { detail::device_impl &DevImpl = *getSyclObjImpl(SyclDev); - if (!ContextImpl->hasDevice(DevImpl)) { + if (!ContextImpl.hasDevice(DevImpl)) { throw sycl::exception(make_error_code(errc::invalid), "device not part of kernel_bundle context"); } @@ -873,9 +871,8 @@ class device_image_impl { const std::vector Devices, const std::vector &BuildOptions, const std::string &SourceStr, ur_program_handle_t &UrProgram) const { - const std::shared_ptr &ContextImpl = - getSyclObjImpl(MContext); - const AdapterPtr &Adapter = ContextImpl->getAdapter(); + sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext); + const AdapterPtr &Adapter = ContextImpl.getAdapter(); std::string UserArgs = syclex::detail::userArgsAsString(BuildOptions); @@ -904,7 +901,7 @@ class device_image_impl { Properties.pMetadatas = nullptr; Adapter->call( - ContextImpl->getHandleRef(), DeviceHandles.size(), DeviceHandles.data(), + ContextImpl.getHandleRef(), DeviceHandles.size(), DeviceHandles.data(), Lengths.data(), Binaries.data(), &Properties, &UrProgram); return true; @@ -1132,7 +1129,7 @@ class device_image_impl { } auto DGRegs = std::make_shared( - getSyclObjImpl(MContext), std::string{Prefix}, + *getSyclObjImpl(MContext), std::string{Prefix}, std::move(DeviceGlobalNames), std::move(DeviceGlobalAllocations)); // Mark the image as input so the program manager will bring it into @@ -1195,9 +1192,8 @@ class device_image_impl { createProgramFromSource(const std::vector Devices, const std::vector &Options, std::string *LogPtr) const { - const std::shared_ptr &ContextImpl = - getSyclObjImpl(MContext); - const AdapterPtr &Adapter = ContextImpl->getAdapter(); + sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext); + const AdapterPtr &Adapter = ContextImpl.getAdapter(); const auto spirv = [&]() -> std::vector { switch (MRTCBinInfo->MLanguage) { case syclex::source_language::opencl: { @@ -1234,7 +1230,7 @@ class device_image_impl { }(); ur_program_handle_t UrProgram = nullptr; - Adapter->call(ContextImpl->getHandleRef(), + Adapter->call(ContextImpl.getHandleRef(), spirv.data(), spirv.size(), nullptr, &UrProgram); // program created by urProgramCreateWithIL is implicitly retained.