Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceASAN] Fix urKernelCreateWithNativeHandle segfault #2506

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 17 additions & 40 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,28 +1335,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelCreate
__urdlllocal ur_result_t UR_APICALL urKernelCreate(
ur_program_handle_t hProgram, ///< [in] handle of the program instance
const char *pKernelName, ///< [in] pointer to null-terminated string.
ur_kernel_handle_t
*phKernel ///< [out] pointer to handle of kernel object created.
) {
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;

if (nullptr == pfnCreate) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelRetain
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
Expand All @@ -1372,8 +1350,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
KernelInfo->RefCount++;
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
KernelInfo.RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1392,9 +1370,9 @@ __urdlllocal ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
if (--KernelInfo->RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
if (--KernelInfo.RefCount == 0) {
UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel));
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -1423,9 +1401,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getAsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
Expand Down Expand Up @@ -1453,9 +1431,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(

std::shared_ptr<MemBuffer> MemBuffer;
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
}
Expand Down Expand Up @@ -1484,12 +1462,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
argSize);

{
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
// TODO: get local variable alignment
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
KI.LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
argSize = argSizeWithRZ;
}

Expand Down Expand Up @@ -1522,9 +1500,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(

std::shared_ptr<KernelInfo> KI;
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
KI.PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
}

ur_result_t result =
Expand Down Expand Up @@ -1708,7 +1686,6 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable(

ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease;
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue;
Expand Down
37 changes: 21 additions & 16 deletions source/loader/layers/sanitizer/asan/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,22 +639,26 @@ ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) {
return UR_RESULT_SUCCESS;
}

ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return UR_RESULT_SUCCESS;
KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
{
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
return *m_KernelMap[Kernel].get();
}
}

auto hProgram = GetProgram(Kernel);
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
// Create new KernelInfo
auto Program = GetProgram(Kernel);
auto PI = getProgramInfo(Program);
bool IsInstrumented = PI->isKernelInstrumented(Kernel);

std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
m_KernelMap.emplace(Kernel,
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
return UR_RESULT_SUCCESS;
std::make_unique<KernelInfo>(Kernel, IsInstrumented));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why bother using a unique_ptr if there is no ownership change? We can just use raw pointer here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to manually delete them if it's raw pointer. unique_ptr is more convenient here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can just use the object created by the map. It should be deleted when being erased from then map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we can just use the object created by the map. It should be deleted when being erased from then map.

I've tried this but encountered a difficult-to-solve compilation error.
Using unqiue_ptr is a simple solution here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I am fine with this as a temporary solution. We plan to refactor related code in near future anyway. We can discuss for any better solution in that time.

return *m_KernelMap[Kernel].get();
}

ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) {
ur_result_t AsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) {
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
m_KernelMap.erase(Kernel);
Expand Down Expand Up @@ -691,7 +695,8 @@ ur_result_t AsanInterceptor::prepareLaunch(
std::shared_ptr<ContextInfo> &ContextInfo,
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {
auto KernelInfo = getKernelInfo(Kernel);
auto &KernelInfo = getOrCreateKernelInfo(Kernel);
std::shared_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);

auto ArgNums = GetKernelNumArgs(Kernel);
auto LocalMemoryUsage =
Expand All @@ -703,11 +708,11 @@ ur_result_t AsanInterceptor::prepareLaunch(
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
"LocalMemory={}, PrivateMemory={})",
(void *)Kernel, GetKernelName(Kernel), ArgNums,
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
KernelInfo.IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);

// Validate pointer arguments
if (getOptions().DetectKernelArguments) {
for (const auto &[ArgIndex, PtrPair] : KernelInfo->PointerArgs) {
for (const auto &[ArgIndex, PtrPair] : KernelInfo.PointerArgs) {
auto Ptr = PtrPair.first;
if (Ptr == nullptr) {
continue;
Expand All @@ -722,7 +727,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
}

// Set membuffer arguments
for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) {
for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) {
char *ArgPointer = nullptr;
UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer));
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
Expand All @@ -735,7 +740,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
}
}

if (!KernelInfo->IsInstrumented) {
if (!KernelInfo.IsInstrumented) {
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -830,9 +835,9 @@ ur_result_t AsanInterceptor::prepareLaunch(
}

// Write local arguments info
if (!KernelInfo->LocalArgs.empty()) {
if (!KernelInfo.LocalArgs.empty()) {
std::vector<LocalArgsInfo> LocalArgsInfo;
for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) {
for (auto [ArgIndex, ArgInfo] : KernelInfo.LocalArgs) {
LocalArgsInfo.push_back(ArgInfo);
getContext()->logger.debug(
"local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex,
Expand Down
12 changes: 3 additions & 9 deletions source/loader/layers/sanitizer/asan/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ class AsanInterceptor {
ur_result_t insertProgram(ur_program_handle_t Program);
ur_result_t eraseProgram(ur_program_handle_t Program);

ur_result_t insertKernel(ur_kernel_handle_t Kernel);
ur_result_t eraseKernel(ur_kernel_handle_t Kernel);

ur_result_t insertMemBuffer(std::shared_ptr<MemBuffer> MemBuffer);
ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle);
std::shared_ptr<MemBuffer> getMemBuffer(ur_mem_handle_t MemHandle);
Expand Down Expand Up @@ -350,11 +347,8 @@ class AsanInterceptor {
return nullptr;
}

std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
return m_KernelMap[Kernel];
}
KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel);
ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel);

const AsanOptions &getOptions() { return m_Options; }

Expand Down Expand Up @@ -401,7 +395,7 @@ class AsanInterceptor {
m_ProgramMap;
ur_shared_mutex m_ProgramMapMutex;

std::unordered_map<ur_kernel_handle_t, std::shared_ptr<KernelInfo>>
std::unordered_map<ur_kernel_handle_t, std::unique_ptr<KernelInfo>>
m_KernelMap;
ur_shared_mutex m_KernelMapMutex;

Expand Down
64 changes: 13 additions & 51 deletions source/loader/layers/sanitizer/msan/msan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
return UR_RESULT_SUCCESS;
}

bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
auto hProgram = GetProgram(hKernel);
auto PI = getMsanInterceptor()->getProgramInfo(hProgram);
return PI->isKernelInstrumented(hKernel);
}

} // namespace

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -354,12 +348,6 @@ ur_result_t urEnqueueKernelLaunch(

getContext()->logger.debug("==== urEnqueueKernelLaunch");

if (!isInstrumentedKernel(hKernel)) {
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
pGlobalWorkSize, pLocalWorkSize,
numEventsInWaitList, phEventWaitList, phEvent);
}

USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
workDim);
Expand Down Expand Up @@ -1155,26 +1143,6 @@ ur_result_t urEnqueueMemUnmap(
return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelCreate
ur_result_t urKernelCreate(
ur_program_handle_t hProgram, ///< [in] handle of the program instance
const char *pKernelName, ///< [in] pointer to null-terminated string.
ur_kernel_handle_t
*phKernel ///< [out] pointer to handle of kernel object created.
) {
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;

getContext()->logger.debug("==== urKernelCreate");

UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
if (isInstrumentedKernel(*phKernel)) {
UR_CALL(getMsanInterceptor()->insertKernel(*phKernel));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urKernelRetain
ur_result_t urKernelRetain(
Expand All @@ -1186,10 +1154,8 @@ ur_result_t urKernelRetain(

UR_CALL(pfnRetain(hKernel));

auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
KernelInfo->RefCount++;
}
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
KernelInfo.RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1204,11 +1170,9 @@ ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
if (KernelInfo) {
if (--KernelInfo->RefCount == 0) {
UR_CALL(getMsanInterceptor()->eraseKernel(hKernel));
}
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
if (--KernelInfo.RefCount == 0) {
UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel));
}

return UR_RESULT_SUCCESS;
Expand All @@ -1230,13 +1194,12 @@ ur_result_t urKernelSetArgValue(
getContext()->logger.debug("==== urKernelSetArgValue");

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if (argSize == sizeof(ur_mem_handle_t) &&
(MemBuffer = getMsanInterceptor()->getMemBuffer(
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
Expand All @@ -1260,10 +1223,10 @@ ur_result_t urKernelSetArgMemObj(

std::shared_ptr<MemBuffer> MemBuffer;
std::shared_ptr<KernelInfo> KernelInfo;
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue)) &&
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue))) {
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
} else {
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
}
Expand Down Expand Up @@ -1348,7 +1311,6 @@ ur_result_t urGetKernelProcAddrTable(
) {
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease;
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue;
Expand Down
Loading
Loading