-
Notifications
You must be signed in to change notification settings - Fork 797
[SYCL] enable_shared_from_this for kernel_bundle_impl #18899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
06ee707
1ad2b56
c6130ff
ea89dd2
a777b1e
7974ec0
6aa90e9
4b651b6
986711d
47af740
60c1806
6c42146
1e3f2d4
35a8acf
3f1efb2
b2f44c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,10 +67,15 @@ class kernel_impl; | |
/// The class is an impl counterpart of the sycl::kernel_bundle. | ||
// It provides an access and utilities to manage set of sycl::device_images | ||
// objects. | ||
class kernel_bundle_impl { | ||
class kernel_bundle_impl | ||
: public std::enable_shared_from_this<kernel_bundle_impl> { | ||
|
||
using SpecConstMapT = std::map<std::string, std::vector<unsigned char>>; | ||
|
||
struct private_tag { | ||
explicit private_tag() = default; | ||
}; | ||
|
||
void common_ctor_checks() const { | ||
const bool AllDevicesInTheContext = | ||
checkAllDevicesAreInContext(MDevices, MContext); | ||
|
@@ -92,7 +97,8 @@ class kernel_bundle_impl { | |
} | ||
|
||
public: | ||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State) | ||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State, | ||
private_tag) | ||
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { | ||
|
||
common_ctor_checks(); | ||
|
@@ -103,7 +109,7 @@ class kernel_bundle_impl { | |
} | ||
|
||
// Interop constructor used by make_kernel | ||
kernel_bundle_impl(context Ctx, std::vector<device> Devs) | ||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, private_tag) | ||
: MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) { | ||
if (!checkAllDevicesAreInContext(Devs, Ctx)) | ||
throw sycl::exception( | ||
|
@@ -114,8 +120,8 @@ class kernel_bundle_impl { | |
|
||
// Interop constructor | ||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, | ||
device_image_plain &DevImage) | ||
: kernel_bundle_impl(Ctx, Devs) { | ||
device_image_plain &DevImage, private_tag Tag) | ||
: kernel_bundle_impl(Ctx, Devs, Tag) { | ||
MDeviceImages.emplace_back(DevImage); | ||
MUniqueDeviceImages.emplace_back(DevImage); | ||
} | ||
|
@@ -125,7 +131,7 @@ class kernel_bundle_impl { | |
// signature | ||
kernel_bundle_impl(const kernel_bundle<bundle_state::input> &InputBundle, | ||
std::vector<device> Devs, const property_list &PropList, | ||
bundle_state TargetState) | ||
bundle_state TargetState, private_tag) | ||
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)), | ||
MState(TargetState) { | ||
|
||
|
@@ -193,7 +199,7 @@ class kernel_bundle_impl { | |
// Matches sycl::link | ||
kernel_bundle_impl( | ||
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles, | ||
std::vector<device> Devs, const property_list &PropList) | ||
std::vector<device> Devs, const property_list &PropList, private_tag) | ||
: MDevices(std::move(Devs)), MState(bundle_state::executable) { | ||
if (MDevices.empty()) | ||
throw sycl::exception(make_error_code(errc::invalid), | ||
|
@@ -414,7 +420,7 @@ class kernel_bundle_impl { | |
|
||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, | ||
const std::vector<kernel_id> &KernelIDs, | ||
bundle_state State) | ||
bundle_state State, private_tag) | ||
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { | ||
|
||
common_ctor_checks(); | ||
|
@@ -425,7 +431,8 @@ class kernel_bundle_impl { | |
} | ||
|
||
kernel_bundle_impl(context Ctx, std::vector<device> Devs, | ||
const DevImgSelectorImpl &Selector, bundle_state State) | ||
const DevImgSelectorImpl &Selector, bundle_state State, | ||
private_tag) | ||
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) { | ||
|
||
common_ctor_checks(); | ||
|
@@ -437,7 +444,7 @@ class kernel_bundle_impl { | |
|
||
// C'tor matches sycl::join API | ||
kernel_bundle_impl(const std::vector<detail::KernelBundleImplPtr> &Bundles, | ||
bundle_state State) | ||
bundle_state State, private_tag) | ||
: MState(State) { | ||
if (Bundles.empty()) | ||
return; | ||
|
@@ -501,7 +508,8 @@ class kernel_bundle_impl { | |
// oneapi_ext_kernel_compiler | ||
// construct from source string | ||
kernel_bundle_impl(const context &Context, syclex::source_language Lang, | ||
const std::string &Src, include_pairs_t IncludePairsVec) | ||
const std::string &Src, include_pairs_t IncludePairsVec, | ||
private_tag) | ||
: MContext(Context), MDevices(Context.get_devices()), | ||
MDeviceImages{device_image_plain{std::make_shared<device_image_impl>( | ||
Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}}, | ||
|
@@ -513,7 +521,7 @@ class kernel_bundle_impl { | |
// oneapi_ext_kernel_compiler | ||
// construct from source bytes | ||
kernel_bundle_impl(const context &Context, syclex::source_language Lang, | ||
const std::vector<std::byte> &Bytes) | ||
const std::vector<std::byte> &Bytes, private_tag) | ||
: MContext(Context), MDevices(Context.get_devices()), | ||
MDeviceImages{device_image_plain{std::make_shared<device_image_impl>( | ||
Bytes, MContext, MDevices, Lang)}}, | ||
|
@@ -528,7 +536,7 @@ class kernel_bundle_impl { | |
const context &Context, const std::vector<device> &Devs, | ||
std::vector<device_image_plain> &&DevImgs, | ||
std::vector<std::shared_ptr<ManagedDeviceBinaries>> &&DevBinaries, | ||
bundle_state State) | ||
bundle_state State, private_tag) | ||
: MContext(Context), MDevices(Devs), | ||
MSharedDeviceBinaries(std::move(DevBinaries)), | ||
MUniqueDeviceImages(std::move(DevImgs)), MState(State) { | ||
|
@@ -540,6 +548,12 @@ class kernel_bundle_impl { | |
MDeviceImages.emplace_back(DevImg); | ||
} | ||
|
||
template <typename... Ts> | ||
static std::shared_ptr<kernel_bundle_impl> create(Ts &&...args) { | ||
return std::make_shared<kernel_bundle_impl>(std::forward<Ts>(args)..., | ||
private_tag{}); | ||
} | ||
|
||
std::shared_ptr<kernel_bundle_impl> build_from_source( | ||
const std::vector<device> Devices, | ||
const std::vector<sycl::detail::string_view> &BuildOptions, | ||
|
@@ -559,9 +573,8 @@ class kernel_bundle_impl { | |
for (std::shared_ptr<device_image_impl> &DevImgImpl : NewDevImgImpls) | ||
NewDevImgs.emplace_back(std::move(DevImgImpl)); | ||
} | ||
return std::make_shared<kernel_bundle_impl>( | ||
MContext, Devices, std::move(NewDevImgs), std::move(NewBinReso), | ||
bundle_state::executable); | ||
return create(MContext, Devices, std::move(NewDevImgs), | ||
std::move(NewBinReso), bundle_state::executable); | ||
} | ||
|
||
std::shared_ptr<kernel_bundle_impl> compile_from_source( | ||
|
@@ -584,9 +597,8 @@ class kernel_bundle_impl { | |
for (std::shared_ptr<device_image_impl> &DevImgImpl : NewDevImgImpls) | ||
NewDevImgs.emplace_back(std::move(DevImgImpl)); | ||
} | ||
return std::make_shared<kernel_bundle_impl>( | ||
MContext, Devices, std::move(NewDevImgs), std::move(NewBinReso), | ||
bundle_state::object); | ||
return create(MContext, Devices, std::move(NewDevImgs), | ||
std::move(NewBinReso), bundle_state::object); | ||
} | ||
|
||
public: | ||
|
@@ -597,15 +609,16 @@ class kernel_bundle_impl { | |
}); | ||
} | ||
|
||
kernel | ||
ext_oneapi_get_kernel(const std::string &Name, | ||
const std::shared_ptr<kernel_bundle_impl> &Self) const { | ||
kernel ext_oneapi_get_kernel(const std::string &Name) const { | ||
if (!hasSourceBasedImages()) | ||
throw sycl::exception(make_error_code(errc::invalid), | ||
"'ext_oneapi_get_kernel' is only available in " | ||
"kernel_bundles successfully built from " | ||
"kernel_bundle<bundle_state::ext_oneapi_source>."); | ||
|
||
std::shared_ptr<kernel_bundle_impl> Self = | ||
const_cast<kernel_bundle_impl *>(this)->shared_from_this(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporary until tryGetSourceBasedKernel is updated, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I missed the idea. How is it related to |
||
|
||
// TODO: When linking is properly implemented for kernel compiler binaries, | ||
// there can be scenarios where multiple binaries have the same | ||
// kernels. In this case, all these bundles should be found and the | ||
|
@@ -717,11 +730,8 @@ class kernel_bundle_impl { | |
return Result; | ||
} | ||
|
||
kernel | ||
get_kernel(const kernel_id &KernelID, | ||
const std::shared_ptr<detail::kernel_bundle_impl> &Self) const { | ||
if (std::shared_ptr<kernel_impl> KernelImpl = | ||
tryGetOfflineKernel(KernelID, Self)) | ||
kernel get_kernel(const kernel_id &KernelID) const { | ||
if (std::shared_ptr<kernel_impl> KernelImpl = tryGetOfflineKernel(KernelID)) | ||
return detail::createSyclObjFromImpl<kernel>(std::move(KernelImpl)); | ||
throw sycl::exception(make_error_code(errc::invalid), | ||
"The kernel bundle does not contain the kernel " | ||
|
@@ -876,9 +886,8 @@ class kernel_bundle_impl { | |
}); | ||
} | ||
|
||
std::shared_ptr<kernel_impl> tryGetOfflineKernel( | ||
const kernel_id &KernelID, | ||
const std::shared_ptr<detail::kernel_bundle_impl> &Self) const { | ||
std::shared_ptr<kernel_impl> | ||
tryGetOfflineKernel(const kernel_id &KernelID) const { | ||
using ImageImpl = std::shared_ptr<detail::device_image_impl>; | ||
// Selected image. | ||
ImageImpl SelectedImage = nullptr; | ||
|
@@ -938,13 +947,15 @@ class kernel_bundle_impl { | |
SelectedImage->get_ur_program_ref()); | ||
|
||
return std::make_shared<kernel_impl>( | ||
Kernel, detail::getSyclObjImpl(MContext), SelectedImage, Self, ArgMask, | ||
Kernel, detail::getSyclObjImpl(MContext), SelectedImage, | ||
const_cast<kernel_bundle_impl *>(this)->shared_from_this(), ArgMask, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can do
If compiler would complain about incomplete types, make it a template:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the idea is to have ugly cast in single place, right? Done. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm actually not sure if this or just dropping a bunch of |
||
SelectedImage->get_ur_program_ref(), CacheMutex); | ||
} | ||
|
||
std::shared_ptr<kernel_impl> | ||
tryGetKernel(detail::KernelNameStrRefT Name, | ||
const std::shared_ptr<kernel_bundle_impl> &Self) const { | ||
tryGetKernel(detail::KernelNameStrRefT Name) const { | ||
std::shared_ptr<kernel_bundle_impl> Self = | ||
const_cast<kernel_bundle_impl *>(this)->shared_from_this(); | ||
// TODO: For source-based kernels, it may be faster to keep a map between | ||
// {kernel_name, device} and their corresponding image. | ||
// First look through the kernels registered in source-based images. | ||
|
@@ -961,7 +972,7 @@ class kernel_bundle_impl { | |
if (std::optional<kernel_id> MaybeKernelID = | ||
sycl::detail::ProgramManager::getInstance().tryGetSYCLKernelID( | ||
Name)) | ||
return tryGetOfflineKernel(*MaybeKernelID, Self); | ||
return tryGetOfflineKernel(*MaybeKernelID); | ||
return nullptr; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't this be done now? Same in the next file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, done.