Skip to content

[DeviceMSAN] Fix urEnqueueUSMMemcpy2D return UR_RESULT_ERROR_UNSUPPORTED_FEATURE after enabling MSAN #19286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: sycl
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,48 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

ur_result_t urEnqueueUSMFill2DFallback(ur_queue_handle_t hQueue, void *pMem,
size_t pitch, size_t patternSize,
const void *pPattern, size_t width,
size_t height,
uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {
ur_result_t Result = getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, pMem, pitch, patternSize, pPattern, width, height,
numEventsInWaitList, phEventWaitList, phEvent);
if (Result == UR_RESULT_SUCCESS ||
Result != UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
return Result;
}

// fallback code
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;

std::vector<ur_event_handle_t> WaitEvents(numEventsInWaitList);

for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) {
ur_event_handle_t Event = nullptr;

UR_CALL(pfnUSMFill(hQueue, (void *)((char *)pMem + pitch * HeightIndex),
patternSize, pPattern, width, WaitEvents.size(),
WaitEvents.data(), &Event));

WaitEvents.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, WaitEvents.size(), WaitEvents.data(), phEvent));
}

for (const auto Event : WaitEvents) {
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(Event));
}

return UR_RESULT_SUCCESS;
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1726,11 +1768,6 @@ ur_result_t urEnqueueUSMMemcpy2D(
{
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;

std::vector<ur_event_handle_t> WaitEvents(numEventsInWaitList);
for (uint32_t i = 0; i < numEventsInWaitList; i++) {
WaitEvents[i] = phEventWaitList[i];
}

for (size_t HeightIndex = 0; HeightIndex < height; HeightIndex++) {
ur_event_handle_t Event = nullptr;
const auto DstOrigin =
Expand All @@ -1742,8 +1779,8 @@ ur_result_t urEnqueueUSMMemcpy2D(
width - 1) +
MSAN_ORIGIN_GRANULARITY;
pfnUSMMemcpy(hQueue, false, (void *)DstOrigin, (void *)SrcOrigin,
SrcOriginEnd - SrcOrigin, WaitEvents.size(),
WaitEvents.data(), &Event);
SrcOriginEnd - SrcOrigin, numEventsInWaitList,
phEventWaitList, &Event);
Events.push_back(Event);
}
}
Expand All @@ -1756,9 +1793,9 @@ ur_result_t urEnqueueUSMMemcpy2D(
const auto DstShadow = DstDI->Shadow->MemToShadow((uptr)pDst);
const char Pattern = 0;
ur_event_handle_t Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
nullptr, &Event));
UR_CALL(urEnqueueUSMFill2DFallback(hQueue, (void *)DstShadow, dstPitch, 1,
&Pattern, width, height, 0, nullptr,
&Event));
Events.push_back(Event);
}

Expand All @@ -1767,7 +1804,7 @@ ur_result_t urEnqueueUSMMemcpy2D(
hQueue, Events.size(), Events.data(), phEvent));
}

for (const auto &E : Events)
for (const auto E : Events)
UR_CALL(getContext()->urDdiTable.Event.pfnRelease(E));

return UR_RESULT_SUCCESS;
Expand Down