Skip to content

Commit

Permalink
Merge pull request #2520 from zhaomaosu/fix-buffer-shadow
Browse files Browse the repository at this point in the history
[DevMSAN] Propagate shadow memory in buffer related APIs
  • Loading branch information
kbenzie committed Jan 17, 2025
1 parent 1026b47 commit 868c56d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 27 deletions.
61 changes: 56 additions & 5 deletions source/loader/layers/sanitizer/msan/msan_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,67 @@ ur_result_t EnqueueMemCopyRectHelper(
char *DstOrigin = pDst + DstOffset.x + DstRowPitch * DstOffset.y +
DstSlicePitch * DstOffset.z;

const bool IsDstDeviceUSM = getMsanInterceptor()
->findAllocInfoByAddress((uptr)DstOrigin)
.has_value();
const bool IsSrcDeviceUSM = getMsanInterceptor()
->findAllocInfoByAddress((uptr)SrcOrigin)
.has_value();

ur_device_handle_t Device = GetDevice(Queue);
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
std::vector<ur_event_handle_t> Events;
Events.reserve(Region.depth);

// For now, USM doesn't support 3D memory copy operation, so we can only
// loop call 2D memory copy function to implement it.
for (size_t i = 0; i < Region.depth; i++) {
ur_event_handle_t NewEvent{};
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
Queue, Blocking, DstOrigin + (i * DstSlicePitch), DstRowPitch,
Queue, false, DstOrigin + (i * DstSlicePitch), DstRowPitch,
SrcOrigin + (i * SrcSlicePitch), SrcRowPitch, Region.width,
Region.height, NumEventsInWaitList, EventWaitList, &NewEvent));

Events.push_back(NewEvent);

// Update shadow memory
if (IsDstDeviceUSM && IsSrcDeviceUSM) {
NewEvent = nullptr;
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
(uptr)DstOrigin + (i * DstSlicePitch));
uptr SrcShadowAddr = DeviceInfo->Shadow->MemToShadow(
(uptr)SrcOrigin + (i * SrcSlicePitch));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
Queue, false, (void *)DstShadowAddr, DstRowPitch,
(void *)SrcShadowAddr, SrcRowPitch, Region.width, Region.height,
NumEventsInWaitList, EventWaitList, &NewEvent));
Events.push_back(NewEvent);
} else if (IsDstDeviceUSM && !IsSrcDeviceUSM) {
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
(uptr)DstOrigin + (i * DstSlicePitch));
const char Val = 0;
// opencl & l0 adapter doesn't implement urEnqueueUSMFill2D, so
// emulate the operation with urEnqueueUSMFill.
for (size_t HeightIndex = 0; HeightIndex < Region.height;
HeightIndex++) {
NewEvent = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
Queue, (void *)(DstShadowAddr + HeightIndex * DstRowPitch),
1, &Val, Region.width, NumEventsInWaitList, EventWaitList,
&NewEvent));
Events.push_back(NewEvent);
}
}
}

UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
Queue, Events.size(), Events.data(), Event));
if (Blocking) {
UR_CALL(
getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0]));
}

if (Event) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
Queue, Events.size(), &Events[0], Event));
}

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -112,6 +157,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
Size, HostPtr, this);
return URes;
}

// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
Queue, (uptr)Allocation, Size, 0));
}
}

Expand Down
129 changes: 107 additions & 22 deletions source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ ur_result_t urMemBufferCreate(
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr));

// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(hDevice);
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
InternalQueue, (uptr)Handle, size, 0));
}
}

Expand Down Expand Up @@ -730,10 +736,29 @@ ur_result_t urEnqueueMemBufferWrite(
if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
ur_device_handle_t Device = GetDevice(hQueue);
char *pDst = nullptr;
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(MemBuffer->getHandle(Device, pDst));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
hQueue, blockingWrite, pDst + offset, pSrc, size,
numEventsInWaitList, phEventWaitList, phEvent));
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
const char Val = 0;
uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)pDst + offset);
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}
} else {
UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size,
pSrc, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -893,15 +918,36 @@ ur_result_t urEnqueueMemBufferCopy(

if (SrcBuffer && DstBuffer) {
ur_device_handle_t Device = GetDevice(hQueue);
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
char *SrcHandle = nullptr;
UR_CALL(SrcBuffer->getHandle(Device, SrcHandle));

char *DstHandle = nullptr;
UR_CALL(DstBuffer->getHandle(Device, DstHandle));

std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size,
numEventsInWaitList, phEventWaitList, phEvent));
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

// Update shadow memory
uptr DstShadowAddr =
DeviceInfo->Shadow->MemToShadow((uptr)DstHandle + dstOffset);
uptr SrcShadowAddr =
DeviceInfo->Shadow->MemToShadow((uptr)SrcHandle + srcOffset);
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
hQueue, false, (void *)DstShadowAddr, (void *)SrcShadowAddr, size,
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}
} else {
UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset,
dstOffset, size, numEventsInWaitList,
Expand Down Expand Up @@ -1000,11 +1046,31 @@ ur_result_t urEnqueueMemBufferFill(

if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
char *Handle = nullptr;
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
ur_device_handle_t Device = GetDevice(hQueue);
UR_CALL(MemBuffer->getHandle(Device, Handle));
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
hQueue, Handle + offset, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, phEvent));
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

// Update shadow memory
std::shared_ptr<DeviceInfo> DeviceInfo =
getMsanInterceptor()->getDeviceInfo(Device);
const char Val = 0;
uptr ShadowAddr =
DeviceInfo->Shadow->MemToShadow((uptr)Handle + offset);
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
phEventWaitList, &Event));
Events.push_back(Event);

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, Events.size(), Events.data(), phEvent));
}
} else {
UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset,
size, numEventsInWaitList, phEventWaitList,
Expand Down Expand Up @@ -1270,9 +1336,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill;
getContext()->logger.debug("==== urEnqueueUSMFill");

ur_event_handle_t hEvents[2] = {};
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
Expand All @@ -1283,13 +1351,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill(
getMsanInterceptor()->getDeviceInfo(MemInfo->Device);
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

Event = nullptr;
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0,
nullptr, &hEvents[1]));
nullptr, &Event));
Events.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
hQueue, Events.size(), Events.data(), phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1319,9 +1389,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy;
getContext()->logger.debug("==== pfnUSMMemcpy");

ur_event_handle_t hEvents[2] = {};
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size,
numEventsInWaitList, phEventWaitList, &hEvents[0]));
numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
Expand All @@ -1336,22 +1408,26 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy(
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

Event = nullptr;
UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow,
(void *)SrcShadow, size, 0, nullptr, &hEvents[1]));
(void *)SrcShadow, size, 0, nullptr, &Event));
Events.push_back(Event);
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

const auto &DeviceInfo =
getMsanInterceptor()->getDeviceInfo(DstInfo->Device);
auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

Event = nullptr;
UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0,
nullptr, &hEvents[1]));
nullptr, &Event));
Events.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
hQueue, Events.size(), Events.data(), phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1387,10 +1463,11 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D;
getContext()->logger.debug("==== urEnqueueUSMFill2D");

ur_event_handle_t hEvents[2] = {};
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width,
height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));
height, numEventsInWaitList, phEventWaitList, &Event));
Events.push_back(Event);

const auto Mem = (uptr)pMem;
auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem);
Expand All @@ -1402,13 +1479,15 @@ ur_result_t UR_APICALL urEnqueueUSMFill2D(
const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem);

const char Pattern = 0;
Event = nullptr;
UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern,
width, height, 0, nullptr, &hEvents[1]));
width, height, 0, nullptr, &Event));
Events.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
hQueue, Events.size(), Events.data(), phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1443,10 +1522,12 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D;
getContext()->logger.debug("==== pfnUSMMemcpy2D");

ur_event_handle_t hEvents[2] = {};
std::vector<ur_event_handle_t> Events;
ur_event_handle_t Event{};
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch,
width, height, numEventsInWaitList, phEventWaitList,
&hEvents[0]));
&Event));
Events.push_back(Event);

const auto Src = (uptr)pSrc, Dst = (uptr)pDst;
auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src);
Expand All @@ -1461,9 +1542,11 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src);
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

Event = nullptr;
UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch,
(void *)SrcShadow, srcPitch, width, height, 0,
nullptr, &hEvents[1]));
nullptr, &Event));
Events.push_back(Event);
} else if (DstInfoItOp) {
auto DstInfo = (*DstInfoItOp)->second;

Expand All @@ -1472,14 +1555,16 @@ ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst);

const char Pattern = 0;
Event = nullptr;
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0,
nullptr, &hEvents[1]));
nullptr, &Event));
Events.push_back(Event);
}

if (phEvent) {
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
hQueue, 2, hEvents, phEvent));
hQueue, Events.size(), Events.data(), phEvent));
}

return UR_RESULT_SUCCESS;
Expand Down
1 change: 1 addition & 0 deletions source/loader/layers/sanitizer/msan/msan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
m_AllocationMap.emplace(AI->AllocBegin, AI);
}

// Update shadow memory
ManagedQueue Queue(Context, Device);
DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, AI->AllocBegin,
AI->AllocSize, 0xff);
Expand Down

0 comments on commit 868c56d

Please sign in to comment.