Skip to content

Commit f1c3834

Browse files
committed
[SYCL] Support connection with multiple plugins
This commit enables including multiple devices of the same device_type and changed the logic of device selection to just prefer a SYCL_BE device if present. If someone uses SYCL_BE but appropriate device is not present, we will simply use any other device. Signed-off-by: Artur Gainullin <[email protected]>
1 parent 86acff3 commit f1c3834

File tree

11 files changed

+213
-106
lines changed

11 files changed

+213
-106
lines changed

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ enum class PiApiKind {
4343
class plugin;
4444
namespace pi {
4545

46+
// The SYCL_PI_TRACE sets what we will trace.
47+
// This is a bit-mask of various things we'd want to trace.
48+
enum TraceLevel {
49+
PI_TRACE_BASIC = 0x1,
50+
PI_TRACE_CALLS = 0x2,
51+
PI_TRACE_ALL = -1
52+
};
53+
54+
// Return true if we want to trace PI related activities.
55+
bool trace(TraceLevel level);
56+
4657
#ifdef SYCL_RT_OS_WINDOWS
4758
#define OPENCL_PLUGIN_NAME "pi_opencl.dll"
4859
#define CUDA_PLUGIN_NAME "pi_cuda.dll"
@@ -115,8 +126,8 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName);
115126
// environment variable.
116127
enum Backend { SYCL_BE_PI_OPENCL, SYCL_BE_PI_CUDA, SYCL_BE_PI_OTHER };
117128

118-
// Check for manually selected BE at run-time.
119-
bool useBackend(Backend Backend);
129+
// Get the preferred BE (selected with SYCL_BE).
130+
Backend getPreferredBE();
120131

121132
// Get a string representing a _pi_platform_info enum
122133
std::string platformInfoToString(pi_platform_info info);

sycl/source/detail/pi.cpp

Lines changed: 93 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
#include <cstring>
2323
#include <iostream>
2424
#include <map>
25+
#include <sstream>
2526
#include <stddef.h>
2627
#include <string>
27-
#include <sstream>
2828

2929
#ifdef XPTI_ENABLE_INSTRUMENTATION
3030
// Include the headers necessary for emitting
@@ -141,39 +141,80 @@ std::string memFlagsToString(pi_mem_flags Flags) {
141141
return Sstream.str();
142142
}
143143

144-
// Check for manually selected BE at run-time.
145-
static Backend getBackend() {
146-
static const char *GetEnv = std::getenv("SYCL_BE");
147-
// Current default backend as SYCL_BE_PI_OPENCL
148-
// Valid values of GetEnv are "PI_OPENCL", "PI_CUDA" and "PI_OTHER"
149-
std::string StringGetEnv = (GetEnv ? GetEnv : "PI_OPENCL");
150-
static const Backend Use =
151-
std::map<std::string, Backend>{
152-
{ "PI_OPENCL", SYCL_BE_PI_OPENCL },
153-
{ "PI_CUDA", SYCL_BE_PI_CUDA },
154-
{ "PI_OTHER", SYCL_BE_PI_OTHER }
155-
}[ GetEnv ? StringGetEnv : "PI_OPENCL"];
156-
return Use;
144+
// A singleton class to aid that PI configuration parameters
145+
// are processed only once, like reading a string from environment
146+
// and converting it into a typed object.
147+
//
148+
template <typename T, const char *E> class Config {
149+
static Config *m_Instance;
150+
T m_Data;
151+
Config();
152+
153+
public:
154+
static T get() {
155+
if (!m_Instance) {
156+
m_Instance = new Config();
157+
}
158+
return m_Instance->m_Data;
159+
}
160+
};
161+
162+
template <typename T, const char *E>
163+
Config<T, E> *Config<T, E>::m_Instance = nullptr;
164+
165+
// Lists valid configuration environment variables.
166+
static constexpr char SYCL_BE[] = "SYCL_BE";
167+
static constexpr char SYCL_INTEROP_BE[] = "SYCL_INTEROP_BE";
168+
static constexpr char SYCL_PI_TRACE[] = "SYCL_PI_TRACE";
169+
170+
// SYCL_PI_TRACE gives the mask of enabled tracing components (0 default)
171+
template <> Config<int, SYCL_PI_TRACE>::Config() {
172+
const char *Env = std::getenv(SYCL_PI_TRACE);
173+
m_Data = (Env ? std::atoi(Env) : 0);
174+
}
175+
176+
static Backend getBE(const char *EnvVar) {
177+
const char *BE = std::getenv(EnvVar);
178+
const std::map<std::string, Backend> SyclBeMap{
179+
{"PI_OTHER", SYCL_BE_PI_OTHER},
180+
{"PI_CUDA", SYCL_BE_PI_CUDA},
181+
{"PI_OPENCL", SYCL_BE_PI_OPENCL}};
182+
if (BE) {
183+
auto It = SyclBeMap.find(BE);
184+
if (It == SyclBeMap.end())
185+
pi::die("Invalid backend. "
186+
"Valid values are PI_OPENCL/PI_CUDA");
187+
return It->second;
188+
}
189+
// Default backend
190+
return SYCL_BE_PI_OPENCL;
157191
}
158192

159-
// Check for manually selected BE at run-time.
160-
bool useBackend(Backend TheBackend) {
161-
return TheBackend == getBackend();
193+
template <> Config<Backend, SYCL_BE>::Config() { m_Data = getBE(SYCL_BE); }
194+
195+
// SYCL_INTEROP_BE is a way to specify the interoperability plugin.
196+
template <> Config<Backend, SYCL_INTEROP_BE>::Config() {
197+
m_Data = getBE(SYCL_INTEROP_BE);
162198
}
163199

200+
// Helper interface to not expose "pi::Config" outside of pi.cpp
201+
Backend getPreferredBE() { return Config<Backend, SYCL_BE>::get(); }
202+
164203
// GlobalPlugin is a global Plugin used with Interoperability constructors that
165204
// use OpenCL objects to construct SYCL class objects.
166205
std::shared_ptr<plugin> GlobalPlugin;
167206

168207
// Find the plugin at the appropriate location and return the location.
169-
// TODO: Change the function appropriately when there are multiple plugins.
170-
bool findPlugins(vector_class<std::string> &PluginNames) {
208+
bool findPlugins(vector_class<std::pair<std::string, Backend>> &PluginNames) {
171209
// TODO: Based on final design discussions, change the location where the
172210
// plugin must be searched; how to identify the plugins etc. Currently the
173211
// search is done for libpi_opencl.so/pi_opencl.dll file in LD_LIBRARY_PATH
174212
// env only.
175-
PluginNames.push_back(OPENCL_PLUGIN_NAME);
176-
PluginNames.push_back(CUDA_PLUGIN_NAME);
213+
//
214+
PluginNames.push_back(std::make_pair<std::string, Backend>(
215+
OPENCL_PLUGIN_NAME, SYCL_BE_PI_OPENCL));
216+
PluginNames.push_back(
217+
std::make_pair<std::string, Backend>(CUDA_PLUGIN_NAME, SYCL_BE_PI_CUDA));
177218
return true;
178219
}
179220

@@ -207,52 +248,51 @@ bool bindPlugin(void *Library, PiPlugin *PluginInformation) {
207248
return true;
208249
}
209250

210-
// Load the plugin based on SYCL_BE.
211-
// TODO: Currently only accepting OpenCL and CUDA plugins. Edit it to identify
212-
// and load other kinds of plugins, do the required changes in the
213-
// findPlugins, loadPlugin and bindPlugin functions.
251+
bool trace(TraceLevel Level) {
252+
auto TraceLevelMask = Config<int, SYCL_PI_TRACE>::get();
253+
return (TraceLevelMask & Level) == Level;
254+
}
255+
256+
// Initializes all available Plugins.
214257
vector_class<plugin> initialize() {
215258
vector_class<plugin> Plugins;
216-
217-
if (!useBackend(SYCL_BE_PI_OPENCL) && !useBackend(SYCL_BE_PI_CUDA)) {
218-
die("Unknown SYCL_BE");
219-
}
220-
221-
bool EnableTrace = (std::getenv("SYCL_PI_TRACE") != nullptr);
222-
223-
vector_class<std::string> PluginNames;
259+
vector_class<std::pair<std::string, Backend>> PluginNames;
224260
findPlugins(PluginNames);
225261

226-
if (PluginNames.empty() && EnableTrace)
227-
std::cerr << "No Plugins Found." << std::endl;
262+
if (PluginNames.empty() && trace(PI_TRACE_ALL))
263+
std::cerr << "SYCL_PI_TRACE[-1]: No Plugins Found." << std::endl;
228264

229-
PiPlugin PluginInformation; // TODO: include.
265+
PiPlugin PluginInformation;
230266
for (unsigned int I = 0; I < PluginNames.size(); I++) {
231-
void *Library = loadPlugin(PluginNames[I]);
267+
void *Library = loadPlugin(PluginNames[I].first);
232268

233269
if (!Library) {
234-
if (EnableTrace) {
235-
std::cerr << "Check if plugin is present. Failed to load plugin: "
236-
<< PluginNames[I] << std::endl;
270+
if (trace(PI_TRACE_ALL)) {
271+
std::cerr << "SYCL_PI_TRACE[-1]: Check if plugin is present. "
272+
<< "Failed to load plugin: " << PluginNames[I].first
273+
<< std::endl;
237274
}
238275
continue;
239276
}
240277

241-
if (!bindPlugin(Library, &PluginInformation) && EnableTrace) {
242-
std::cerr << "Failed to bind PI APIs to the plugin: " << PluginNames[I]
243-
<< std::endl;
244-
}
245-
if (useBackend(SYCL_BE_PI_OPENCL) &&
246-
PluginNames[I].find("opencl") != std::string::npos) {
247-
// Use the OpenCL plugin as the GlobalPlugin
248-
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
278+
if (!bindPlugin(Library, &PluginInformation)) {
279+
if (trace(PI_TRACE_ALL)) {
280+
std::cerr << "SYCL_PI_TRACE[-1]: Failed to bind PI APIs to the plugin: "
281+
<< PluginNames[I].first << std::endl;
282+
}
283+
continue;
249284
}
250-
if (useBackend(SYCL_BE_PI_CUDA) &&
251-
PluginNames[I].find("cuda") != std::string::npos) {
252-
// Use the CUDA plugin as the GlobalPlugin
253-
GlobalPlugin = std::make_shared<plugin>(PluginInformation);
285+
// Set the Global Plugin based on SYCL_INTEROP_BE.
286+
// Rework this when it will be explicit in the code which BE is used in the
287+
// interoperability methods.
288+
if (Config<Backend, SYCL_INTEROP_BE>::get() == PluginNames[I].second) {
289+
GlobalPlugin =
290+
std::make_shared<plugin>(PluginInformation, PluginNames[I].second);
254291
}
255-
Plugins.push_back(plugin(PluginInformation));
292+
Plugins.emplace_back(plugin(PluginInformation, PluginNames[I].second));
293+
if (trace(TraceLevel::PI_TRACE_BASIC))
294+
std::cerr << "SYCL_PI_TRACE[1]: Plugin found and successfully loaded: "
295+
<< PluginNames[I].first << std::endl;
256296
}
257297

258298
#ifdef XPTI_ENABLE_INSTRUMENTATION

sycl/source/detail/plugin.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ class plugin {
2323
public:
2424
plugin() = delete;
2525

26-
plugin(RT::PiPlugin Plugin) : MPlugin(Plugin) {
27-
MPiEnableTrace = (std::getenv("SYCL_PI_TRACE") != nullptr);
28-
}
26+
plugin(RT::PiPlugin Plugin, RT::Backend UseBackend)
27+
: MPlugin(Plugin), MBackend(UseBackend) {}
2928

3029
~plugin() = default;
3130

@@ -52,13 +51,13 @@ class plugin {
5251
template <PiApiKind PiApiOffset, typename... ArgsT>
5352
RT::PiResult call_nocheck(ArgsT... Args) const {
5453
RT::PiFuncInfo<PiApiOffset> PiCallInfo;
55-
if (MPiEnableTrace) {
54+
if (pi::trace(pi::TraceLevel::PI_TRACE_CALLS)) {
5655
std::string FnName = PiCallInfo.getFuncName();
5756
std::cout << "---> " << FnName << "(" << std::endl;
5857
RT::printArgs(Args...);
5958
}
6059
RT::PiResult R = PiCallInfo.getFuncPtr(MPlugin)(Args...);
61-
if (MPiEnableTrace) {
60+
if (pi::trace(pi::TraceLevel::PI_TRACE_CALLS)) {
6261
std::cout << ") ---> ";
6362
RT::printArgs(R);
6463
}
@@ -74,10 +73,11 @@ class plugin {
7473
checkPiResult(Err);
7574
}
7675

76+
RT::Backend getBackend(void) const { return MBackend; }
77+
7778
private:
7879
RT::PiPlugin MPlugin;
79-
bool MPiEnableTrace;
80-
80+
const RT::Backend MBackend;
8181
}; // class plugin
8282
} // namespace detail
8383
} // namespace sycl

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ static bool isDeviceBinaryTypeSupported(const context &C,
270270
}
271271

272272
// OpenCL 2.1 and greater require clCreateProgramWithIL
273-
if (pi::useBackend(pi::SYCL_BE_PI_OPENCL) &&
273+
pi::Backend CBackend = (detail::getSyclObjImpl(C)->getPlugin()).getBackend();
274+
if ((CBackend == pi::SYCL_BE_PI_OPENCL) &&
274275
C.get_platform().get_info<info::platform::version>() >= "2.1")
275276
return true;
276277

sycl/source/detail/scheduler/commands.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ cl_int ExecCGCommand::enqueueImp() {
16731673
Requirement *Req = (Requirement *)(Arg.MPtr);
16741674
AllocaCommandBase *AllocaCmd = getAllocaForReq(Req);
16751675
RT::PiMem MemArg = (RT::PiMem)AllocaCmd->getMemAllocation();
1676-
if (RT::useBackend(pi::Backend::SYCL_BE_PI_OPENCL)) {
1676+
if (Plugin.getBackend() == (pi::Backend::SYCL_BE_PI_OPENCL)) {
16771677
Plugin.call<PiApiKind::piKernelSetArg>(Kernel, Arg.MIndex,
16781678
sizeof(RT::PiMem), &MemArg);
16791679
} else {

0 commit comments

Comments
 (0)