Skip to content

Commit ae77ba8

Browse files
committed
Refactor UR function ptrs
1 parent 1dc8b92 commit ae77ba8

File tree

7 files changed

+41
-19
lines changed

7 files changed

+41
-19
lines changed

sycl/include/sycl/detail/ur.hpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,31 @@ enum class UrApiKind {
5353
#undef _UR_API
5454
};
5555

56+
struct UrFuncPtrMapT {
57+
#define _UR_API(api) decltype(&::api) pfn_##api = nullptr;
58+
#include <ur_api_funcs.def>
59+
#undef _UR_API
60+
};
61+
5662
template <UrApiKind UrApiOffset> struct UrFuncInfo {};
5763

5864
#ifdef _WIN32
5965
void *GetWinProcAddress(void *module, const char *funcName);
66+
inline void PopulateUrFuncPtrTable(UrFuncPtrMapT *funcs, void *module) {
67+
#define _UR_API(api) \
68+
funcs->pfn_##api = (decltype(&::api))GetWinProcAddress(module, #api);
69+
#include <ur_api_funcs.def>
70+
#undef _UR_API
71+
}
72+
6073
#define _UR_API(api) \
6174
template <> struct UrFuncInfo<UrApiKind::api> { \
6275
using FuncPtrT = decltype(&::api); \
6376
inline const char *getFuncName() { return #api; } \
64-
inline FuncPtrT getFuncPtr(void *module) { \
77+
inline FuncPtrT getFuncPtr(const UrFuncPtrMapT *funcs) { \
78+
return funcs->pfn_##api; \
79+
} \
80+
inline FuncPtrT getFuncPtrFromModule(void *module) { \
6581
return (FuncPtrT)GetWinProcAddress(module, #api); \
6682
} \
6783
};
@@ -72,7 +88,8 @@ void *GetWinProcAddress(void *module, const char *funcName);
7288
template <> struct UrFuncInfo<UrApiKind::api> { \
7389
using FuncPtrT = decltype(&::api); \
7490
inline const char *getFuncName() { return #api; } \
75-
constexpr inline FuncPtrT getFuncPtr(void *) { return &api; } \
91+
constexpr inline FuncPtrT getFuncPtr(const void *) { return &api; } \
92+
constexpr inline FuncPtrT getFuncPtrFromModule(void *) { return &api; } \
7693
};
7794
#include <ur_api_funcs.def>
7895
#undef _UR_API
@@ -106,7 +123,7 @@ int unloadOsLibrary(void *Library);
106123
// library, implementation is OS dependent.
107124
void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName);
108125

109-
void *loadURLoaderLibrary();
126+
void *getURLoaderLibrary();
110127

111128
// Performs UR one-time initialization.
112129
std::vector<PluginPtr> &

sycl/source/detail/global_handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ void GlobalHandler::unloadPlugins() {
275275

276276
UrFuncInfo<UrApiKind::urLoaderTearDown> loaderTearDownInfo;
277277
auto loaderTearDown =
278-
loaderTearDownInfo.getFuncPtr(ur::loadURLoaderLibrary());
278+
loaderTearDownInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
279279
loaderTearDown();
280280
// urLoaderTearDown();
281281

sycl/source/detail/plugin.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class plugin {
5151
MPluginMutex(std::make_shared<std::mutex>()) {
5252

5353
#ifdef _WIN32
54-
UrLoaderHandle = ur::loadURLoaderLibrary();
54+
UrLoaderHandle = ur::getURLoaderLibrary();
55+
PopulateUrFuncPtrTable(&UrFuncPtrs, UrLoaderHandle);
5556
#endif
5657
}
5758

@@ -123,7 +124,7 @@ class plugin {
123124
ur_result_t R = UR_RESULT_SUCCESS;
124125
if (!adapterReleased) {
125126
detail::UrFuncInfo<UrApiOffset> UrApiInfo;
126-
auto F = UrApiInfo.getFuncPtr(UrLoaderHandle);
127+
auto F = UrApiInfo.getFuncPtr(&UrFuncPtrs);
127128
R = F(Args...);
128129
}
129130
return R;
@@ -220,6 +221,7 @@ class plugin {
220221
// index of this vector corresponds to the index in UrPlatforms vector.
221222
std::vector<int> LastDeviceIds;
222223
void *UrLoaderHandle = nullptr;
224+
UrFuncPtrMapT UrFuncPtrs;
223225
}; // class plugin
224226

225227
using PluginPtr = std::shared_ptr<plugin>;

sycl/source/detail/posix_ur.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
3535
return dlsym(Library, FunctionName.c_str());
3636
}
3737

38-
void *loadURLoaderLibrary() { return nullptr; }
38+
void *getURLoaderLibrary() { return nullptr; }
3939

4040
} // namespace detail::ur
4141
} // namespace _V1

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
805805

806806
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
807807
auto programRelease =
808-
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
808+
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
809809
ProgramPtr ProgramManaged(NativePrg, programRelease);
810810

811811
// Link a fallback implementation of device libraries if they are not
@@ -2555,7 +2555,7 @@ device_image_plain ProgramManager::build(const device_image_plain &DeviceImage,
25552555

25562556
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
25572557
auto programRelease =
2558-
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
2558+
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
25592559
ProgramPtr ProgramManaged(NativePrg, programRelease);
25602560

25612561
// Link a fallback implementation of device libraries if they are not
@@ -2769,7 +2769,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
27692769
auto &Plugin = DeviceImpl->getPlugin();
27702770
UrFuncInfo<UrApiKind::urProgramRelease> programReleaseInfo;
27712771
auto programRelease =
2772-
programReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
2772+
programReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
27732773
ProgramPtr ProgramManaged(Program, programRelease);
27742774

27752775
std::string CompileOpts;

sycl/source/detail/ur.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,28 @@ static void initializePlugins(std::vector<PluginPtr> &Plugins,
109109

110110
UrFuncInfo<UrApiKind::urLoaderConfigCreate> loaderConfigCreateInfo;
111111
auto loaderConfigCreate =
112-
loaderConfigCreateInfo.getFuncPtr(ur::loadURLoaderLibrary());
112+
loaderConfigCreateInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
113113
UrFuncInfo<UrApiKind::urLoaderConfigEnableLayer> loaderConfigEnableLayerInfo;
114114
auto loaderConfigEnableLayer =
115-
loaderConfigEnableLayerInfo.getFuncPtr(ur::loadURLoaderLibrary());
115+
loaderConfigEnableLayerInfo.getFuncPtrFromModule(
116+
ur::getURLoaderLibrary());
116117
UrFuncInfo<UrApiKind::urLoaderConfigRelease> loaderConfigReleaseInfo;
117118
auto loaderConfigRelease =
118-
loaderConfigReleaseInfo.getFuncPtr(ur::loadURLoaderLibrary());
119+
loaderConfigReleaseInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
119120
UrFuncInfo<UrApiKind::urLoaderConfigSetCodeLocationCallback>
120121
loaderConfigSetCodeLocationCallbackInfo;
121122
auto loaderConfigSetCodeLocationCallback =
122-
loaderConfigSetCodeLocationCallbackInfo.getFuncPtr(
123-
ur::loadURLoaderLibrary());
123+
loaderConfigSetCodeLocationCallbackInfo.getFuncPtrFromModule(
124+
ur::getURLoaderLibrary());
124125
UrFuncInfo<UrApiKind::urLoaderInit> loaderInitInfo;
125-
auto loaderInit = loaderInitInfo.getFuncPtr(ur::loadURLoaderLibrary());
126+
auto loaderInit =
127+
loaderInitInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
126128
UrFuncInfo<UrApiKind::urAdapterGet> adapterGet_Info;
127-
auto adapterGet = adapterGet_Info.getFuncPtr(ur::loadURLoaderLibrary());
129+
auto adapterGet =
130+
adapterGet_Info.getFuncPtrFromModule(ur::getURLoaderLibrary());
128131
UrFuncInfo<UrApiKind::urAdapterGetInfo> adapterGetInfoInfo;
129132
auto adapterGetInfo =
130-
adapterGetInfoInfo.getFuncPtr(ur::loadURLoaderLibrary());
133+
adapterGetInfoInfo.getFuncPtrFromModule(ur::getURLoaderLibrary());
131134

132135
bool OwnLoaderConfig = false;
133136
// If we weren't provided with a custom config handle create our own.

sycl/source/detail/windows_ur.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ static std::filesystem::path getCurrentDSODirPath() {
7575
return std::filesystem::path(Path);
7676
}
7777

78-
void *loadURLoaderLibrary() {
78+
void *getURLoaderLibrary() {
7979
const std::filesystem::path LibSYCLDir = getCurrentDSODirPath();
8080
return getPreloadedPlugin(LibSYCLDir / std::string("ur_loader.dll"));
8181
}

0 commit comments

Comments
 (0)