@@ -48,18 +48,17 @@ bool is_source_kernel_bundle_supported(
48
48
49
49
namespace detail {
50
50
51
- static bool checkAllDevicesAreInContext (const std::vector<device> & Devices,
51
+ inline bool checkAllDevicesAreInContext (devices_range Devices,
52
52
const context &Context) {
53
- return std::all_of (
54
- Devices. begin (), Devices. end (), [&Context](const device &Dev) {
55
- return getSyclObjImpl (Context)->isDeviceValid (* getSyclObjImpl ( Dev) );
56
- });
53
+ return std::all_of (Devices. begin (), Devices. end (),
54
+ [&Context](device_impl &Dev) {
55
+ return getSyclObjImpl (Context)->isDeviceValid (Dev);
56
+ });
57
57
}
58
58
59
- static bool checkAllDevicesHaveAspect (const std::vector<device> &Devices,
60
- aspect Aspect) {
59
+ inline bool checkAllDevicesHaveAspect (devices_range Devices, aspect Aspect) {
61
60
return std::all_of (Devices.begin (), Devices.end (),
62
- [&Aspect](const device &Dev) { return Dev.has (Aspect); });
61
+ [&Aspect](device_impl &Dev) { return Dev.has (Aspect); });
63
62
}
64
63
65
64
namespace syclex = sycl::ext::oneapi::experimental;
@@ -100,9 +99,10 @@ class kernel_bundle_impl
100
99
}
101
100
102
101
public:
103
- kernel_bundle_impl (context Ctx, std::vector<device> Devs, bundle_state State,
102
+ kernel_bundle_impl (context Ctx, devices_range Devs, bundle_state State,
104
103
private_tag)
105
- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
104
+ : MContext(std::move(Ctx)),
105
+ MDevices (Devs.to<std::vector<device_impl *>>()), MState(State) {
106
106
107
107
common_ctor_checks ();
108
108
@@ -112,8 +112,9 @@ class kernel_bundle_impl
112
112
}
113
113
114
114
// Interop constructor used by make_kernel
115
- kernel_bundle_impl (context Ctx, std::vector<device> Devs, private_tag)
116
- : MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
115
+ kernel_bundle_impl (context Ctx, devices_range Devs, private_tag)
116
+ : MContext(Ctx), MDevices(Devs.to<std::vector<device_impl *>>()),
117
+ MState(bundle_state::executable) {
117
118
if (!checkAllDevicesAreInContext (Devs, Ctx))
118
119
throw sycl::exception (
119
120
make_error_code (errc::invalid),
@@ -122,9 +123,9 @@ class kernel_bundle_impl
122
123
}
123
124
124
125
// Interop constructor
125
- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
126
+ kernel_bundle_impl (context Ctx, devices_range Devs,
126
127
device_image_plain &DevImage, private_tag Tag)
127
- : kernel_bundle_impl(std::move(Ctx), std::move( Devs) , Tag) {
128
+ : kernel_bundle_impl(std::move(Ctx), Devs, Tag) {
128
129
MDeviceImages.emplace_back (DevImage);
129
130
MUniqueDeviceImages.emplace_back (DevImage);
130
131
}
@@ -133,22 +134,19 @@ class kernel_bundle_impl
133
134
// Have one constructor because sycl::build and sycl::compile have the same
134
135
// signature
135
136
kernel_bundle_impl (const kernel_bundle<bundle_state::input> &InputBundle,
136
- std::vector<device> Devs, const property_list &PropList,
137
+ devices_range Devs, const property_list &PropList,
137
138
bundle_state TargetState, private_tag)
138
- : MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
139
- MState (TargetState) {
139
+ : MContext(InputBundle.get_context()),
140
+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(TargetState) {
140
141
141
142
kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl (InputBundle);
142
143
MSpecConstValues = InputBundleImpl.get_spec_const_map_ref ();
143
144
144
- const std::vector<device> &InputBundleDevices =
145
- InputBundleImpl.get_devices ();
145
+ devices_range InputBundleDevices = InputBundleImpl.get_devices ();
146
146
const bool AllDevsAssociatedWithInputBundle =
147
- std::all_of (MDevices.begin (), MDevices.end (),
148
- [&InputBundleDevices](const device &Dev) {
149
- return InputBundleDevices.end () !=
150
- std::find (InputBundleDevices.begin (),
151
- InputBundleDevices.end (), Dev);
147
+ std::all_of (get_devices ().begin (), get_devices ().end (),
148
+ [&InputBundleDevices](device_impl &Dev) {
149
+ return InputBundleDevices.contains (Dev);
152
150
});
153
151
if (MDevices.empty () || !AllDevsAssociatedWithInputBundle)
154
152
throw sycl::exception (
@@ -163,8 +161,8 @@ class kernel_bundle_impl
163
161
for (const DevImgPlainWithDeps &DevImgWithDeps :
164
162
InputBundleImpl.MDeviceImages ) {
165
163
// Skip images which are not compatible with devices provided
166
- if (std::none_of (MDevices .begin (), MDevices .end (),
167
- [&DevImgWithDeps](const device &Dev) {
164
+ if (std::none_of (get_devices () .begin (), get_devices () .end (),
165
+ [&DevImgWithDeps](device_impl &Dev) {
168
166
return getSyclObjImpl (DevImgWithDeps.getMain ())
169
167
->compatible_with_device (Dev);
170
168
}))
@@ -206,8 +204,9 @@ class kernel_bundle_impl
206
204
// Matches sycl::link
207
205
kernel_bundle_impl (
208
206
const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
209
- std::vector<device> Devs, const property_list &PropList, private_tag)
210
- : MDevices(std::move(Devs)), MState(bundle_state::executable) {
207
+ devices_range Devs, const property_list &PropList, private_tag)
208
+ : MDevices(Devs.to<std::vector<device_impl *>>()),
209
+ MState(bundle_state::executable) {
211
210
if (MDevices.empty ())
212
211
throw sycl::exception (make_error_code (errc::invalid),
213
212
" Vector of devices is empty" );
@@ -226,16 +225,15 @@ class kernel_bundle_impl
226
225
// Check if any of the devices in devs are not in the set of associated
227
226
// devices for any of the bundles in ObjectBundles
228
227
const bool AllDevsAssociatedWithInputBundles = std::all_of (
229
- MDevices.begin (), MDevices.end (), [&ObjectBundles](const device &Dev) {
228
+ get_devices ().begin (), get_devices ().end (),
229
+ [&ObjectBundles](device_impl &Dev) {
230
230
// Number of devices is expected to be small
231
231
return std::all_of (
232
232
ObjectBundles.begin (), ObjectBundles.end (),
233
233
[&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
234
- const std::vector<device> & BundleDevices =
234
+ devices_range BundleDevices =
235
235
getSyclObjImpl (KernelBundle)->get_devices ();
236
- return BundleDevices.end () != std::find (BundleDevices.begin (),
237
- BundleDevices.end (),
238
- Dev);
236
+ return BundleDevices.contains (Dev);
239
237
});
240
238
});
241
239
if (!AllDevsAssociatedWithInputBundles)
@@ -363,41 +361,33 @@ class kernel_bundle_impl
363
361
}
364
362
365
363
// Create a link graph and clone it for each device.
366
- device_impl &FirstDevice = *getSyclObjImpl (MDevices[0 ]);
367
- std::map<std::shared_ptr<device_impl>, LinkGraph<device_image_plain>>
368
- DevImageLinkGraphs;
364
+ device_impl &FirstDevice = get_devices ().front ();
365
+ std::map<device_impl *, LinkGraph<device_image_plain>> DevImageLinkGraphs;
369
366
const auto &FirstGraph =
370
367
DevImageLinkGraphs
371
- .emplace (FirstDevice. shared_from_this () ,
368
+ .emplace (& FirstDevice,
372
369
LinkGraph<device_image_plain>{DevImages, Dependencies})
373
370
.first ->second ;
374
- for (size_t I = 1 ; I < MDevices.size (); ++I)
375
- DevImageLinkGraphs.emplace (getSyclObjImpl (MDevices[I]),
376
- FirstGraph.Clone ());
371
+ for (device_impl &Dev : get_devices ())
372
+ DevImageLinkGraphs.emplace (&Dev, FirstGraph.Clone ());
377
373
378
374
// Poison the images based on whether the corresponding device supports it.
379
375
for (auto &GraphIt : DevImageLinkGraphs) {
380
- device Dev = createSyclObjFromImpl<device>( GraphIt.first ) ;
376
+ device_impl & Dev = * GraphIt.first ;
381
377
GraphIt.second .Poison ([&Dev](const device_image_plain &DevImg) {
382
378
return !getSyclObjImpl (DevImg)->compatible_with_device (Dev);
383
379
});
384
380
}
385
381
386
382
// Unify graphs after poisoning.
387
- std::map<std::vector<std::shared_ptr<device_impl>>,
388
- LinkGraph<device_image_plain>>
383
+ std::map<std::vector<device_impl *>, LinkGraph<device_image_plain>>
389
384
UnifiedGraphs = UnifyGraphs (DevImageLinkGraphs);
390
385
391
386
// Link based on the resulting graphs.
392
387
for (auto &GraphIt : UnifiedGraphs) {
393
- std::vector<device> DeviceGroup;
394
- DeviceGroup.reserve (GraphIt.first .size ());
395
- for (const auto &DeviceImgImpl : GraphIt.first )
396
- DeviceGroup.emplace_back (createSyclObjFromImpl<device>(DeviceImgImpl));
397
-
398
388
std::vector<device_image_plain> LinkedResults =
399
389
detail::ProgramManager::getInstance ().link (
400
- GraphIt.second .GetNodeValues (), DeviceGroup , PropList);
390
+ GraphIt.second .GetNodeValues (), GraphIt. first , PropList);
401
391
MDeviceImages.insert (MDeviceImages.end (), LinkedResults.begin (),
402
392
LinkedResults.end ());
403
393
MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
@@ -410,8 +400,8 @@ class kernel_bundle_impl
410
400
for (const DevImgPlainWithDeps *DeviceImageWithDeps :
411
401
ImagesWithSpecConsts) {
412
402
// Skip images which are not compatible with devices provided
413
- if (std::none_of (MDevices .begin (), MDevices .end (),
414
- [DeviceImageWithDeps](const device &Dev) {
403
+ if (std::none_of (get_devices () .begin (), get_devices () .end (),
404
+ [DeviceImageWithDeps](device_impl &Dev) {
415
405
return getSyclObjImpl (DeviceImageWithDeps->getMain ())
416
406
->compatible_with_device (Dev);
417
407
}))
@@ -438,10 +428,11 @@ class kernel_bundle_impl
438
428
}
439
429
}
440
430
441
- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
431
+ kernel_bundle_impl (context Ctx, devices_range Devs,
442
432
const std::vector<kernel_id> &KernelIDs,
443
433
bundle_state State, private_tag)
444
- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
434
+ : MContext(std::move(Ctx)),
435
+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
445
436
446
437
common_ctor_checks ();
447
438
@@ -450,10 +441,11 @@ class kernel_bundle_impl
450
441
fillUniqueDeviceImages ();
451
442
}
452
443
453
- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
444
+ kernel_bundle_impl (context Ctx, devices_range Devs,
454
445
const DevImgSelectorImpl &Selector, bundle_state State,
455
446
private_tag)
456
- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
447
+ : MContext(std::move(Ctx)),
448
+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
457
449
458
450
common_ctor_checks ();
459
451
@@ -548,7 +540,9 @@ class kernel_bundle_impl
548
540
kernel_bundle_impl (const context &Context, syclex::source_language Lang,
549
541
const std::string &Src, include_pairs_t IncludePairsVec,
550
542
private_tag)
551
- : MContext(Context), MDevices(Context.get_devices()),
543
+ : MContext(Context), MDevices(getSyclObjImpl(Context)
544
+ ->getDevices()
545
+ .to<std::vector<device_impl *>>()),
552
546
MDeviceImages{device_image_plain{device_image_impl::create (
553
547
Src, MContext, MDevices, Lang, std::move (IncludePairsVec))}},
554
548
MUniqueDeviceImages{MDeviceImages[0 ].getMain ()},
@@ -560,7 +554,9 @@ class kernel_bundle_impl
560
554
// construct from source bytes
561
555
kernel_bundle_impl (const context &Context, syclex::source_language Lang,
562
556
const std::vector<std::byte> &Bytes, private_tag)
563
- : MContext(Context), MDevices(Context.get_devices()),
557
+ : MContext(Context), MDevices(getSyclObjImpl(Context)
558
+ ->getDevices ()
559
+ .to<std::vector<device_impl *>>()),
564
560
MDeviceImages{device_image_plain{
565
561
device_image_impl::create (Bytes, MContext, MDevices, Lang)}},
566
562
MUniqueDeviceImages{MDeviceImages[0 ].getMain ()},
@@ -571,11 +567,11 @@ class kernel_bundle_impl
571
567
// oneapi_ext_kernel_compiler
572
568
// construct from built source files
573
569
kernel_bundle_impl (
574
- const context &Context, const std::vector<device> & Devs,
570
+ const context &Context, devices_range Devs,
575
571
std::vector<device_image_plain> &&DevImgs,
576
572
std::vector<std::shared_ptr<ManagedDeviceBinaries>> &&DevBinaries,
577
573
bundle_state State, private_tag)
578
- : MContext(Context), MDevices(Devs),
574
+ : MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>() ),
579
575
MSharedDeviceBinaries (std::move(DevBinaries)),
580
576
MUniqueDeviceImages(std::move(DevImgs)), MState(State) {
581
577
common_ctor_checks ();
@@ -587,10 +583,11 @@ class kernel_bundle_impl
587
583
}
588
584
589
585
// SYCLBIN constructor
590
- kernel_bundle_impl (const context &Context, const std::vector<device> & Devs,
586
+ kernel_bundle_impl (const context &Context, devices_range Devs,
591
587
const sycl::span<char > Bytes, bundle_state State,
592
588
private_tag)
593
- : MContext(Context), MDevices(Devs), MState(State) {
589
+ : MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
590
+ MState(State) {
594
591
common_ctor_checks ();
595
592
596
593
auto &SYCLBIN = MSYCLBINs.emplace_back (
@@ -622,7 +619,7 @@ class kernel_bundle_impl
622
619
}
623
620
624
621
std::shared_ptr<kernel_bundle_impl> build_from_source (
625
- const std::vector<device> Devices,
622
+ devices_range Devices,
626
623
const std::vector<sycl::detail::string_view> &BuildOptions,
627
624
std::string *LogPtr,
628
625
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -645,7 +642,7 @@ class kernel_bundle_impl
645
642
}
646
643
647
644
std::shared_ptr<kernel_bundle_impl> compile_from_source (
648
- const std::vector<device> Devices,
645
+ devices_range Devices,
649
646
const std::vector<sycl::detail::string_view> &CompileOptions,
650
647
std::string *LogPtr,
651
648
const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -733,8 +730,9 @@ class kernel_bundle_impl
733
730
void *ext_oneapi_get_device_global_address (const std::string &Name,
734
731
const device &Dev) const {
735
732
DeviceGlobalMapEntry *Entry = getDeviceGlobalEntry (Name);
733
+ device_impl &DeviceImpl = *getSyclObjImpl (Dev);
736
734
737
- if (std::find (MDevices. begin (), MDevices. end (), Dev) == MDevices. end ( )) {
735
+ if (! get_devices (). contains (DeviceImpl )) {
738
736
throw sycl::exception (make_error_code (errc::invalid),
739
737
" kernel_bundle not built for device" );
740
738
}
@@ -745,7 +743,6 @@ class kernel_bundle_impl
745
743
" 'device_image_scope' property" );
746
744
}
747
745
748
- device_impl &DeviceImpl = *getSyclObjImpl (Dev);
749
746
bool SupportContextMemcpy = false ;
750
747
DeviceImpl.getAdapter ().call <UrApiKind::urDeviceGetInfo>(
751
748
DeviceImpl.getHandleRef (),
@@ -772,7 +769,7 @@ class kernel_bundle_impl
772
769
773
770
context get_context () const noexcept { return MContext; }
774
771
775
- const std::vector<device> & get_devices () const noexcept { return MDevices; }
772
+ devices_range get_devices () const noexcept { return MDevices; }
776
773
777
774
std::vector<kernel_id> get_kernel_ids () const {
778
775
// Collect kernel ids from all device images, then remove duplicates
@@ -1111,7 +1108,7 @@ class kernel_bundle_impl
1111
1108
}
1112
1109
1113
1110
context MContext;
1114
- std::vector<device > MDevices;
1111
+ std::vector<device_impl * > MDevices;
1115
1112
1116
1113
// For sycl_jit, building from source may have produced sycl binaries that
1117
1114
// the kernel_bundles now manage.
0 commit comments