diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index 9ce7acec912c5..217c16817704a 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -112,7 +112,7 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle, NativeHandle, Adapter->getUrAdapter(), DeviceHandles.size(), DeviceHandles.data(), &Properties, &UrContext); // Construct the SYCL context from UR context. - return detail::createSyclObjFromImpl(std::make_shared( + return detail::createSyclObjFromImpl(context_impl::create( UrContext, Handler, Adapter, DeviceList, !KeepOwnership)); } diff --git a/sycl/source/context.cpp b/sycl/source/context.cpp index 6878b0ce2442e..69e9233e69ed4 100644 --- a/sycl/source/context.cpp +++ b/sycl/source/context.cpp @@ -69,8 +69,7 @@ context::context(const std::vector &DeviceList, throw exception(make_error_code(errc::invalid), "Can't add devices across platforms to a single context."); else - impl = std::make_shared(DeviceList, AsyncHandler, - PropList); + impl = detail::context_impl::create(DeviceList, AsyncHandler, PropList); } context::context(cl_context ClContext, async_handler AsyncHandler) { const auto &Adapter = sycl::detail::ur::getAdapter(); @@ -81,8 +80,7 @@ context::context(cl_context ClContext, async_handler AsyncHandler) { Adapter->call( nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext); - impl = - std::make_shared(hContext, AsyncHandler, Adapter); + impl = detail::context_impl::create(hContext, AsyncHandler, Adapter); } template diff --git a/sycl/source/detail/context_impl.cpp b/sycl/source/detail/context_impl.cpp index 9fe103986f913..115fc58794d90 100644 --- a/sycl/source/detail/context_impl.cpp +++ b/sycl/source/detail/context_impl.cpp @@ -27,19 +27,9 @@ namespace sycl { inline namespace _V1 { namespace detail { -context_impl::context_impl(const device &Device, async_handler AsyncHandler, - const property_list &PropList) - : MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(1, Device), - MContext(nullptr), - MPlatform(detail::getSyclObjImpl(Device.get_platform())), - MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) { - verifyProps(PropList); - MKernelProgramCache.setContextPtr(this); -} - context_impl::context_impl(const std::vector Devices, async_handler AsyncHandler, - const property_list &PropList) + const property_list &PropList, private_tag) : MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices), MContext(nullptr), MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())), @@ -72,7 +62,7 @@ context_impl::context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, const AdapterPtr &Adapter, const std::vector &DeviceList, - bool OwnedByRuntime) + bool OwnedByRuntime, private_tag) : MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler), MDevices(DeviceList), MContext(UrContext), MPlatform(), MSupportBufferLocationByDevices(NotChecked) { diff --git a/sycl/source/detail/context_impl.hpp b/sycl/source/detail/context_impl.hpp index adb372f46115b..5a79591ba87eb 100644 --- a/sycl/source/detail/context_impl.hpp +++ b/sycl/source/detail/context_impl.hpp @@ -29,20 +29,12 @@ inline namespace _V1 { // Forward declaration class device; namespace detail { -class context_impl { -public: - /// Constructs a context_impl using a single SYCL devices. - /// - /// The constructed context_impl will use the AsyncHandler parameter to - /// handle exceptions. - /// PropList carries the properties of the constructed context_impl. - /// - /// \param Device is an instance of SYCL device. - /// \param AsyncHandler is an instance of async_handler. - /// \param PropList is an instance of property_list. - context_impl(const device &Device, async_handler AsyncHandler, - const property_list &PropList); +class context_impl : std::enable_shared_from_this { + struct private_tag { + explicit private_tag() = default; + }; +public: /// Constructs a context_impl using a list of SYCL devices. /// /// Newly created instance will save each SYCL device in the list. This @@ -56,7 +48,8 @@ class context_impl { /// \param AsyncHandler is an instance of async_handler. /// \param PropList is an instance of property_list. context_impl(const std::vector DeviceList, - async_handler AsyncHandler, const property_list &PropList); + async_handler AsyncHandler, const property_list &PropList, + private_tag); /// Construct a context_impl using plug-in interoperability handle. /// @@ -70,8 +63,23 @@ class context_impl { /// transferred to runtime context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, const AdapterPtr &Adapter, - const std::vector &DeviceList = {}, - bool OwnedByRuntime = true); + const std::vector &DeviceList, bool OwnedByRuntime, + private_tag); + + context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler, + const AdapterPtr &Adapter, private_tag tag) + : context_impl(UrContext, AsyncHandler, Adapter, + std::vector{}, + /*OwnedByRuntime*/ true, tag) {} + + // Single variadic method works because all the ctors are expected to be + // "public" except the `private_tag` part restricting the creation to + // `std::shared_ptr` allocations. + template + static std::shared_ptr create(Ts &&...args) { + return std::make_shared(std::forward(args)..., + private_tag{}); + } ~context_impl();