Skip to content

Commit 3d1002b

Browse files
[NFCI][SYCL] Change kernel_bundle_impl::MDevices to store raw device_impl * (#19484)
#18251 extended `device_impl`s' lifetimes until shutdown and #18270 started to pass devices as raw pointers in some of the APIs. This PR builds on top of that and extends usage of raw pointers/references/`device_range` as the devices are known to be alive and extra `std::shared_ptr`'s atomic increments aren't necessary and could be avoided. Since we change the type of `kernel_bundle_impl::MDevices`, other APIs in that class don't need to operate in terms of `sycl::device` or `std::shared_ptr<device_impl>` and we can switch them to use `devices_range` instead. A small number of other modifications are caused by these APIs' changes and are necessary to keep the code buildable.
1 parent fafd9cf commit 3d1002b

File tree

5 files changed

+76
-74
lines changed

5 files changed

+76
-74
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,11 +549,8 @@ class device_image_impl
549549

550550
devices_range get_devices() const noexcept { return MDevices; }
551551

552-
bool compatible_with_device(const device &Dev) const {
553-
return std::any_of(MDevices.begin(), MDevices.end(),
554-
[Dev = &*getSyclObjImpl(Dev)](device_impl *DevCand) {
555-
return Dev == DevCand;
556-
});
552+
bool compatible_with_device(device_impl &Dev) const {
553+
return get_devices().contains(Dev);
557554
}
558555

559556
const ur_program_handle_t &get_ur_program_ref() const noexcept {

sycl/source/detail/helpers.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class variadic_iterator {
5454
variadic_iterator(const variadic_iterator &) = default;
5555
variadic_iterator(variadic_iterator &&) = default;
5656
variadic_iterator(variadic_iterator &) = default;
57+
variadic_iterator &operator=(const variadic_iterator &) = default;
58+
variadic_iterator &operator=(variadic_iterator &&) = default;
5759

5860
template <typename IterTy>
5961
variadic_iterator(IterTy &&It) : It(std::forward<IterTy>(It)) {}
@@ -151,6 +153,12 @@ template <typename iterator> class iterator_range {
151153
return Container{std::move(Result)};
152154
}
153155

156+
bool contains(value_type &Other) const {
157+
return std::find_if(begin(), end(), [&Other](value_type &Elem) {
158+
return &Elem == &Other;
159+
}) != end();
160+
}
161+
154162
protected:
155163
template <typename Container>
156164
static constexpr bool has_reserve_v = has_reserve<Container>::value;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 64 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,17 @@ bool is_source_kernel_bundle_supported(
4848

4949
namespace detail {
5050

51-
static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
51+
inline bool checkAllDevicesAreInContext(devices_range Devices,
5252
const context &Context) {
53-
return std::all_of(
54-
Devices.begin(), Devices.end(), [&Context](const device &Dev) {
55-
return getSyclObjImpl(Context)->isDeviceValid(*getSyclObjImpl(Dev));
56-
});
53+
return std::all_of(Devices.begin(), Devices.end(),
54+
[&Context](device_impl &Dev) {
55+
return getSyclObjImpl(Context)->isDeviceValid(Dev);
56+
});
5757
}
5858

59-
static bool checkAllDevicesHaveAspect(const std::vector<device> &Devices,
60-
aspect Aspect) {
59+
inline bool checkAllDevicesHaveAspect(devices_range Devices, aspect Aspect) {
6160
return std::all_of(Devices.begin(), Devices.end(),
62-
[&Aspect](const device &Dev) { return Dev.has(Aspect); });
61+
[&Aspect](device_impl &Dev) { return Dev.has(Aspect); });
6362
}
6463

6564
namespace syclex = sycl::ext::oneapi::experimental;
@@ -100,9 +99,10 @@ class kernel_bundle_impl
10099
}
101100

102101
public:
103-
kernel_bundle_impl(context Ctx, std::vector<device> Devs, bundle_state State,
102+
kernel_bundle_impl(context Ctx, devices_range Devs, bundle_state State,
104103
private_tag)
105-
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
104+
: MContext(std::move(Ctx)),
105+
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
106106

107107
common_ctor_checks();
108108

@@ -112,8 +112,9 @@ class kernel_bundle_impl
112112
}
113113

114114
// Interop constructor used by make_kernel
115-
kernel_bundle_impl(context Ctx, std::vector<device> Devs, private_tag)
116-
: MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
115+
kernel_bundle_impl(context Ctx, devices_range Devs, private_tag)
116+
: MContext(Ctx), MDevices(Devs.to<std::vector<device_impl *>>()),
117+
MState(bundle_state::executable) {
117118
if (!checkAllDevicesAreInContext(Devs, Ctx))
118119
throw sycl::exception(
119120
make_error_code(errc::invalid),
@@ -122,9 +123,9 @@ class kernel_bundle_impl
122123
}
123124

124125
// Interop constructor
125-
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
126+
kernel_bundle_impl(context Ctx, devices_range Devs,
126127
device_image_plain &DevImage, private_tag Tag)
127-
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), Tag) {
128+
: kernel_bundle_impl(std::move(Ctx), Devs, Tag) {
128129
MDeviceImages.emplace_back(DevImage);
129130
MUniqueDeviceImages.emplace_back(DevImage);
130131
}
@@ -133,22 +134,19 @@ class kernel_bundle_impl
133134
// Have one constructor because sycl::build and sycl::compile have the same
134135
// signature
135136
kernel_bundle_impl(const kernel_bundle<bundle_state::input> &InputBundle,
136-
std::vector<device> Devs, const property_list &PropList,
137+
devices_range Devs, const property_list &PropList,
137138
bundle_state TargetState, private_tag)
138-
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
139-
MState(TargetState) {
139+
: MContext(InputBundle.get_context()),
140+
MDevices(Devs.to<std::vector<device_impl *>>()), MState(TargetState) {
140141

141142
kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl(InputBundle);
142143
MSpecConstValues = InputBundleImpl.get_spec_const_map_ref();
143144

144-
const std::vector<device> &InputBundleDevices =
145-
InputBundleImpl.get_devices();
145+
devices_range InputBundleDevices = InputBundleImpl.get_devices();
146146
const bool AllDevsAssociatedWithInputBundle =
147-
std::all_of(MDevices.begin(), MDevices.end(),
148-
[&InputBundleDevices](const device &Dev) {
149-
return InputBundleDevices.end() !=
150-
std::find(InputBundleDevices.begin(),
151-
InputBundleDevices.end(), Dev);
147+
std::all_of(get_devices().begin(), get_devices().end(),
148+
[&InputBundleDevices](device_impl &Dev) {
149+
return InputBundleDevices.contains(Dev);
152150
});
153151
if (MDevices.empty() || !AllDevsAssociatedWithInputBundle)
154152
throw sycl::exception(
@@ -163,8 +161,8 @@ class kernel_bundle_impl
163161
for (const DevImgPlainWithDeps &DevImgWithDeps :
164162
InputBundleImpl.MDeviceImages) {
165163
// Skip images which are not compatible with devices provided
166-
if (std::none_of(MDevices.begin(), MDevices.end(),
167-
[&DevImgWithDeps](const device &Dev) {
164+
if (std::none_of(get_devices().begin(), get_devices().end(),
165+
[&DevImgWithDeps](device_impl &Dev) {
168166
return getSyclObjImpl(DevImgWithDeps.getMain())
169167
->compatible_with_device(Dev);
170168
}))
@@ -206,8 +204,9 @@ class kernel_bundle_impl
206204
// Matches sycl::link
207205
kernel_bundle_impl(
208206
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
209-
std::vector<device> Devs, const property_list &PropList, private_tag)
210-
: MDevices(std::move(Devs)), MState(bundle_state::executable) {
207+
devices_range Devs, const property_list &PropList, private_tag)
208+
: MDevices(Devs.to<std::vector<device_impl *>>()),
209+
MState(bundle_state::executable) {
211210
if (MDevices.empty())
212211
throw sycl::exception(make_error_code(errc::invalid),
213212
"Vector of devices is empty");
@@ -226,16 +225,15 @@ class kernel_bundle_impl
226225
// Check if any of the devices in devs are not in the set of associated
227226
// devices for any of the bundles in ObjectBundles
228227
const bool AllDevsAssociatedWithInputBundles = std::all_of(
229-
MDevices.begin(), MDevices.end(), [&ObjectBundles](const device &Dev) {
228+
get_devices().begin(), get_devices().end(),
229+
[&ObjectBundles](device_impl &Dev) {
230230
// Number of devices is expected to be small
231231
return std::all_of(
232232
ObjectBundles.begin(), ObjectBundles.end(),
233233
[&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
234-
const std::vector<device> &BundleDevices =
234+
devices_range BundleDevices =
235235
getSyclObjImpl(KernelBundle)->get_devices();
236-
return BundleDevices.end() != std::find(BundleDevices.begin(),
237-
BundleDevices.end(),
238-
Dev);
236+
return BundleDevices.contains(Dev);
239237
});
240238
});
241239
if (!AllDevsAssociatedWithInputBundles)
@@ -363,41 +361,33 @@ class kernel_bundle_impl
363361
}
364362

365363
// Create a link graph and clone it for each device.
366-
device_impl &FirstDevice = *getSyclObjImpl(MDevices[0]);
367-
std::map<std::shared_ptr<device_impl>, LinkGraph<device_image_plain>>
368-
DevImageLinkGraphs;
364+
device_impl &FirstDevice = get_devices().front();
365+
std::map<device_impl *, LinkGraph<device_image_plain>> DevImageLinkGraphs;
369366
const auto &FirstGraph =
370367
DevImageLinkGraphs
371-
.emplace(FirstDevice.shared_from_this(),
368+
.emplace(&FirstDevice,
372369
LinkGraph<device_image_plain>{DevImages, Dependencies})
373370
.first->second;
374-
for (size_t I = 1; I < MDevices.size(); ++I)
375-
DevImageLinkGraphs.emplace(getSyclObjImpl(MDevices[I]),
376-
FirstGraph.Clone());
371+
for (device_impl &Dev : get_devices())
372+
DevImageLinkGraphs.emplace(&Dev, FirstGraph.Clone());
377373

378374
// Poison the images based on whether the corresponding device supports it.
379375
for (auto &GraphIt : DevImageLinkGraphs) {
380-
device Dev = createSyclObjFromImpl<device>(GraphIt.first);
376+
device_impl &Dev = *GraphIt.first;
381377
GraphIt.second.Poison([&Dev](const device_image_plain &DevImg) {
382378
return !getSyclObjImpl(DevImg)->compatible_with_device(Dev);
383379
});
384380
}
385381

386382
// Unify graphs after poisoning.
387-
std::map<std::vector<std::shared_ptr<device_impl>>,
388-
LinkGraph<device_image_plain>>
383+
std::map<std::vector<device_impl *>, LinkGraph<device_image_plain>>
389384
UnifiedGraphs = UnifyGraphs(DevImageLinkGraphs);
390385

391386
// Link based on the resulting graphs.
392387
for (auto &GraphIt : UnifiedGraphs) {
393-
std::vector<device> DeviceGroup;
394-
DeviceGroup.reserve(GraphIt.first.size());
395-
for (const auto &DeviceImgImpl : GraphIt.first)
396-
DeviceGroup.emplace_back(createSyclObjFromImpl<device>(DeviceImgImpl));
397-
398388
std::vector<device_image_plain> LinkedResults =
399389
detail::ProgramManager::getInstance().link(
400-
GraphIt.second.GetNodeValues(), DeviceGroup, PropList);
390+
GraphIt.second.GetNodeValues(), GraphIt.first, PropList);
401391
MDeviceImages.insert(MDeviceImages.end(), LinkedResults.begin(),
402392
LinkedResults.end());
403393
MUniqueDeviceImages.insert(MUniqueDeviceImages.end(),
@@ -410,8 +400,8 @@ class kernel_bundle_impl
410400
for (const DevImgPlainWithDeps *DeviceImageWithDeps :
411401
ImagesWithSpecConsts) {
412402
// Skip images which are not compatible with devices provided
413-
if (std::none_of(MDevices.begin(), MDevices.end(),
414-
[DeviceImageWithDeps](const device &Dev) {
403+
if (std::none_of(get_devices().begin(), get_devices().end(),
404+
[DeviceImageWithDeps](device_impl &Dev) {
415405
return getSyclObjImpl(DeviceImageWithDeps->getMain())
416406
->compatible_with_device(Dev);
417407
}))
@@ -438,10 +428,11 @@ class kernel_bundle_impl
438428
}
439429
}
440430

441-
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
431+
kernel_bundle_impl(context Ctx, devices_range Devs,
442432
const std::vector<kernel_id> &KernelIDs,
443433
bundle_state State, private_tag)
444-
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
434+
: MContext(std::move(Ctx)),
435+
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
445436

446437
common_ctor_checks();
447438

@@ -450,10 +441,11 @@ class kernel_bundle_impl
450441
fillUniqueDeviceImages();
451442
}
452443

453-
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
444+
kernel_bundle_impl(context Ctx, devices_range Devs,
454445
const DevImgSelectorImpl &Selector, bundle_state State,
455446
private_tag)
456-
: MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
447+
: MContext(std::move(Ctx)),
448+
MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
457449

458450
common_ctor_checks();
459451

@@ -548,7 +540,9 @@ class kernel_bundle_impl
548540
kernel_bundle_impl(const context &Context, syclex::source_language Lang,
549541
const std::string &Src, include_pairs_t IncludePairsVec,
550542
private_tag)
551-
: MContext(Context), MDevices(Context.get_devices()),
543+
: MContext(Context), MDevices(getSyclObjImpl(Context)
544+
->getDevices()
545+
.to<std::vector<device_impl *>>()),
552546
MDeviceImages{device_image_plain{device_image_impl::create(
553547
Src, MContext, MDevices, Lang, std::move(IncludePairsVec))}},
554548
MUniqueDeviceImages{MDeviceImages[0].getMain()},
@@ -560,7 +554,9 @@ class kernel_bundle_impl
560554
// construct from source bytes
561555
kernel_bundle_impl(const context &Context, syclex::source_language Lang,
562556
const std::vector<std::byte> &Bytes, private_tag)
563-
: MContext(Context), MDevices(Context.get_devices()),
557+
: MContext(Context), MDevices(getSyclObjImpl(Context)
558+
->getDevices()
559+
.to<std::vector<device_impl *>>()),
564560
MDeviceImages{device_image_plain{
565561
device_image_impl::create(Bytes, MContext, MDevices, Lang)}},
566562
MUniqueDeviceImages{MDeviceImages[0].getMain()},
@@ -571,11 +567,11 @@ class kernel_bundle_impl
571567
// oneapi_ext_kernel_compiler
572568
// construct from built source files
573569
kernel_bundle_impl(
574-
const context &Context, const std::vector<device> &Devs,
570+
const context &Context, devices_range Devs,
575571
std::vector<device_image_plain> &&DevImgs,
576572
std::vector<std::shared_ptr<ManagedDeviceBinaries>> &&DevBinaries,
577573
bundle_state State, private_tag)
578-
: MContext(Context), MDevices(Devs),
574+
: MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
579575
MSharedDeviceBinaries(std::move(DevBinaries)),
580576
MUniqueDeviceImages(std::move(DevImgs)), MState(State) {
581577
common_ctor_checks();
@@ -587,10 +583,11 @@ class kernel_bundle_impl
587583
}
588584

589585
// SYCLBIN constructor
590-
kernel_bundle_impl(const context &Context, const std::vector<device> &Devs,
586+
kernel_bundle_impl(const context &Context, devices_range Devs,
591587
const sycl::span<char> Bytes, bundle_state State,
592588
private_tag)
593-
: MContext(Context), MDevices(Devs), MState(State) {
589+
: MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
590+
MState(State) {
594591
common_ctor_checks();
595592

596593
auto &SYCLBIN = MSYCLBINs.emplace_back(
@@ -622,7 +619,7 @@ class kernel_bundle_impl
622619
}
623620

624621
std::shared_ptr<kernel_bundle_impl> build_from_source(
625-
const std::vector<device> Devices,
622+
devices_range Devices,
626623
const std::vector<sycl::detail::string_view> &BuildOptions,
627624
std::string *LogPtr,
628625
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -645,7 +642,7 @@ class kernel_bundle_impl
645642
}
646643

647644
std::shared_ptr<kernel_bundle_impl> compile_from_source(
648-
const std::vector<device> Devices,
645+
devices_range Devices,
649646
const std::vector<sycl::detail::string_view> &CompileOptions,
650647
std::string *LogPtr,
651648
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -733,8 +730,9 @@ class kernel_bundle_impl
733730
void *ext_oneapi_get_device_global_address(const std::string &Name,
734731
const device &Dev) const {
735732
DeviceGlobalMapEntry *Entry = getDeviceGlobalEntry(Name);
733+
device_impl &DeviceImpl = *getSyclObjImpl(Dev);
736734

737-
if (std::find(MDevices.begin(), MDevices.end(), Dev) == MDevices.end()) {
735+
if (!get_devices().contains(DeviceImpl)) {
738736
throw sycl::exception(make_error_code(errc::invalid),
739737
"kernel_bundle not built for device");
740738
}
@@ -745,7 +743,6 @@ class kernel_bundle_impl
745743
"'device_image_scope' property");
746744
}
747745

748-
device_impl &DeviceImpl = *getSyclObjImpl(Dev);
749746
bool SupportContextMemcpy = false;
750747
DeviceImpl.getAdapter().call<UrApiKind::urDeviceGetInfo>(
751748
DeviceImpl.getHandleRef(),
@@ -772,7 +769,7 @@ class kernel_bundle_impl
772769

773770
context get_context() const noexcept { return MContext; }
774771

775-
const std::vector<device> &get_devices() const noexcept { return MDevices; }
772+
devices_range get_devices() const noexcept { return MDevices; }
776773

777774
std::vector<kernel_id> get_kernel_ids() const {
778775
// Collect kernel ids from all device images, then remove duplicates
@@ -1111,7 +1108,7 @@ class kernel_bundle_impl
11111108
}
11121109

11131110
context MContext;
1114-
std::vector<device> MDevices;
1111+
std::vector<device_impl *> MDevices;
11151112

11161113
// For sycl_jit, building from source may have produced sycl binaries that
11171114
// the kernel_bundles now manage.

sycl/source/detail/kernel_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ kernel_impl::get_backend_info<info::device::version>() const {
167167
"the info::device::version info descriptor can only "
168168
"be queried with an OpenCL backend");
169169
}
170-
auto Devices = MKernelBundleImpl->get_devices();
170+
auto Devices = MKernelBundleImpl->get_devices().to<std::vector<device>>();
171171
if (Devices.empty()) {
172172
return "No available device";
173173
}

sycl/source/kernel_bundle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ context kernel_bundle_plain::get_context() const noexcept {
7474
}
7575

7676
std::vector<device> kernel_bundle_plain::get_devices() const noexcept {
77-
return impl->get_devices();
77+
return impl->get_devices().to<std::vector<device>>();
7878
}
7979

8080
std::vector<kernel_id> kernel_bundle_plain::get_kernel_ids() const {

0 commit comments

Comments
 (0)