Skip to content

[SYCL] Add support of multiple devices within a context #2343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ class context_impl {
/// more details.
///
/// \returns a map with device library programs.
std::map<DeviceLibExt, RT::PiProgram> &getCachedLibPrograms() {
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram> &
getCachedLibPrograms() {
return MCachedLibPrograms;
}

Expand All @@ -155,7 +156,8 @@ class context_impl {
PlatformImplPtr MPlatform;
bool MHostContext;
bool MUseCUDAPrimaryContext;
std::map<DeviceLibExt, RT::PiProgram> MCachedLibPrograms;
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
MCachedLibPrograms;
mutable KernelProgramCache MKernelProgramCache;
};

Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,10 @@ void program_impl::create_pi_program_with_kernel_name(
bool JITCompilationIsRequired) {
assert(!MProgram && "This program already has an encapsulated PI program");
ProgramManager &PM = ProgramManager::getInstance();
RTDeviceBinaryImage &Img =
PM.getDeviceImage(Module, KernelName, get_context(), get_devices()[0],
JITCompilationIsRequired);
MProgram = PM.createPIProgram(Img, get_context(), get_devices()[0]);
const device FirstDevice = get_devices()[0];
RTDeviceBinaryImage &Img = PM.getDeviceImage(
Module, KernelName, get_context(), FirstDevice, JITCompilationIsRequired);
MProgram = PM.createPIProgram(Img, get_context(), FirstDevice);
}

template <>
Expand Down
172 changes: 67 additions & 105 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ static RT::PiProgram createBinaryProgram(const ContextImplPtr Context,
#endif

RT::PiProgram Program;
const auto PiDevice = getSyclObjImpl(Device)->getHandleRef();
const RT::PiDevice PiDevice = getSyclObjImpl(Device)->getHandleRef();
pi_int32 BinaryStatus = CL_SUCCESS;
Plugin.call<PiApiKind::piProgramCreateWithBinary>(
Context->getHandleRef(), 1 /*one binary*/, &PiDevice, &DataLen, &Data,
Expand Down Expand Up @@ -320,8 +320,7 @@ RT::PiProgram ProgramManager::createPIProgram(const RTDeviceBinaryImage &Img,
{
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
// associate the PI program with the image it was created for
const auto PiDevice = getSyclObjImpl(Device)->getHandleRef();
NativePrograms.emplace(std::make_pair(std::make_pair(Res, PiDevice), &Img));
NativePrograms[Res] = &Img;
}

if (DbgProgMgr > 1)
Expand Down Expand Up @@ -376,42 +375,14 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
!SYCLConfig<SYCL_DEVICELIB_NO_FALLBACK>::get())
DeviceLibReqMask = getDeviceLibReqMask(Img);

bool ContextHasSubDevices = false;
const vector_class<device> &Devices = ContextImpl->getDevices();
for (const auto &Device : Devices) {
try {
// Device.get_info<info::device::parent_device>(); should throw
// sycl::invalid_object_error exception if Device is not a sub device.
// If the exception doesn't throw, it means that context has a sub
// device and we can quit the loop.
Device.get_info<info::device::parent_device>();
ContextHasSubDevices = true;
break;
} catch (sycl::invalid_object_error const &) {
}
}

vector_class<RT::PiDevice> PiDevices;
if (ContextHasSubDevices) {
PiDevices.resize(Devices.size());
std::transform(Devices.begin(), Devices.end(), PiDevices.begin(),
[](const device Dev) {
return getRawSyclObjImpl(Dev)->getHandleRef();
});
} else {
PiDevices.push_back(getRawSyclObjImpl(Device)->getHandleRef());
}

ProgramPtr BuiltProgram =
build(std::move(ProgramManaged), ContextImpl, Img.getCompileOptions(),
Img.getLinkOptions(), PiDevices,
Img.getLinkOptions(), getRawSyclObjImpl(Device)->getHandleRef(),
ContextImpl->getCachedLibPrograms(), DeviceLibReqMask);

{
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
const auto PiDevice = getSyclObjImpl(Device)->getHandleRef();
NativePrograms.emplace(
std::make_pair(std::make_pair(BuiltProgram.get(), PiDevice), &Img));
NativePrograms[BuiltProgram.get()] = &Img;
}
return BuiltProgram.release();
};
Expand All @@ -420,7 +391,7 @@ RT::PiProgram ProgramManager::getBuiltPIProgram(OSModuleHandle M,
if (Prg)
Prg->stableSerializeSpecConstRegistry(SpecConsts);

const auto PiDevice = getRawSyclObjImpl(Device)->getHandleRef();
const RT::PiDevice PiDevice = getRawSyclObjImpl(Device)->getHandleRef();
auto BuildResult = getOrBuild<PiProgramT, compile_program_error>(
Cache,
std::make_pair(std::make_pair(std::move(SpecConsts), KSId), PiDevice),
Expand Down Expand Up @@ -471,7 +442,7 @@ std::pair<RT::PiKernel, std::mutex *> ProgramManager::getOrCreateKernel(
return Result;
};

const auto PiDevice = getRawSyclObjImpl(Device)->getHandleRef();
const RT::PiDevice PiDevice = getRawSyclObjImpl(Device)->getHandleRef();
auto BuildResult = getOrBuild<PiKernelT, invalid_object_error>(
Cache, std::make_pair(KernelName, PiDevice), AcquireF, GetF, BuildF);
return std::make_pair(BuildResult->Ptr.load(),
Expand Down Expand Up @@ -578,13 +549,15 @@ static const char *getDeviceLibExtensionStr(DeviceLibExt Extension) {

static RT::PiProgram loadDeviceLibFallback(
const ContextImplPtr Context, DeviceLibExt Extension,
const std::vector<RT::PiDevice> &Devices,
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms) {
const RT::PiDevice &Device,
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
&CachedLibPrograms) {

const char *LibFileName = getDeviceLibFilename(Extension);
auto CacheResult = CachedLibPrograms.insert({Extension, nullptr});
auto CacheResult = CachedLibPrograms.emplace(
std::make_pair(std::make_pair(Extension, Device), nullptr));
bool Cached = !CacheResult.second;
std::map<DeviceLibExt, RT::PiProgram>::iterator LibProgIt = CacheResult.first;
auto LibProgIt = CacheResult.first;
RT::PiProgram &LibProg = LibProgIt->second;

if (Cached)
Expand All @@ -600,8 +573,7 @@ static RT::PiProgram loadDeviceLibFallback(
// TODO no spec constants are used in the std libraries, support in the future
RT::PiResult Error = Plugin.call_nocheck<PiApiKind::piProgramCompile>(
LibProg,
// Assume that Devices contains all devices from Context.
Devices.size(), Devices.data(),
/*num devices = */ 1, &Device,
// Do not use compile options for library programs: it is not clear
// if user options (image options) are supposed to be applied to
// library program as well, and what actually happens to a SPIR-V
Expand Down Expand Up @@ -721,11 +693,11 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
return ((DeviceLibReqMask & Mask) == Mask);
}

static std::vector<RT::PiProgram>
getDeviceLibPrograms(const ContextImplPtr Context,
const std::vector<RT::PiDevice> &Devices,
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
uint32_t DeviceLibReqMask) {
static std::vector<RT::PiProgram> getDeviceLibPrograms(
const ContextImplPtr Context, const RT::PiDevice &Device,
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
&CachedLibPrograms,
uint32_t DeviceLibReqMask) {
std::vector<RT::PiProgram> Programs;

std::pair<DeviceLibExt, bool> RequiredDeviceLibExt[] = {
Expand All @@ -739,68 +711,61 @@ getDeviceLibPrograms(const ContextImplPtr Context,
// Disable all devicelib extensions requiring fp64 support if at least
// one underlying device doesn't support cl_khr_fp64.
bool fp64Support = true;
for (RT::PiDevice Dev : Devices) {
std::string DevExtList =
get_device_info<std::string, info::device::extensions>::get(
Dev, Context->getPlugin());
fp64Support =
fp64Support && (DevExtList.npos != DevExtList.find("cl_khr_fp64"));
}
std::string DevExtList =
get_device_info<std::string, info::device::extensions>::get(
Device, Context->getPlugin());
fp64Support =
fp64Support && (DevExtList.npos != DevExtList.find("cl_khr_fp64"));

// Load a fallback library for an extension if at least one device does not
// Load a fallback library for an extension if the device does not
// support it.
for (RT::PiDevice Dev : Devices) {
std::string DevExtList =
get_device_info<std::string, info::device::extensions>::get(
Dev, Context->getPlugin());
for (auto &Pair : RequiredDeviceLibExt) {
DeviceLibExt Ext = Pair.first;
bool &FallbackIsLoaded = Pair.second;

if (FallbackIsLoaded) {
continue;
}
for (auto &Pair : RequiredDeviceLibExt) {
DeviceLibExt Ext = Pair.first;
bool &FallbackIsLoaded = Pair.second;

if (!isDeviceLibRequired(Ext, DeviceLibReqMask)) {
continue;
}
if ((Ext == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
Ext == DeviceLibExt::cl_intel_devicelib_complex_fp64) &&
!fp64Support) {
continue;
}
if (FallbackIsLoaded) {
continue;
}

if (!isDeviceLibRequired(Ext, DeviceLibReqMask)) {
continue;
}
if ((Ext == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
Ext == DeviceLibExt::cl_intel_devicelib_complex_fp64) &&
!fp64Support) {
continue;
}

const char *ExtStr = getDeviceLibExtensionStr(Ext);
const char *ExtStr = getDeviceLibExtensionStr(Ext);

bool InhibitNativeImpl = false;
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
InhibitNativeImpl = strstr(Env, ExtStr) != nullptr;
}
bool InhibitNativeImpl = false;
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
InhibitNativeImpl = strstr(Env, ExtStr) != nullptr;
}

bool DeviceSupports = DevExtList.npos != DevExtList.find(ExtStr);
bool DeviceSupports = DevExtList.npos != DevExtList.find(ExtStr);

if (!DeviceSupports || InhibitNativeImpl) {
Programs.push_back(
loadDeviceLibFallback(Context, Ext, Devices, CachedLibPrograms));
FallbackIsLoaded = true;
}
if (!DeviceSupports || InhibitNativeImpl) {
Programs.push_back(
loadDeviceLibFallback(Context, Ext, Device, CachedLibPrograms));
FallbackIsLoaded = true;
}
}
return Programs;
}

ProgramManager::ProgramPtr
ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,
const string_class &CompileOptions,
const string_class &LinkOptions,
const std::vector<RT::PiDevice> &Devices,
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
uint32_t DeviceLibReqMask) {
ProgramManager::ProgramPtr ProgramManager::build(
ProgramPtr Program, const ContextImplPtr Context,
const string_class &CompileOptions, const string_class &LinkOptions,
const RT::PiDevice &Device,
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
&CachedLibPrograms,
uint32_t DeviceLibReqMask) {

if (DbgProgMgr > 0) {
std::cerr << ">>> ProgramManager::build(" << Program.get() << ", "
<< CompileOptions << ", " << LinkOptions << ", ... "
<< Devices.size() << ")\n";
<< CompileOptions << ", " << LinkOptions << ", ... " << Device
<< ")\n";
}

bool LinkDeviceLibs = (DeviceLibReqMask != 0);
Expand Down Expand Up @@ -831,7 +796,7 @@ ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,

std::vector<RT::PiProgram> LinkPrograms;
if (LinkDeviceLibs) {
LinkPrograms = getDeviceLibPrograms(Context, Devices, CachedLibPrograms,
LinkPrograms = getDeviceLibPrograms(Context, Device, CachedLibPrograms,
DeviceLibReqMask);
}

Expand All @@ -840,7 +805,7 @@ ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,
std::string Opts(CompileOpts);

RT::PiResult Error = Plugin.call_nocheck<PiApiKind::piProgramBuild>(
Program.get(), Devices.size(), Devices.data(), Opts.c_str(), nullptr,
Program.get(), /*num devices =*/1, &Device, Opts.c_str(), nullptr,
nullptr);
if (Error != PI_SUCCESS)
throw compile_program_error(getProgramBuildLog(Program.get(), Context),
Expand All @@ -849,14 +814,14 @@ ProgramManager::build(ProgramPtr Program, const ContextImplPtr Context,
}

// Include the main program and compile/link everything together
Plugin.call<PiApiKind::piProgramCompile>(Program.get(), Devices.size(),
Devices.data(), CompileOpts, 0,
nullptr, nullptr, nullptr, nullptr);
Plugin.call<PiApiKind::piProgramCompile>(Program.get(), /*num devices =*/1,
&Device, CompileOpts, 0, nullptr,
nullptr, nullptr, nullptr);
LinkPrograms.push_back(Program.get());

RT::PiProgram LinkedProg = nullptr;
RT::PiResult Error = Plugin.call_nocheck<PiApiKind::piProgramLink>(
Context->getHandleRef(), Devices.size(), Devices.data(), LinkOpts,
Context->getHandleRef(), /*num devices =*/1, &Device, LinkOpts,
LinkPrograms.size(), LinkPrograms.data(), nullptr, nullptr, &LinkedProg);

// Link program call returns a new program object if all parameters are valid,
Expand Down Expand Up @@ -1037,7 +1002,7 @@ void ProgramManager::flushSpecConstants(const program_impl &Prg,
// caller hasn't provided the image object - find it
{ // make sure NativePrograms map access is synchronized
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
auto It = NativePrograms.find(std::make_pair(NativePrg, Device));
auto It = NativePrograms.find(NativePrg);
if (It == NativePrograms.end())
throw sycl::experimental::spec_const_error(
"spec constant is set in a program w/o a binary image",
Expand Down Expand Up @@ -1080,11 +1045,9 @@ ProgramManager::KernelArgMask ProgramManager::getEliminatedKernelArgMask(
if (m_UseSpvFile && M == OSUtil::ExeModuleHandle)
return {};

const auto PiDevice = getSyclObjImpl(Device)->getHandleRef();

{
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
auto ImgIt = NativePrograms.find(std::make_pair(NativePrg, PiDevice));
auto ImgIt = NativePrograms.find(NativePrg);
if (ImgIt != NativePrograms.end()) {
auto MapIt = m_EliminatedKernelArgMasks.find(ImgIt->second);
if (MapIt != m_EliminatedKernelArgMasks.end())
Expand Down Expand Up @@ -1114,8 +1077,7 @@ ProgramManager::KernelArgMask ProgramManager::getEliminatedKernelArgMask(
RTDeviceBinaryImage &Img = getDeviceImage(M, KSId, Context, Device);
{
std::lock_guard<std::mutex> Lock(MNativeProgramsMutex);
NativePrograms.emplace(
std::make_pair(std::make_pair(NativePrg, PiDevice), &Img));
NativePrograms[NativePrg] = &Img;
}
auto MapIt = m_EliminatedKernelArgMasks.find(&Img);
if (MapIt != m_EliminatedKernelArgMasks.end())
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ class ProgramManager {
decltype(&::piProgramRelease)>;
ProgramPtr build(ProgramPtr Program, const ContextImplPtr Context,
const string_class &CompileOptions,
const string_class &LinkOptions,
const std::vector<RT::PiDevice> &Devices,
std::map<DeviceLibExt, RT::PiProgram> &CachedLibPrograms,
const string_class &LinkOptions, const RT::PiDevice &Device,
std::map<std::pair<DeviceLibExt, RT::PiDevice>,
RT::PiProgram> &CachedLibPrograms,
uint32_t DeviceLibReqMask);
/// Provides a new kernel set id for grouping kernel names together
KernelSetId getNextKernelSetId() const;
Expand Down Expand Up @@ -209,8 +209,7 @@ class ProgramManager {
// the underlying program disposed of), so the map can't be used in any way
// other than binary image lookup with known live PiProgram as the key.
// NOTE: access is synchronized via the MNativeProgramsMutex
std::map<std::pair<pi::PiProgram, pi::PiDevice>, const RTDeviceBinaryImage *>
NativePrograms;
std::unordered_map<pi::PiProgram, const RTDeviceBinaryImage *> NativePrograms;

/// Protects NativePrograms that can be changed by class' methods.
std::mutex MNativeProgramsMutex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ void exceptionHandler(sycl::exception_list exceptions) {
try {
std::rethrow_exception(e);
} catch (sycl::exception const &e) {
std::cout << "Caught asynchronous SYCL exception:\n"
std::cerr << "Caught asynchronous SYCL exception:\n"
<< e.what() << std::endl;
}
}
}

int main() {
std::vector DeviceList = sycl::device::get_devices();
auto DeviceList = sycl::device::get_devices();

// remove host device from the list
DeviceList.erase(std::remove_if(DeviceList.begin(), DeviceList.end(),
Expand All @@ -33,8 +33,7 @@ int main() {

std::vector<sycl::queue> QueueList;
for (const auto &Device : Context.get_devices()) {
sycl::queue Queue(Context, Device, &exceptionHandler);
QueueList.push_back(Queue);
QueueList.emplace_back(Context, Device, &exceptionHandler);
}

for (auto &Queue : QueueList) {
Expand Down