Skip to content

[SYCL][NFC] Remove AdapterPtr from SYCL RT #19315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
}

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform,
const AdapterPtr &Adapter) {
ur_platform_handle_t UrPlatform, adapter_impl &Adapter) {

AllowListParsedT AllowListParsed =
parseAllowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
Expand All @@ -375,7 +374,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
// Get platform's backend and put it to DeviceDesc
DeviceDescT DeviceDesc;
platform_impl &PlatformImpl =
platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter);
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
backend Backend = PlatformImpl.getBackend();

for (const auto &SyclBe : getSyclBeMap()) {
Expand All @@ -396,7 +395,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
device_impl &DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device);
// get DeviceType value and put it to DeviceDesc
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
Adapter->call<UrApiKind::urDeviceGetInfo>(
Adapter.call<UrApiKind::urDeviceGetInfo>(
Device, UR_DEVICE_INFO_TYPE, sizeof(UrDevType), &UrDevType, nullptr);
// TODO need mechanism to do these casts, there's a bunch of this sort of
// thing
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/allowlist.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
const AllowListParsedT &AllowListParsed);

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform, const AdapterPtr &Adapter);
ur_platform_handle_t UrPlatform, adapter_impl &Adapter);

} // namespace detail
} // namespace _V1
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ context_impl::~context_impl() {
try {
// Free all events associated with the initialization of device globals.
for (auto &DeviceGlobalInitializer : MDeviceGlobalInitializers)
DeviceGlobalInitializer.second.ClearEvents(&getAdapter());
DeviceGlobalInitializer.second.ClearEvents(getAdapter());
// Free all device_global USM allocations associated with this context.
for (const void *DeviceGlobal : MAssociatedDeviceGlobals) {
DeviceGlobalMapEntry *DGEntry =
Expand All @@ -146,7 +146,7 @@ const async_handler &context_impl::get_async_handler() const {
template <>
uint32_t context_impl::get_info<info::context::reference_count>() const {
return get_context_info<info::context::reference_count>(this->getHandleRef(),
&this->getAdapter());
this->getAdapter());
}
template <> platform context_impl::get_info<info::context::platform>() const {
return createSyclObjFromImpl<platform>(*MPlatform);
Expand Down Expand Up @@ -449,10 +449,9 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
}
}

void context_impl::DeviceGlobalInitializer::ClearEvents(
const AdapterPtr &Adapter) {
void context_impl::DeviceGlobalInitializer::ClearEvents(adapter_impl &Adapter) {
for (const ur_event_handle_t &Event : MDeviceGlobalInitEvents)
Adapter->call<UrApiKind::urEventRelease>(Event);
Adapter.call<UrApiKind::urEventRelease>(Event);
MDeviceGlobalInitEvents.clear();
}

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
}

/// Clears all events of the initializer. This will not acquire the lock.
void ClearEvents(const AdapterPtr &Adapter);
void ClearEvents(adapter_impl &Adapter);

/// The binary image of the program.
const RTDeviceBinaryImage *MBinImage = nullptr;
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/context_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace detail {

template <typename Param>
typename Param::return_type get_context_info(ur_context_handle_t Ctx,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_context_info_desc<Param>::value,
"Invalid context information descriptor");
typename Param::return_type Result = 0;
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urContextGetInfo>(Ctx, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);
Adapter.call<UrApiKind::urContextGetInfo>(Ctx, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);
return Result;
}

Expand Down
20 changes: 10 additions & 10 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,19 @@ static bool IsBannedPlatform(platform Platform) {
// replace uses of this with a helper in adapter object, the adapter
// objects will own the ur adapter handles and they'll need to pass them to
// urPlatformsGet - so urPlatformsGet will need to be wrapped with a helper
std::vector<platform> platform_impl::getAdapterPlatforms(AdapterPtr &Adapter,
std::vector<platform> platform_impl::getAdapterPlatforms(adapter_impl &Adapter,
bool Supported) {
std::vector<platform> Platforms;

auto UrPlatforms = Adapter->getUrPlatforms();
auto UrPlatforms = Adapter.getUrPlatforms();

if (UrPlatforms.empty()) {
return Platforms;
}

for (const auto &UrPlatform : UrPlatforms) {
platform Platform = detail::createSyclObjFromImpl<platform>(
getOrMakePlatformImpl(UrPlatform, *Adapter));
getOrMakePlatformImpl(UrPlatform, Adapter));
const bool IsBanned = IsBannedPlatform(Platform);
bool HasAnyDevices = false;

Expand Down Expand Up @@ -168,12 +168,12 @@ std::vector<platform> platform_impl::get_platforms() {

// See which platform we want to be served by which adapter.
// There should be just one adapter serving each backend.
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
std::vector<std::pair<platform, AdapterPtr>> PlatformsWithAdapter;
std::vector<adapter_impl *> &Adapters = ur::initializeUr();
std::vector<std::pair<platform, adapter_impl *>> PlatformsWithAdapter;

// Then check backend-specific adapters
for (auto &Adapter : Adapters) {
const auto &AdapterPlatforms = getAdapterPlatforms(Adapter);
const auto &AdapterPlatforms = getAdapterPlatforms(*Adapter);
for (const auto &P : AdapterPlatforms) {
PlatformsWithAdapter.push_back({P, Adapter});
}
Expand Down Expand Up @@ -504,13 +504,13 @@ platform_impl::get_devices(info::device_type DeviceType) const {
// analysis. Doing adjustment by simple copy of last device num from
// previous platform.
// Needs non const adapter reference.
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
std::vector<adapter_impl *> &Adapters = ur::initializeUr();
auto It = std::find_if(Adapters.begin(), Adapters.end(),
[&Platform = MPlatform](AdapterPtr &Adapter) {
[&Platform = MPlatform](adapter_impl *&Adapter) {
return Adapter->containsUrPlatform(Platform);
});
if (It != Adapters.end()) {
AdapterPtr &Adapter = *It;
adapter_impl *&Adapter = *It;
std::lock_guard<std::mutex> Guard(*Adapter->getAdapterMutex());
Adapter->adjustLastDeviceId(MPlatform);
}
Expand All @@ -530,7 +530,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {

// Filter out devices that are not present in the SYCL_DEVICE_ALLOWLIST
if (SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get())
applyAllowList(UrDevices, MPlatform, MAdapter);
applyAllowList(UrDevices, MPlatform, *MAdapter);

// The first step is to filter out devices that are not compatible with
// ONEAPI_DEVICE_SELECTOR. This is also the mechanism by which top level
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
device_impl *getDeviceImplHelper(ur_device_handle_t UrDevice);

// Helper to get the vector of platforms supported by a given UR adapter
static std::vector<platform> getAdapterPlatforms(AdapterPtr &Adapter,
static std::vector<platform> getAdapterPlatforms(adapter_impl &Adapter,
bool Supported = true);

// Helper to filter reportable devices in the platform
Expand All @@ -216,7 +216,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
ur_platform_handle_t MPlatform = 0;
backend MBackend;

AdapterPtr MAdapter;
adapter_impl *MAdapter;

std::vector<std::shared_ptr<device_impl>> MDevices;
friend class GlobalHandler;
Expand Down
1 change: 0 additions & 1 deletion sycl/source/detail/sycl_mem_obj_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace detail {
class context_impl;
class event_impl;
class adapter_impl;
using AdapterPtr = adapter_impl *;

using EventImplPtr = std::shared_ptr<event_impl>;

Expand Down
3 changes: 2 additions & 1 deletion sycl/source/detail/ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ static void initializeAdapters(std::vector<adapter_impl *> &Adapters,
bool XPTIInitDone = false;

// Initializes all available Adapters.
std::vector<AdapterPtr> &initializeUr(ur_loader_config_handle_t LoaderConfig) {
std::vector<adapter_impl *> &
initializeUr(ur_loader_config_handle_t LoaderConfig) {
// This uses static variable initialization to work around a gcc bug with
// std::call_once and exceptions.
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=66146
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ inline namespace _V1 {
enum class backend : char;
namespace detail {
class adapter_impl;
using AdapterPtr = adapter_impl *;

namespace ur {
void *getURLoaderLibrary();

// Performs UR one-time initialization.
std::vector<AdapterPtr> &
std::vector<adapter_impl *> &
initializeUr(ur_loader_config_handle_t LoaderConfig = nullptr);

// Get the adapter serving given backend.
Expand Down
Loading