diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 7c50701c5af0e..79707ea022685 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -891,7 +891,7 @@ class __SYCL_EXPORT handler { // If the kernel lambda is callable with a kernel_handler argument, manifest // the associated kernel handler. if constexpr (IsCallableWithKernelHandler) { - getOrInsertHandlerKernelBundle(/*Insert=*/true); + getOrInsertHandlerKernelBundlePtr(/*Insert=*/true); } } @@ -1706,13 +1706,26 @@ class __SYCL_EXPORT handler { void setStateSpecConstSet(); bool isStateExplicitKernelBundle() const; +#ifndef __INTEL_PREVIEW_BREAKING_CHANGES std::shared_ptr getOrInsertHandlerKernelBundle(bool Insert) const; +#endif + +#ifdef __INTEL_PREVIEW_BREAKING_CHANGES + // Rename to just getOrInsertHandlerKernelBundle +#endif + detail::kernel_bundle_impl * + getOrInsertHandlerKernelBundlePtr(bool Insert) const; void setHandlerKernelBundle(kernel Kernel); +#ifndef __INTEL_PREVIEW_BREAKING_CHANGES void setHandlerKernelBundle( const std::shared_ptr &NewKernelBundleImpPtr); +#endif + + template + void setHandlerKernelBundle(SharedPtrT &&NewKernelBundleImpPtr); void SetHostTask(std::function &&Func); void SetHostTask(std::function &&Func); @@ -1760,6 +1773,8 @@ class __SYCL_EXPORT handler { /// called. void setUserFacingNodeType(ext::oneapi::experimental::node_type Type); + kernel_bundle getKernelBundle() const; + public: handler(const handler &) = delete; handler(handler &&) = delete; diff --git a/sycl/include/sycl/kernel_bundle.hpp b/sycl/include/sycl/kernel_bundle.hpp index 803504d21f585..497aa3653d2a0 100644 --- a/sycl/include/sycl/kernel_bundle.hpp +++ b/sycl/include/sycl/kernel_bundle.hpp @@ -1330,12 +1330,7 @@ void handler::set_specialization_constant( setStateSpecConstSet(); - std::shared_ptr KernelBundleImplPtr = - getOrInsertHandlerKernelBundle(/*Insert=*/true); - - detail::createSyclObjFromImpl>( - std::move(KernelBundleImplPtr)) - .set_specialization_constant(Value); + getKernelBundle().set_specialization_constant(Value); } template @@ -1347,12 +1342,7 @@ handler::get_specialization_constant() const { "Specialization constants cannot be read after " "explicitly setting the used kernel bundle"); - std::shared_ptr KernelBundleImplPtr = - getOrInsertHandlerKernelBundle(/*Insert=*/true); - - return detail::createSyclObjFromImpl>( - std::move(KernelBundleImplPtr)) - .get_specialization_constant(); + return getKernelBundle().get_specialization_constant(); } } // namespace _V1 diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index f2e5494fb6218..fae02cbf3bdd7 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -269,6 +269,7 @@ set(SYCL_COMMON_SOURCES "detail/host_pipe_map.cpp" "detail/device_global_map.cpp" "detail/device_global_map_entry.cpp" + "detail/device_image_impl.cpp" "detail/device_impl.cpp" "detail/error_handling/error_handling.cpp" "detail/event_impl.cpp" diff --git a/sycl/source/backend.cpp b/sycl/source/backend.cpp index ebd3e4357904c..d262256634d61 100644 --- a/sycl/source/backend.cpp +++ b/sycl/source/backend.cpp @@ -306,7 +306,7 @@ make_kernel_bundle(ur_native_handle_t NativeHandle, ImageOriginInterop); device_image_plain DevImg{DevImgImpl}; - return std::make_shared(TargetContext, Devices, DevImg); + return kernel_bundle_impl::create(TargetContext, Devices, DevImg); } // TODO: Unused. Remove when allowed. diff --git a/sycl/source/detail/device_image_impl.cpp b/sycl/source/detail/device_image_impl.cpp new file mode 100644 index 0000000000000..1ee97f34f3ade --- /dev/null +++ b/sycl/source/detail/device_image_impl.cpp @@ -0,0 +1,59 @@ +//==----------------- device_image_impl.cpp - SYCL device_image_impl -------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +namespace sycl { +inline namespace _V1 { +namespace detail { + +std::shared_ptr device_image_impl::tryGetSourceBasedKernel( + std::string_view Name, const context &Context, + const kernel_bundle_impl &OwnerBundle, + const std::shared_ptr &Self) const { + if (!(getOriginMask() & ImageOriginKernelCompiler)) + return nullptr; + + assert(MRTCBinInfo); + std::string AdjustedName = adjustKernelName(Name); + if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) { + auto &PM = ProgramManager::getInstance(); + for (const std::string &Prefix : MRTCBinInfo->MPrefixes) { + auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName); + + if (!KID || !has_kernel(*KID)) + continue; + + auto UrProgram = get_ur_program_ref(); + auto [UrKernel, CacheMutex, ArgMask] = + PM.getOrCreateKernel(Context, AdjustedName, + /*PropList=*/{}, UrProgram); + return std::make_shared(UrKernel, *getSyclObjImpl(Context), + Self, OwnerBundle.shared_from_this(), + ArgMask, UrProgram, CacheMutex); + } + return nullptr; + } + + ur_program_handle_t UrProgram = get_ur_program_ref(); + const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter(); + ur_kernel_handle_t UrKernel = nullptr; + Adapter->call(UrProgram, AdjustedName.c_str(), + &UrKernel); + // Kernel created by urKernelCreate is implicitly retained. + + return std::make_shared( + UrKernel, *detail::getSyclObjImpl(Context), Self, + OwnerBundle.shared_from_this(), /*ArgMask=*/nullptr, UrProgram, + /*CacheMutex=*/nullptr); +} + +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/source/detail/device_image_impl.hpp b/sycl/source/detail/device_image_impl.hpp index 4d7885a315456..1e307851e5c6a 100644 --- a/sycl/source/detail/device_image_impl.hpp +++ b/sycl/source/detail/device_image_impl.hpp @@ -617,45 +617,10 @@ class device_image_impl { MRTCBinInfo->MKernelNames.end(); } - std::shared_ptr tryGetSourceBasedKernel( - std::string_view Name, const context &Context, - const std::shared_ptr &OwnerBundle, - const std::shared_ptr &Self) const { - if (!(getOriginMask() & ImageOriginKernelCompiler)) - return nullptr; - - assert(MRTCBinInfo); - std::string AdjustedName = adjustKernelName(Name); - if (MRTCBinInfo->MLanguage == syclex::source_language::sycl) { - auto &PM = ProgramManager::getInstance(); - for (const std::string &Prefix : MRTCBinInfo->MPrefixes) { - auto KID = PM.tryGetSYCLKernelID(Prefix + AdjustedName); - - if (!KID || !has_kernel(*KID)) - continue; - - auto UrProgram = get_ur_program_ref(); - auto [UrKernel, CacheMutex, ArgMask] = - PM.getOrCreateKernel(Context, AdjustedName, - /*PropList=*/{}, UrProgram); - return std::make_shared(UrKernel, *getSyclObjImpl(Context), - Self, OwnerBundle, ArgMask, - UrProgram, CacheMutex); - } - return nullptr; - } - - ur_program_handle_t UrProgram = get_ur_program_ref(); - const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter(); - ur_kernel_handle_t UrKernel = nullptr; - Adapter->call(UrProgram, AdjustedName.c_str(), - &UrKernel); - // Kernel created by urKernelCreate is implicitly retained. - - return std::make_shared( - UrKernel, *detail::getSyclObjImpl(Context), Self, OwnerBundle, - /*ArgMask=*/nullptr, UrProgram, /*CacheMutex=*/nullptr); - } + std::shared_ptr + tryGetSourceBasedKernel(std::string_view Name, const context &Context, + const kernel_bundle_impl &OwnerBundle, + const std::shared_ptr &Self) const; bool hasDeviceGlobalName(const std::string &Name) const noexcept { if (!MRTCBinInfo.has_value()) diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index a9a46a0199ec0..8d9415cecf49f 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -860,7 +860,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, std::tie(CmdTraceEvent, InstanceID) = emitKernelInstrumentationData( StreamID, CGExec->MSyclKernel, CodeLoc, CGExec->MIsTopCodeLoc, CGExec->MKernelName.data(), CGExec->MKernelNameBasedCachePtr, nullptr, - CGExec->MNDRDesc, CGExec->MKernelBundle, CGExec->MArgs); + CGExec->MNDRDesc, CGExec->MKernelBundle.get(), CGExec->MArgs); if (CmdTraceEvent) sycl::detail::emitInstrumentationGeneral( StreamID, InstanceID, CmdTraceEvent, xpti::trace_task_begin, nullptr); @@ -1536,8 +1536,7 @@ void exec_graph_impl::populateURKernelUpdateStructs( EliminatedArgMask = Kernel->getKernelArgMask(); } else if (auto SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName, - KernelBundleImplPtr) + ? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName) : std::shared_ptr{nullptr}) { UrKernel = SyclKernelImpl->getHandleRef(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); diff --git a/sycl/source/detail/helpers.cpp b/sycl/source/detail/helpers.cpp index 894b8b8063178..14e7aca275221 100644 --- a/sycl/source/detail/helpers.cpp +++ b/sycl/source/detail/helpers.cpp @@ -73,8 +73,7 @@ retrieveKernelBinary(queue_impl &Queue, KernelNameStrRefT KernelName, DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref(); Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref(); } else if (auto SyclKernelImpl = - KernelBundleImpl ? KernelBundleImpl->tryGetKernel( - KernelName, KernelBundleImpl) + KernelBundleImpl ? KernelBundleImpl->tryGetKernel(KernelName) : std::shared_ptr{nullptr}) { // Retrieve the device image from the kernel bundle. DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref(); diff --git a/sycl/source/detail/kernel_bundle_impl.hpp b/sycl/source/detail/kernel_bundle_impl.hpp index 8bb2df1ac7e20..feb2d0b9f197a 100644 --- a/sycl/source/detail/kernel_bundle_impl.hpp +++ b/sycl/source/detail/kernel_bundle_impl.hpp @@ -67,9 +67,15 @@ class kernel_impl; /// The class is an impl counterpart of the sycl::kernel_bundle. // It provides an access and utilities to manage set of sycl::device_images // objects. -class kernel_bundle_impl { +class kernel_bundle_impl + : public std::enable_shared_from_this { using SpecConstMapT = std::map>; + using Base = std::enable_shared_from_this; + + struct private_tag { + explicit private_tag() = default; + }; void common_ctor_checks() const { const bool AllDevicesInTheContext = @@ -92,7 +98,8 @@ class kernel_bundle_impl { } public: - kernel_bundle_impl(context Ctx, std::vector Devs, bundle_state State) + kernel_bundle_impl(context Ctx, std::vector Devs, bundle_state State, + private_tag) : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { common_ctor_checks(); @@ -103,7 +110,7 @@ class kernel_bundle_impl { } // Interop constructor used by make_kernel - kernel_bundle_impl(context Ctx, std::vector Devs) + kernel_bundle_impl(context Ctx, std::vector Devs, private_tag) : MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) { if (!checkAllDevicesAreInContext(Devs, Ctx)) throw sycl::exception( @@ -114,8 +121,8 @@ class kernel_bundle_impl { // Interop constructor kernel_bundle_impl(context Ctx, std::vector Devs, - device_image_plain &DevImage) - : kernel_bundle_impl(Ctx, Devs) { + device_image_plain &DevImage, private_tag Tag) + : kernel_bundle_impl(Ctx, Devs, Tag) { MDeviceImages.emplace_back(DevImage); MUniqueDeviceImages.emplace_back(DevImage); } @@ -125,7 +132,7 @@ class kernel_bundle_impl { // signature kernel_bundle_impl(const kernel_bundle &InputBundle, std::vector Devs, const property_list &PropList, - bundle_state TargetState) + bundle_state TargetState, private_tag) : MContext(InputBundle.get_context()), MDevices(std::move(Devs)), MState(TargetState) { @@ -193,7 +200,7 @@ class kernel_bundle_impl { // Matches sycl::link kernel_bundle_impl( const std::vector> &ObjectBundles, - std::vector Devs, const property_list &PropList) + std::vector Devs, const property_list &PropList, private_tag) : MDevices(std::move(Devs)), MState(bundle_state::executable) { if (MDevices.empty()) throw sycl::exception(make_error_code(errc::invalid), @@ -414,7 +421,7 @@ class kernel_bundle_impl { kernel_bundle_impl(context Ctx, std::vector Devs, const std::vector &KernelIDs, - bundle_state State) + bundle_state State, private_tag) : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { common_ctor_checks(); @@ -425,7 +432,8 @@ class kernel_bundle_impl { } kernel_bundle_impl(context Ctx, std::vector Devs, - const DevImgSelectorImpl &Selector, bundle_state State) + const DevImgSelectorImpl &Selector, bundle_state State, + private_tag) : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { common_ctor_checks(); @@ -437,7 +445,7 @@ class kernel_bundle_impl { // C'tor matches sycl::join API kernel_bundle_impl(const std::vector &Bundles, - bundle_state State) + bundle_state State, private_tag) : MState(State) { if (Bundles.empty()) return; @@ -501,7 +509,8 @@ class kernel_bundle_impl { // oneapi_ext_kernel_compiler // construct from source string kernel_bundle_impl(const context &Context, syclex::source_language Lang, - const std::string &Src, include_pairs_t IncludePairsVec) + const std::string &Src, include_pairs_t IncludePairsVec, + private_tag) : MContext(Context), MDevices(Context.get_devices()), MDeviceImages{device_image_plain{std::make_shared( Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}}, @@ -513,7 +522,7 @@ class kernel_bundle_impl { // oneapi_ext_kernel_compiler // construct from source bytes kernel_bundle_impl(const context &Context, syclex::source_language Lang, - const std::vector &Bytes) + const std::vector &Bytes, private_tag) : MContext(Context), MDevices(Context.get_devices()), MDeviceImages{device_image_plain{std::make_shared( Bytes, MContext, MDevices, Lang)}}, @@ -528,7 +537,7 @@ class kernel_bundle_impl { const context &Context, const std::vector &Devs, std::vector &&DevImgs, std::vector> &&DevBinaries, - bundle_state State) + bundle_state State, private_tag) : MContext(Context), MDevices(Devs), MSharedDeviceBinaries(std::move(DevBinaries)), MUniqueDeviceImages(std::move(DevImgs)), MState(State) { @@ -540,6 +549,12 @@ class kernel_bundle_impl { MDeviceImages.emplace_back(DevImg); } + template + static std::shared_ptr create(Ts &&...args) { + return std::make_shared(std::forward(args)..., + private_tag{}); + } + std::shared_ptr build_from_source( const std::vector Devices, const std::vector &BuildOptions, @@ -559,9 +574,8 @@ class kernel_bundle_impl { for (std::shared_ptr &DevImgImpl : NewDevImgImpls) NewDevImgs.emplace_back(std::move(DevImgImpl)); } - return std::make_shared( - MContext, Devices, std::move(NewDevImgs), std::move(NewBinReso), - bundle_state::executable); + return create(MContext, Devices, std::move(NewDevImgs), + std::move(NewBinReso), bundle_state::executable); } std::shared_ptr compile_from_source( @@ -584,9 +598,8 @@ class kernel_bundle_impl { for (std::shared_ptr &DevImgImpl : NewDevImgImpls) NewDevImgs.emplace_back(std::move(DevImgImpl)); } - return std::make_shared( - MContext, Devices, std::move(NewDevImgs), std::move(NewBinReso), - bundle_state::object); + return create(MContext, Devices, std::move(NewDevImgs), + std::move(NewBinReso), bundle_state::object); } public: @@ -597,9 +610,7 @@ class kernel_bundle_impl { }); } - kernel - ext_oneapi_get_kernel(const std::string &Name, - const std::shared_ptr &Self) const { + kernel ext_oneapi_get_kernel(const std::string &Name) const { if (!hasSourceBasedImages()) throw sycl::exception(make_error_code(errc::invalid), "'ext_oneapi_get_kernel' is only available in " @@ -615,7 +626,7 @@ class kernel_bundle_impl { const std::shared_ptr &DevImgImpl = getSyclObjImpl(DevImg); if (std::shared_ptr PotentialKernelImpl = - DevImgImpl->tryGetSourceBasedKernel(Name, MContext, Self, + DevImgImpl->tryGetSourceBasedKernel(Name, MContext, *this, DevImgImpl)) return detail::createSyclObjFromImpl( std::move(PotentialKernelImpl)); @@ -717,11 +728,8 @@ class kernel_bundle_impl { return Result; } - kernel - get_kernel(const kernel_id &KernelID, - const std::shared_ptr &Self) const { - if (std::shared_ptr KernelImpl = - tryGetOfflineKernel(KernelID, Self)) + kernel get_kernel(const kernel_id &KernelID) const { + if (std::shared_ptr KernelImpl = tryGetOfflineKernel(KernelID)) return detail::createSyclObjFromImpl(std::move(KernelImpl)); throw sycl::exception(make_error_code(errc::invalid), "The kernel bundle does not contain the kernel " @@ -876,9 +884,8 @@ class kernel_bundle_impl { }); } - std::shared_ptr tryGetOfflineKernel( - const kernel_id &KernelID, - const std::shared_ptr &Self) const { + std::shared_ptr + tryGetOfflineKernel(const kernel_id &KernelID) const { using ImageImpl = std::shared_ptr; // Selected image. ImageImpl SelectedImage = nullptr; @@ -938,13 +945,13 @@ class kernel_bundle_impl { SelectedImage->get_ur_program_ref()); return std::make_shared( - Kernel, *detail::getSyclObjImpl(MContext), SelectedImage, Self, ArgMask, - SelectedImage->get_ur_program_ref(), CacheMutex); + Kernel, *detail::getSyclObjImpl(MContext), SelectedImage, + shared_from_this(), ArgMask, SelectedImage->get_ur_program_ref(), + CacheMutex); } std::shared_ptr - tryGetKernel(detail::KernelNameStrRefT Name, - const std::shared_ptr &Self) const { + tryGetKernel(detail::KernelNameStrRefT Name) const { // TODO: For source-based kernels, it may be faster to keep a map between // {kernel_name, device} and their corresponding image. // First look through the kernels registered in source-based images. @@ -952,7 +959,7 @@ class kernel_bundle_impl { const std::shared_ptr &DevImgImpl = getSyclObjImpl(DevImg); if (std::shared_ptr SourceBasedKernel = - DevImgImpl->tryGetSourceBasedKernel(Name, MContext, Self, + DevImgImpl->tryGetSourceBasedKernel(Name, MContext, *this, DevImgImpl)) return SourceBasedKernel; } @@ -961,10 +968,14 @@ class kernel_bundle_impl { if (std::optional MaybeKernelID = sycl::detail::ProgramManager::getInstance().tryGetSYCLKernelID( Name)) - return tryGetOfflineKernel(*MaybeKernelID, Self); + return tryGetOfflineKernel(*MaybeKernelID); return nullptr; } + std::shared_ptr shared_from_this() const { + return const_cast(this)->Base::shared_from_this(); + } + private: DeviceGlobalMapEntry *getDeviceGlobalEntry(const std::string &Name) const { if (!hasSourceBasedImages()) { diff --git a/sycl/source/detail/kernel_impl.cpp b/sycl/source/detail/kernel_impl.cpp index d5197a238c0d8..2977cffd28fc2 100644 --- a/sycl/source/detail/kernel_impl.cpp +++ b/sycl/source/detail/kernel_impl.cpp @@ -39,7 +39,7 @@ kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &Context, kernel_impl::kernel_impl(ur_kernel_handle_t Kernel, context_impl &ContextImpl, DeviceImageImplPtr DeviceImageImpl, - KernelBundleImplPtr KernelBundleImpl, + KernelBundleImplPtr &&KernelBundleImpl, const KernelArgMask *ArgMask, ur_program_handle_t Program, std::mutex *CacheMutex) : MKernel(Kernel), MContext(ContextImpl.shared_from_this()), diff --git a/sycl/source/detail/kernel_impl.hpp b/sycl/source/detail/kernel_impl.hpp index 3dc45ae8e8602..4aa67ee165e62 100644 --- a/sycl/source/detail/kernel_impl.hpp +++ b/sycl/source/detail/kernel_impl.hpp @@ -51,7 +51,7 @@ class kernel_impl { /// \param KernelBundleImpl is a valid instance of kernel_bundle_impl kernel_impl(ur_kernel_handle_t Kernel, context_impl &ContextImpl, DeviceImageImplPtr DeviceImageImpl, - KernelBundleImplPtr KernelBundleImpl, + KernelBundleImplPtr &&KernelBundleImpl, const KernelArgMask *ArgMask, ur_program_handle_t Program, std::mutex *CacheMutex); diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 94b98877d6399..4404bbf98352f 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1985,7 +1985,7 @@ std::string instrumentationGetKernelName( void instrumentationAddExtraKernelMetadata( xpti_td *&CmdTraceEvent, const NDRDescT &NDRDesc, - const std::shared_ptr &KernelBundleImplPtr, + detail::kernel_bundle_impl *KernelBundleImplPtr, KernelNameStrRefT KernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, const std::shared_ptr &SyclKernel, queue_impl *Queue, @@ -2003,9 +2003,9 @@ void instrumentationAddExtraKernelMetadata( if (!SyclKernel->isCreatedFromSource()) EliminatedArgMask = SyclKernel->getKernelArgMask(); } else if (auto SyclKernelImpl = - KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel(KernelName) + : std::shared_ptr{nullptr}) { EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); } else if (Queue) { // NOTE: Queue can be null when kernel is directly enqueued to a command @@ -2102,8 +2102,7 @@ std::pair emitKernelInstrumentationData( const detail::code_location &CodeLoc, bool IsTopCodeLoc, const std::string_view SyclKernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, queue_impl *Queue, - const NDRDescT &NDRDesc, - const std::shared_ptr &KernelBundleImplPtr, + const NDRDescT &NDRDesc, detail::kernel_bundle_impl *KernelBundleImplPtr, std::vector &CGArgs) { auto XptiObjects = std::make_pair(nullptr, -1); @@ -2196,7 +2195,7 @@ void ExecCGCommand::emitInstrumentationData() { auto KernelCG = reinterpret_cast(MCommandGroup.get()); instrumentationAddExtraKernelMetadata( - CmdTraceEvent, KernelCG->MNDRDesc, KernelCG->getKernelBundle(), + CmdTraceEvent, KernelCG->MNDRDesc, KernelCG->getKernelBundle().get(), KernelCG->MKernelName, KernelCG->MKernelNameBasedCachePtr, KernelCG->MSyclKernel, MQueue.get(), KernelCG->MArgs); } @@ -2539,10 +2538,9 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, UrKernel = Kernel->getHandleRef(); EliminatedArgMask = Kernel->getKernelArgMask(); } else if (auto SyclKernelImpl = - KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - CommandGroup.MKernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr}) { + KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel( + CommandGroup.MKernelName) + : std::shared_ptr{nullptr}) { UrKernel = SyclKernelImpl->getHandleRef(); DeviceImageImpl = SyclKernelImpl->getDeviceImage(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); @@ -2663,7 +2661,7 @@ ur_result_t enqueueImpCommandBufferKernel( void enqueueImpKernel( queue_impl &Queue, NDRDescT &NDRDesc, std::vector &Args, - const std::shared_ptr &KernelBundleImplPtr, + detail::kernel_bundle_impl *KernelBundleImplPtr, const detail::kernel_impl *MSyclKernel, KernelNameStrRefT KernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, std::vector &RawEvents, detail::event_impl *OutEventImpl, @@ -2699,10 +2697,10 @@ void enqueueImpKernel( // their duplication in such cases. KernelMutex = &MSyclKernel->getNoncacheableEnqueueMutex(); EliminatedArgMask = MSyclKernel->getKernelArgMask(); - } else if ((SyclKernelImpl = KernelBundleImplPtr - ? KernelBundleImplPtr->tryGetKernel( - KernelName, KernelBundleImplPtr) - : std::shared_ptr{nullptr})) { + } else if ((SyclKernelImpl = + KernelBundleImplPtr + ? KernelBundleImplPtr->tryGetKernel(KernelName) + : std::shared_ptr{nullptr})) { Kernel = SyclKernelImpl->getHandleRef(); DeviceImageImpl = SyclKernelImpl->getDeviceImage(); @@ -3261,10 +3259,11 @@ ur_result_t ExecCGCommand::enqueueImpQueue() { assert(BinImage && "Failed to obtain a binary image."); } enqueueImpKernel( - *MQueue, NDRDesc, Args, ExecKernel->getKernelBundle(), SyclKernel.get(), - KernelName, ExecKernel->MKernelNameBasedCachePtr, RawEvents, EventImpl, - getMemAllocationFunc, ExecKernel->MKernelCacheConfig, - ExecKernel->MKernelIsCooperative, ExecKernel->MKernelUsesClusterLaunch, + *MQueue, NDRDesc, Args, ExecKernel->getKernelBundle().get(), + SyclKernel.get(), KernelName, ExecKernel->MKernelNameBasedCachePtr, + RawEvents, EventImpl, getMemAllocationFunc, + ExecKernel->MKernelCacheConfig, ExecKernel->MKernelIsCooperative, + ExecKernel->MKernelUsesClusterLaunch, ExecKernel->MKernelWorkGroupMemorySize, BinImage); return UR_RESULT_SUCCESS; diff --git a/sycl/source/detail/scheduler/commands.hpp b/sycl/source/detail/scheduler/commands.hpp index c21c8bf240255..258a92adce002 100644 --- a/sycl/source/detail/scheduler/commands.hpp +++ b/sycl/source/detail/scheduler/commands.hpp @@ -630,7 +630,7 @@ ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue, void enqueueImpKernel( queue_impl &Queue, NDRDescT &NDRDesc, std::vector &Args, - const std::shared_ptr &KernelBundleImplPtr, + detail::kernel_bundle_impl *KernelBundleImplPtr, const detail::kernel_impl *MSyclKernel, KernelNameStrRefT KernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, std::vector &RawEvents, detail::event_impl *OutEventImpl, @@ -699,8 +699,7 @@ std::pair emitKernelInstrumentationData( const detail::code_location &CodeLoc, bool IsTopCodeLoc, std::string_view SyclKernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, queue_impl *Queue, - const NDRDescT &NDRDesc, - const std::shared_ptr &KernelBundleImplPtr, + const NDRDescT &NDRDesc, detail::kernel_bundle_impl *KernelBundleImplPtr, std::vector &CGArgs); #endif diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index a28f21c408f80..6ce506849601a 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -361,27 +361,54 @@ bool handler::isStateExplicitKernelBundle() const { return impl->isStateExplicitKernelBundle(); } +#ifndef __INTEL_PREVIEW_BREAKING_CHANGES // Returns a shared_ptr to the kernel_bundle. // If there is no kernel_bundle created: // returns newly created kernel_bundle if Insert is true // returns shared_ptr(nullptr) if Insert is false std::shared_ptr handler::getOrInsertHandlerKernelBundle(bool Insert) const { - if (!impl->MKernelBundle && Insert) { - context Ctx = detail::createSyclObjFromImpl(impl->get_context()); - impl->MKernelBundle = - detail::getSyclObjImpl(get_kernel_bundle( - Ctx, {detail::createSyclObjFromImpl(impl->get_device())}, - {})); - } + if (impl->MKernelBundle || !Insert) + return impl->MKernelBundle; + + context Ctx = detail::createSyclObjFromImpl(impl->get_context()); + impl->MKernelBundle = + detail::getSyclObjImpl(get_kernel_bundle( + Ctx, {detail::createSyclObjFromImpl(impl->get_device())}, + {})); return impl->MKernelBundle; } +#endif + +// Returns a ptr to the kernel_bundle. +// If there is no kernel_bundle created: +// returns newly created kernel_bundle if Insert is true +// returns nullptr if Insert is false +detail::kernel_bundle_impl * +handler::getOrInsertHandlerKernelBundlePtr(bool Insert) const { + if (impl->MKernelBundle || !Insert) + return impl->MKernelBundle.get(); + + context Ctx = detail::createSyclObjFromImpl(impl->get_context()); + impl->MKernelBundle = + detail::getSyclObjImpl(get_kernel_bundle( + Ctx, {detail::createSyclObjFromImpl(impl->get_device())}, + {})); + return impl->MKernelBundle.get(); +} // Sets kernel bundle to the provided one. +template +void handler::setHandlerKernelBundle(SharedPtrT &&NewKernelBundleImpPtr) { + impl->MKernelBundle = std::forward(NewKernelBundleImpPtr); +} + +#ifndef __INTEL_PREVIEW_BREAKING_CHANGES void handler::setHandlerKernelBundle( const std::shared_ptr &NewKernelBundleImpPtr) { impl->MKernelBundle = NewKernelBundleImpPtr; } +#endif void handler::setHandlerKernelBundle(kernel Kernel) { // Kernel may not have an associated kernel bundle if it is created from a @@ -389,7 +416,7 @@ void handler::setHandlerKernelBundle(kernel Kernel) { // the other way around: getSyclObjImp(Kernel->get_kernel_bundle()). std::shared_ptr KernelBundleImpl = detail::getSyclObjImpl(Kernel)->get_kernel_bundle(); - setHandlerKernelBundle(KernelBundleImpl); + setHandlerKernelBundle(std::move(KernelBundleImpl)); } #ifdef __INTEL_PREVIEW_BREAKING_CHANGES @@ -462,16 +489,15 @@ event handler::finalize() { if (type == detail::CGType::Kernel) { // If there were uses of set_specialization_constant build the kernel_bundle - std::shared_ptr KernelBundleImpPtr = - getOrInsertHandlerKernelBundle(/*Insert=*/false); + detail::kernel_bundle_impl *KernelBundleImpPtr = + getOrInsertHandlerKernelBundlePtr(/*Insert=*/false); if (KernelBundleImpPtr) { // Make sure implicit non-interop kernel bundles have the kernel if (!impl->isStateExplicitKernelBundle() && !(MKernel && MKernel->isInterop()) && (KernelBundleImpPtr->empty() || KernelBundleImpPtr->hasSYCLOfflineImages()) && - !KernelBundleImpPtr->tryGetKernel(toKernelNameStrT(MKernelName), - KernelBundleImpPtr)) { + !KernelBundleImpPtr->tryGetKernel(toKernelNameStrT(MKernelName))) { detail::device_impl &Dev = impl->get_device(); kernel_id KernelID = detail::ProgramManager::getInstance().getSYCLKernelID( @@ -484,11 +510,13 @@ event handler::finalize() { KernelBundleImpPtr->get_bundle_state() == bundle_state::input) { auto KernelBundle = detail::createSyclObjFromImpl>( - KernelBundleImpPtr); + *KernelBundleImpPtr); kernel_bundle ExecKernelBundle = build(KernelBundle); - KernelBundleImpPtr = detail::getSyclObjImpl(ExecKernelBundle); - setHandlerKernelBundle(KernelBundleImpPtr); + KernelBundleImpPtr = detail::getSyclObjImpl(ExecKernelBundle).get(); + // Raw ptr KernelBundleImpPtr is valid, because we saved the + // shared_ptr to the handler + setHandlerKernelBundle(KernelBundleImpPtr->shared_from_this()); KernelInserted = KernelBundleImpPtr->add_kernel( KernelID, detail::createSyclObjFromImpl(Dev)); } @@ -503,9 +531,11 @@ event handler::finalize() { // Underlying level expects kernel_bundle to be in executable state kernel_bundle ExecBundle = build( detail::createSyclObjFromImpl>( - KernelBundleImpPtr)); - KernelBundleImpPtr = detail::getSyclObjImpl(ExecBundle); - setHandlerKernelBundle(KernelBundleImpPtr); + *KernelBundleImpPtr)); + KernelBundleImpPtr = detail::getSyclObjImpl(ExecBundle).get(); + // Raw ptr KernelBundleImpPtr is valid, because we saved the shared_ptr + // to the handler + setHandlerKernelBundle(KernelBundleImpPtr->shared_from_this()); break; } case bundle_state::executable: @@ -1344,8 +1374,8 @@ detail::ABINeutralKernelNameStrT handler::getKernelName() { } void handler::verifyUsedKernelBundleInternal(detail::string_view KernelName) { - auto UsedKernelBundleImplPtr = - getOrInsertHandlerKernelBundle(/*Insert=*/false); + detail::kernel_bundle_impl *UsedKernelBundleImplPtr = + getOrInsertHandlerKernelBundlePtr(/*Insert=*/false); if (!UsedKernelBundleImplPtr) return; @@ -2234,6 +2264,14 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) { impl->MUserFacingNodeType = Type; } +kernel_bundle handler::getKernelBundle() const { + detail::kernel_bundle_impl *KernelBundleImplPtr = + getOrInsertHandlerKernelBundlePtr(/*Insert=*/true); + + return detail::createSyclObjFromImpl>( + *KernelBundleImplPtr); +} + std::optional> handler::getMaxWorkGroups() { device_impl &DeviceImpl = impl->get_device(); std::array UrResult = {}; diff --git a/sycl/source/kernel_bundle.cpp b/sycl/source/kernel_bundle.cpp index 276f651465576..4d3d430eecd80 100644 --- a/sycl/source/kernel_bundle.cpp +++ b/sycl/source/kernel_bundle.cpp @@ -90,7 +90,7 @@ bool kernel_bundle_plain::native_specialization_constant() const noexcept { } kernel kernel_bundle_plain::get_kernel(const kernel_id &KernelID) const { - return impl->get_kernel(KernelID, impl); + return impl->get_kernel(KernelID); } const device_image_plain *kernel_bundle_plain::begin() const { @@ -135,7 +135,7 @@ bool kernel_bundle_plain::ext_oneapi_has_kernel(detail::string_view name) { } kernel kernel_bundle_plain::ext_oneapi_get_kernel(detail::string_view name) { - return impl->ext_oneapi_get_kernel(std::string(std::string_view(name)), impl); + return impl->ext_oneapi_get_kernel(std::string(std::string_view(name))); } detail::string @@ -194,34 +194,32 @@ kernel_id get_kernel_id_impl(string_view KernelName) { detail::KernelBundleImplPtr get_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, bundle_state State) { - return std::make_shared(Ctx, Devs, State); + return detail::kernel_bundle_impl::create(Ctx, Devs, State); } detail::KernelBundleImplPtr get_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, const std::vector &KernelIDs, bundle_state State) { - return std::make_shared(Ctx, Devs, KernelIDs, - State); + return detail::kernel_bundle_impl::create(Ctx, Devs, KernelIDs, State); } detail::KernelBundleImplPtr get_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, bundle_state State, const DevImgSelectorImpl &Selector) { - return std::make_shared(Ctx, Devs, Selector, - State); + return detail::kernel_bundle_impl::create(Ctx, Devs, Selector, State); } detail::KernelBundleImplPtr get_empty_interop_kernel_bundle_impl(const context &Ctx, const std::vector &Devs) { - return std::make_shared(Ctx, Devs); + return detail::kernel_bundle_impl::create(Ctx, Devs); } std::shared_ptr join_impl(const std::vector &Bundles, bundle_state State) { - return std::make_shared(Bundles, State); + return detail::kernel_bundle_impl::create(Bundles, State); } bool has_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, @@ -301,22 +299,21 @@ bool has_kernel_bundle_impl(const context &Ctx, const std::vector &Devs, std::shared_ptr compile_impl(const kernel_bundle &InputBundle, const std::vector &Devs, const property_list &PropList) { - return std::make_shared( - InputBundle, Devs, PropList, bundle_state::object); + return detail::kernel_bundle_impl::create(InputBundle, Devs, PropList, + bundle_state::object); } std::shared_ptr link_impl(const std::vector> &ObjectBundles, const std::vector &Devs, const property_list &PropList) { - return std::make_shared(ObjectBundles, Devs, - PropList); + return detail::kernel_bundle_impl::create(ObjectBundles, Devs, PropList); } std::shared_ptr build_impl(const kernel_bundle &InputBundle, const std::vector &Devs, const property_list &PropList) { - return std::make_shared( - InputBundle, Devs, PropList, bundle_state::executable); + return detail::kernel_bundle_impl::create(InputBundle, Devs, PropList, + bundle_state::executable); } // This function finds intersection of associated devices in common for all @@ -488,8 +485,7 @@ make_kernel_bundle_from_source(const context &SyclContext, // } std::shared_ptr KBImpl = - std::make_shared(SyclContext, Language, Source, - IncludePairs); + kernel_bundle_impl::create(SyclContext, Language, Source, IncludePairs); return sycl::detail::createSyclObjFromImpl(std::move(KBImpl)); } @@ -503,7 +499,7 @@ source_kb make_kernel_bundle_from_source(const context &SyclContext, "kernel_bundle creation from source not supported"); std::shared_ptr KBImpl = - std::make_shared(SyclContext, Language, Bytes); + kernel_bundle_impl::create(SyclContext, Language, Bytes); return sycl::detail::createSyclObjFromImpl(std::move(KBImpl)); } diff --git a/sycl/test/abi/sycl_symbols_linux.dump b/sycl/test/abi/sycl_symbols_linux.dump index 6d3abbb03b0f1..0fc1af0bafe99 100644 --- a/sycl/test/abi/sycl_symbols_linux.dump +++ b/sycl/test/abi/sycl_symbols_linux.dump @@ -4083,11 +4083,13 @@ _ZNK4sycl3_V17context8get_infoINS0_4info7context8platformEEENS0_6detail20is_cont _ZNK4sycl3_V17context9getNativeEv _ZNK4sycl3_V17handler11eventNeededEv _ZNK4sycl3_V17handler15getCommandGraphEv +_ZNK4sycl3_V17handler15getKernelBundleEv _ZNK4sycl3_V17handler16getDeviceBackendEv _ZNK4sycl3_V17handler17getContextImplPtrEv _ZNK4sycl3_V17handler21HasAssociatedAccessorEPNS0_6detail16AccessorImplHostENS0_6access6targetE _ZNK4sycl3_V17handler27isStateExplicitKernelBundleEv _ZNK4sycl3_V17handler30getOrInsertHandlerKernelBundleEb +_ZNK4sycl3_V17handler33getOrInsertHandlerKernelBundlePtrEb _ZNK4sycl3_V17handler7getTypeEv _ZNK4sycl3_V17sampler11getPropListEv _ZNK4sycl3_V17sampler18get_filtering_modeEv diff --git a/sycl/test/abi/sycl_symbols_windows.dump b/sycl/test/abi/sycl_symbols_windows.dump index ac99be10319ae..0b198da691920 100644 --- a/sycl/test/abi/sycl_symbols_windows.dump +++ b/sycl/test/abi/sycl_symbols_windows.dump @@ -4065,6 +4065,7 @@ ?getElementSize@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBAHXZ ?getElementSize@image_plain@detail@_V1@sycl@@IEBA_KXZ ?getEndTime@HostProfilingInfo@detail@_V1@sycl@@QEBA_KXZ +?getKernelBundle@handler@_V1@sycl@@AEBA?AV?$kernel_bundle@$0A@@23@XZ ?getKernelName@handler@_V1@sycl@@AEAA?AVstring@detail@23@XZ ?getMaxWorkGroups@handler@_V1@sycl@@AEAA?AV?$optional@V?$array@_K$02@std@@@std@@XZ ?getMaxWorkGroups_v2@handler@_V1@sycl@@AEAA?AV?$tuple@V?$array@_K$02@std@@_N@std@@XZ @@ -4095,6 +4096,7 @@ ?getOffset@AccessorBaseHost@detail@_V1@sycl@@QEAAAEAV?$id@$02@34@XZ ?getOffset@AccessorBaseHost@detail@_V1@sycl@@QEBAAEBV?$id@$02@34@XZ ?getOrInsertHandlerKernelBundle@handler@_V1@sycl@@AEBA?AV?$shared_ptr@Vkernel_bundle_impl@detail@_V1@sycl@@@std@@_N@Z +?getOrInsertHandlerKernelBundlePtr@handler@_V1@sycl@@AEBAPEAVkernel_bundle_impl@detail@23@_N@Z ?getPitch@SampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AV?$id@$02@34@XZ ?getPitch@UnsampledImageAccessorBaseHost@detail@_V1@sycl@@QEBA?AV?$id@$02@34@XZ ?getPixelCoordLinearFiltMode@detail@_V1@sycl@@YA?AV?$vec@H$07@23@V?$vec@M$03@23@W4addressing_mode@23@V?$range@$02@23@AEAV523@@Z diff --git a/sycl/unittests/program_manager/arg_mask/EliminatedArgMask.cpp b/sycl/unittests/program_manager/arg_mask/EliminatedArgMask.cpp index ff464c6d2a1ee..cc72ba7f8b2ef 100644 --- a/sycl/unittests/program_manager/arg_mask/EliminatedArgMask.cpp +++ b/sycl/unittests/program_manager/arg_mask/EliminatedArgMask.cpp @@ -182,8 +182,8 @@ const sycl::detail::KernelArgMask *getKernelArgMaskFromBundle( EXPECT_TRUE(KernelBundleImplPtr) << "Expect command group to contain kernel bundle"; - auto SyclKernelImpl = KernelBundleImplPtr->tryGetKernel( - ExecKernel->MKernelName, KernelBundleImplPtr); + auto SyclKernelImpl = + KernelBundleImplPtr->tryGetKernel(ExecKernel->MKernelName); EXPECT_TRUE(SyclKernelImpl != nullptr); std::shared_ptr DeviceImageImpl = SyclKernelImpl->getDeviceImage();