Skip to content

[SYCL] Fixing device check in program link constructor #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sycl/include/CL/sycl/detail/device_host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class device_host : public device_impl {
cl_device_id get() const override {
throw invalid_object_error("This instance of device is a host instance");
}
cl_device_id &getHandleRef() override {
throw invalid_object_error("This instance of device is a host instance");
}

bool is_host() const override { return true; }

Expand Down
6 changes: 6 additions & 0 deletions sycl/include/CL/sycl/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class device_impl {

virtual cl_device_id get() const = 0;

// Returns underlying native device object (if any) w/o reference count
// modification. Caller must ensure the returned object lives on stack only.
// It can also be safely passed to the underlying native runtime API.
// Warning. Returned reference will be invalid if device_impl was destroyed.
virtual cl_device_id &getHandleRef() = 0;

virtual bool is_host() const = 0;

virtual bool is_cpu() const = 0;
Expand Down
4 changes: 4 additions & 0 deletions sycl/include/CL/sycl/detail/device_opencl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class device_opencl : public device_impl {
return id;
}

cl_device_id &getHandleRef() override{
return id;
}

bool is_host() const override { return false; }

bool is_cpu() const override { return (type == CL_DEVICE_TYPE_CPU); }
Expand Down
39 changes: 35 additions & 4 deletions sycl/include/CL/sycl/detail/program_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,23 @@ class program_impl {
}
Context = ProgramList[0]->Context;
Devices = ProgramList[0]->Devices;
std::vector<device> DevicesSorted;
if (!is_host()) {
DevicesSorted = sort_devices_by_cl_device_id(Devices);
}
for (const auto &Prg : ProgramList) {
Prg->throw_if_state_is_not(program_state::compiled);
if (Prg->Context != Context) {
throw invalid_object_error(
"Not all programs are associated with the same context");
}
if (Prg->Devices != Devices) {
throw invalid_object_error(
"Not all programs are associated with the same devices");
if (!is_host()) {
std::vector<device> PrgDevicesSorted =
sort_devices_by_cl_device_id(Prg->Devices);
if (PrgDevicesSorted != DevicesSorted) {
throw invalid_object_error(
"Not all programs are associated with the same devices");
}
}
}

Expand Down Expand Up @@ -92,7 +100,20 @@ class program_impl {
CHECK_OCL_CODE(clGetProgramInfo(ClProgram, CL_PROGRAM_DEVICES,
sizeof(cl_device_id) * NumDevices,
ClDevices.data(), nullptr));
Devices = vector_class<device>(ClDevices.begin(), ClDevices.end());
vector_class<device> SyclContextDevices = Context.get_devices();

// Keep only the subset of the devices (associated with context) that
// were actually used to create the program.
// This is possible when clCreateProgramWithBinary is used.
auto iterator = std::remove_if(
SyclContextDevices.begin(), SyclContextDevices.end(),
[&ClDevices](const sycl::device &Dev) {
return ClDevices.end() ==
std::find(ClDevices.begin(), ClDevices.end(),
detail::getSyclObjImpl(Dev)->getHandleRef());
});
SyclContextDevices.erase(iterator, SyclContextDevices.end());
Devices = SyclContextDevices;
// TODO check build for each device instead
cl_program_binary_type BinaryType;
CHECK_OCL_CODE(clGetProgramBuildInfo(
Expand Down Expand Up @@ -371,6 +392,16 @@ class program_impl {
return ClKernel;
}

std::vector<device>
sort_devices_by_cl_device_id(vector_class<device> Devices) {
std::sort(Devices.begin(), Devices.end(),
[](const device &id1, const device &id2) {
return (detail::getSyclObjImpl(id1)->getHandleRef() <
detail::getSyclObjImpl(id2)->getHandleRef());
});
return Devices;
}

void throw_if_state_is(program_state State) const {
if (this->State == State) {
throw invalid_object_error("Invalid program state");
Expand Down