Skip to content

Commit

Permalink
Merge pull request #2506 from AllanZyne/review/yang/fix_kernel_native
Browse files Browse the repository at this point in the history
[DeviceASAN] Fix urKernelCreateWithNativeHandle segfault
  • Loading branch information
kbenzie committed Jan 6, 2025
1 parent 326d800 commit 71ad3b7
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 137 deletions.
57 changes: 17 additions & 40 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,28 +1335,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelCreate
__urdlllocal ur_result_t UR_APICALL urKernelCreate(
ur_program_handle_t hProgram, ///< [in] handle of the program instance
const char *pKernelName, ///< [in] pointer to null-terminated string.
ur_kernel_handle_t
*phKernel ///< [out] pointer to handle of kernel object created.
) {
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;

if (nullptr == pfnCreate) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelRetain
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
Expand All @@ -1372,8 +1350,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
KernelInfo->RefCount++;
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
KernelInfo.RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1392,9 +1370,9 @@ __urdlllocal ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
if (--KernelInfo.RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1423,9 +1401,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getAsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
Expand Down Expand Up @@ -1453,9 +1431,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(

std::shared_ptr<MemBuffer> MemBuffer;
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
}
Expand Down Expand Up @@ -1484,12 +1462,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
argSize);

{
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
// TODO: get local variable alignment
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
KI.LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
argSize = argSizeWithRZ;
}

Expand Down Expand Up @@ -1522,9 +1500,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(

std::shared_ptr<KernelInfo> KI;
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
KI.PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
}

ur_result_t result =
Expand Down Expand Up @@ -1708,7 +1686,6 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable(

ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease;
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue;
Expand Down
37 changes: 21 additions & 16 deletions source/loader/layers/sanitizer/asan/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,22 +639,26 @@ ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) {
return UR_RESULT_SUCCESS;
}

ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return UR_RESULT_SUCCESS;
KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
{
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return *m_KernelMap[Kernel].get();
}
}

auto hProgram = GetProgram(Kernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
// Create new KernelInfo
auto Program = GetProgram(Kernel);
auto PI = getProgramInfo(Program);
bool IsInstrumented = PI->isKernelInstrumented(Kernel);

std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
m_KernelMap.emplace(Kernel,
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
return UR_RESULT_SUCCESS;
std::make_unique<KernelInfo>(Kernel, IsInstrumented));
return *m_KernelMap[Kernel].get();
}

ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) {
ur_result_t AsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) {
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
m_KernelMap.erase(Kernel);
Expand Down Expand Up @@ -691,7 +695,8 @@ ur_result_t AsanInterceptor::prepareLaunch(
std::shared_ptr<ContextInfo> &ContextInfo,
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {
auto KernelInfo = getKernelInfo(Kernel);
auto &KernelInfo = getOrCreateKernelInfo(Kernel);
std::shared_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);

auto ArgNums = GetKernelNumArgs(Kernel);
auto LocalMemoryUsage =
Expand All @@ -703,11 +708,11 @@ ur_result_t AsanInterceptor::prepareLaunch(
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
"LocalMemory={}, PrivateMemory={})",
(void *)Kernel, GetKernelName(Kernel), ArgNums,
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
KernelInfo.IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);

// Validate pointer arguments
if (getOptions().DetectKernelArguments) {
for (const auto &[ArgIndex, PtrPair] : KernelInfo->PointerArgs) {
for (const auto &[ArgIndex, PtrPair] : KernelInfo.PointerArgs) {
auto Ptr = PtrPair.first;
if (Ptr == nullptr) {
continue;
Expand All @@ -722,7 +727,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
}

// Set membuffer arguments
for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) {
for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) {
char *ArgPointer = nullptr;
UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer));
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
Expand All @@ -735,7 +740,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
}
}

if (!KernelInfo->IsInstrumented) {
if (!KernelInfo.IsInstrumented) {
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -830,9 +835,9 @@ ur_result_t AsanInterceptor::prepareLaunch(
}

// Write local arguments info
if (!KernelInfo->LocalArgs.empty()) {
if (!KernelInfo.LocalArgs.empty()) {
std::vector<LocalArgsInfo> LocalArgsInfo;
for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) {
for (auto [ArgIndex, ArgInfo] : KernelInfo.LocalArgs) {
LocalArgsInfo.push_back(ArgInfo);
getContext()->logger.debug(
"local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex,
Expand Down
12 changes: 3 additions & 9 deletions source/loader/layers/sanitizer/asan/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ class AsanInterceptor {
ur_result_t insertProgram(ur_program_handle_t Program);
ur_result_t eraseProgram(ur_program_handle_t Program);

ur_result_t insertKernel(ur_kernel_handle_t Kernel);
ur_result_t eraseKernel(ur_kernel_handle_t Kernel);

ur_result_t insertMemBuffer(std::shared_ptr<MemBuffer> MemBuffer);
ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle);
std::shared_ptr<MemBuffer> getMemBuffer(ur_mem_handle_t MemHandle);
Expand Down Expand Up @@ -350,11 +347,8 @@ class AsanInterceptor {
return nullptr;
}

std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
return m_KernelMap[Kernel];
}
KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel);
ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel);

const AsanOptions &getOptions() { return m_Options; }

Expand Down Expand Up @@ -401,7 +395,7 @@ class AsanInterceptor {
m_ProgramMap;
ur_shared_mutex m_ProgramMapMutex;

std::unordered_map<ur_kernel_handle_t, std::shared_ptr<KernelInfo>>
std::unordered_map<ur_kernel_handle_t, std::unique_ptr<KernelInfo>>
m_KernelMap;
ur_shared_mutex m_KernelMapMutex;

Expand Down
64 changes: 13 additions & 51 deletions source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
auto hProgram = GetProgram(hKernel);
auto PI = getMsanInterceptor()->getProgramInfo(hProgram);
return PI->isKernelInstrumented(hKernel);
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -354,12 +348,6 @@ ur_result_t urEnqueueKernelLaunch(

getContext()->logger.debug("==== urEnqueueKernelLaunch");

if (!isInstrumentedKernel(hKernel)) {
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
workDim);
Expand Down Expand Up @@ -1155,26 +1143,6 @@ ur_result_t urEnqueueMemUnmap(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelCreate
ur_result_t urKernelCreate(
ur_program_handle_t hProgram, ///< [in] handle of the program instance
const char *pKernelName, ///< [in] pointer to null-terminated string.
ur_kernel_handle_t
*phKernel ///< [out] pointer to handle of kernel object created.
) {
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;

getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
if (isInstrumentedKernel(*phKernel)) {
UR_CALL(getMsanInterceptor()->insertKernel(*phKernel));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelRetain
ur_result_t urKernelRetain(
Expand All @@ -1186,10 +1154,8 @@ ur_result_t urKernelRetain(

UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
KernelInfo->RefCount++;
}
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
KernelInfo.RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1204,11 +1170,9 @@ ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
if (--KernelInfo->RefCount == 0) {
UR_CALL(getMsanInterceptor()->eraseKernel(hKernel));
}
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
if (--KernelInfo.RefCount == 0) {
UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel));
}

return UR_RESULT_SUCCESS;
Expand All @@ -1230,13 +1194,12 @@ ur_result_t urKernelSetArgValue(
getContext()->logger.debug("==== urKernelSetArgValue");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getMsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
Expand All @@ -1260,10 +1223,10 @@ ur_result_t urKernelSetArgMemObj(

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue)) &&
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue))) {
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
}
Expand Down Expand Up @@ -1348,7 +1311,6 @@ ur_result_t urGetKernelProcAddrTable(
) {
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease;
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue;
Expand Down
Loading

0 comments on commit 71ad3b7

Please sign in to comment.