Skip to content

[NFC][SYCL] Raw context_impl in getInteropContext and queue_impl ctor #19126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
ur_device_handle_t UrDevice =
Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr;
const auto &Adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(Context);
context_impl &ContextImpl = *getSyclObjImpl(Context);

if (PropList.has_property<ext::intel::property::queue::compute_index>()) {
throw sycl::exception(
Expand Down Expand Up @@ -156,7 +156,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
ur_queue_handle_t UrQueue = nullptr;

Adapter->call<UrApiKind::urQueueCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), UrDevice, &NativeProperties,
NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties,
&UrQueue);
// Construct the SYCL queue from UR queue.
return detail::createSyclObjFromImpl<queue>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this file because of this queue_impl creation.

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
: MSchedule(), MGraphImpl(GraphImpl), MSyncPoints(),
MQueueImpl(sycl::detail::queue_impl::create(
*sycl::detail::getSyclObjImpl(GraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
*sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
sycl::property_list{})),
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
MSchedulerDependencies(),
Expand Down
39 changes: 18 additions & 21 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// constructed.
/// \param AsyncHandler is a SYCL asynchronous exception handler.
/// \param PropList is a list of properties to use for queue construction.
queue_impl(device_impl &Device, const ContextImplPtr &Context,
queue_impl(device_impl &Device, std::shared_ptr<context_impl> &&Context,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need the rvalue-reference overload for the creation using getDefaultOrNew helper at line 108. Maybe this can be refactored further somehow, but that would be independent of passing SYCL RT objects by raw ptr/ref.

const async_handler &AsyncHandler, const property_list &PropList,
private_tag)
: MDevice(Device), MContext(Context), MAsyncHandler(AsyncHandler),
MPropList(PropList),
: MDevice(Device), MContext(std::move(Context)),
MAsyncHandler(AsyncHandler), MPropList(PropList),
MIsInorder(has_property<property::queue::in_order>()),
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
MQueueID{
Expand All @@ -146,8 +146,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
"Queue compute index must be a non-negative number less than "
"device's number of available compute queue indices.");
}
if (!Context->isDeviceValid(Device)) {
if (Context->getBackend() == backend::opencl)
if (!MContext->isDeviceValid(Device)) {
if (MContext->getBackend() == backend::opencl)
throw sycl::exception(
make_error_code(errc::invalid),
"Queue cannot be constructed with the given context and device "
Expand Down Expand Up @@ -177,17 +177,13 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
trySwitchingToNoEventsMode();
}

sycl::detail::optional<event> getLastEvent();
queue_impl(device_impl &Device, context_impl &Context,
const async_handler &AsyncHandler, const property_list &PropList,
private_tag Tag)
: queue_impl(Device, Context.shared_from_this(), AsyncHandler, PropList,
Tag) {}

/// Constructs a SYCL queue from adapter interoperability handle.
///
/// \param UrQueue is a raw UR queue handle.
/// \param Context is a SYCL context to associate with the queue being
/// constructed.
/// \param AsyncHandler is a SYCL asynchronous exception handler.
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
const async_handler &AsyncHandler, private_tag tag)
: queue_impl(UrQueue, Context, AsyncHandler, {}, tag) {}
sycl::detail::optional<event> getLastEvent();

/// Constructs a SYCL queue from adapter interoperability handle.
///
Expand All @@ -196,27 +192,28 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
/// constructed.
/// \param AsyncHandler is a SYCL asynchronous exception handler.
/// \param PropList is the queue properties.
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
queue_impl(ur_queue_handle_t UrQueue, context_impl &Context,
const async_handler &AsyncHandler, const property_list &PropList,
private_tag)
: MDevice([&]() -> device_impl & {
ur_device_handle_t DeviceUr{};
const AdapterPtr &Adapter = Context->getAdapter();
const AdapterPtr &Adapter = Context.getAdapter();
// TODO catch an exception and put it to list of asynchronous
// exceptions
Adapter->call<UrApiKind::urQueueGetInfo>(
UrQueue, UR_QUEUE_INFO_DEVICE, sizeof(DeviceUr), &DeviceUr,
nullptr);
device_impl *Device = Context->findMatchingDeviceImpl(DeviceUr);
device_impl *Device = Context.findMatchingDeviceImpl(DeviceUr);
if (Device == nullptr) {
throw sycl::exception(
make_error_code(errc::invalid),
"Device provided by native Queue not found in Context.");
}
return *Device;
}()),
MContext(Context), MAsyncHandler(AsyncHandler), MPropList(PropList),
MQueue(UrQueue), MIsInorder(has_property<property::queue::in_order>()),
MContext(Context.shared_from_this()), MAsyncHandler(AsyncHandler),
MPropList(PropList), MQueue(UrQueue),
MIsInorder(has_property<property::queue::in_order>()),
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
MQueueID{
MNextAvailableQueueID.fetch_add(1, std::memory_order_relaxed)} {
Expand Down Expand Up @@ -988,7 +985,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
mutable std::mutex MMutex;

device_impl &MDevice;
const ContextImplPtr MContext;
const std::shared_ptr<context_impl> MContext;

/// These events are tracked, but not owned, by the queue.
std::vector<std::weak_ptr<event_impl>> MEventsWeak;
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
cleanupCommand(Cmd);
};

const ContextImplPtr &InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
context_impl *InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
if (InteropCtxPtr) {
// The memory object has been constructed using interoperability constructor
// which means that there is already an allocation(cl_mem) in some context.
Expand All @@ -225,10 +225,10 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
// here, we need to create a dummy queue bound to the context and one of the
// devices from the context.
std::shared_ptr<queue_impl> InteropQueuePtr = queue_impl::create(
Dev, InteropCtxPtr, async_handler{}, property_list{});
Dev, *InteropCtxPtr, async_handler{}, property_list{});

MemObject->MRecord.reset(
new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency});
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
std::vector<Command *> ToEnqueue;
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr,
ToEnqueue);
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/sycl_mem_obj_i.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class context_impl;
struct MemObjRecord;

using EventImplPtr = std::shared_ptr<detail::event_impl>;
using ContextImplPtr = std::shared_ptr<detail::context_impl>;

// The class serves as an interface in the scheduler for all SYCL memory
// objects.
Expand Down Expand Up @@ -72,7 +71,7 @@ class SYCLMemObjI {

// Returns the context which is passed if a memory object is created using
// interoperability constructor, nullptr otherwise.
virtual ContextImplPtr getInteropContext() const = 0;
virtual detail::context_impl *getInteropContext() const = 0;

protected:
// Pointer to the record that contains the memory commands. This is managed
Expand Down
7 changes: 4 additions & 3 deletions sycl/source/detail/sycl_mem_obj_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class event_impl;
class Adapter;
using AdapterPtr = std::shared_ptr<Adapter>;

using ContextImplPtr = std::shared_ptr<context_impl>;
using EventImplPtr = std::shared_ptr<event_impl>;

// The class serves as a base for all SYCL memory objects.
Expand Down Expand Up @@ -281,7 +280,9 @@ class SYCLMemObjT : public SYCLMemObjI {

MemObjType getType() const override { return MemObjType::Undefined; }

ContextImplPtr getInteropContext() const override { return MInteropContext; }
context_impl *getInteropContext() const override {
return MInteropContext.get();
}

bool isInterop() const override;

Expand Down Expand Up @@ -339,7 +340,7 @@ class SYCLMemObjT : public SYCLMemObjI {
// Should wait on this event before start working with such memory object.
EventImplPtr MInteropEvent;
// Context passed by user to interoperability constructor.
ContextImplPtr MInteropContext;
std::shared_ptr<context_impl> MInteropContext;
// Native backend memory object handle passed by user to interoperability
// constructor.
ur_mem_handle_t MInteropMemObject;
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ queue::queue(const context &SyclContext, const device_selector &DeviceSelector,
const device &SyclDevice = *std::max_element(Devs.begin(), Devs.end(), Comp);

impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
detail::getSyclObjImpl(SyclContext),
*detail::getSyclObjImpl(SyclContext),
AsyncHandler, PropList);
}

queue::queue(const context &SyclContext, const device &SyclDevice,
const async_handler &AsyncHandler, const property_list &PropList) {
impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
detail::getSyclObjImpl(SyclContext),
*detail::getSyclObjImpl(SyclContext),
AsyncHandler, PropList);
}

Expand Down Expand Up @@ -100,7 +100,7 @@ queue::queue(cl_command_queue clQueue, const context &SyclContext,
impl = detail::queue_impl::create(
// TODO(pi2ur): Don't cast straight from cl_command_queue
reinterpret_cast<ur_queue_handle_t>(clQueue),
detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
*detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
}

cl_command_queue queue::get() const { return impl->get(); }
Expand Down
8 changes: 4 additions & 4 deletions sycl/unittests/scheduler/HostTaskAndBarrier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
namespace {
using namespace sycl;
using EventImplPtr = std::shared_ptr<sycl::detail::event_impl>;
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;

constexpr auto DisableCleanupName = "SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP";

class TestQueueImpl : public sycl::detail::queue_impl {
public:
TestQueueImpl(ContextImplPtr SyclContext, sycl::detail::device_impl &Dev)
TestQueueImpl(sycl::detail::context_impl &SyclContext,
sycl::detail::device_impl &Dev)
: sycl::detail::queue_impl(Dev, SyclContext,
SyclContext->get_async_handler(), {},
SyclContext.get_async_handler(), {},
sycl::detail::queue_impl::private_tag{}) {}
using sycl::detail::queue_impl::MDefaultGraphDeps;
using sycl::detail::queue_impl::MExtGraphDeps;
Expand All @@ -46,7 +46,7 @@ class BarrierHandlingWithHostTask : public ::testing::Test {
sycl::device SyclDev =
sycl::detail::select_device(sycl::default_selector_v, SyclContext);
QueueDevImpl.reset(
new TestQueueImpl(sycl::detail::getSyclObjImpl(SyclContext),
new TestQueueImpl(*sycl::detail::getSyclObjImpl(SyclContext),
*sycl::detail::getSyclObjImpl(SyclDev)));

MainLock.lock();
Expand Down
3 changes: 1 addition & 2 deletions sycl/unittests/scheduler/LinkedAllocaDependencies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class MemObjMock : public sycl::detail::SYCLMemObjI {
bool isHostPointerReadOnly() const override { return false; }
bool usesPinnedHostMemory() const override { return false; }

std::shared_ptr<sycl::detail::context_impl>
getInteropContext() const override {
sycl::detail::context_impl *getInteropContext() const override {
return nullptr;
}
};
Expand Down