Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP #2530

Closed
wants to merge 1 commit into from
Closed

WIP #2530

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ if(UR_ENABLE_SANITIZER)
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanddi.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.hpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_obj_handler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_obj_handler.hpp
)

if(UR_ENABLE_SYMBOLIZER)
Expand Down
10 changes: 8 additions & 2 deletions source/loader/layers/sanitizer/asan/asan_ddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate(
pfnCreate(numDevices, phDevices, pProperties, phContext);

if (result == UR_RESULT_SUCCESS) {
getContext()->objectHandler.add(*phContext);
UR_CALL(setupContext(*phContext, numDevices, phDevices));
}

Expand Down Expand Up @@ -543,6 +544,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
phDevices, pProperties, phContext);

if (result == UR_RESULT_SUCCESS) {
getContext()->objectHandler.add(*phContext);
UR_CALL(setupContext(*phContext, numDevices, phDevices));
}

Expand All @@ -563,7 +565,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(

getContext()->logger.debug("==== urContextRetain");

UR_CALL(pfnRetain(hContext));
// UR_CALL(pfnRetain(hContext));
UR_CALL(getContext()->objectHandler.retain(hContext));

auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
Expand All @@ -585,7 +588,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(

getContext()->logger.debug("==== urContextRelease");

UR_CALL(pfnRelease(hContext));
// UR_CALL(pfnRelease(hContext));
UR_CALL(getContext()->objectHandler.release(hContext));

auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
Expand Down Expand Up @@ -2037,6 +2041,8 @@ ur_result_t initAsanDDITable(ur_dditable_t *dditable) {

getContext()->logger.always("==== DeviceSanitizer: ASAN");

getContext()->objectHandler.installDdiTable(dditable);

if (UR_RESULT_SUCCESS == result) {
result = ur_sanitizer_layer::asan::urGetGlobalProcAddrTable(
UR_API_VERSION_CURRENT, &dditable->Global);
Expand Down
6 changes: 4 additions & 2 deletions source/loader/layers/sanitizer/asan/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,8 @@ ContextInfo::~ContextInfo() {
assert(URes == UR_RESULT_SUCCESS);
}

URes = getContext()->urDdiTable.Context.pfnRelease(Handle);
// URes = getContext()->urDdiTable.Context.pfnRelease(Handle);
URes = getContext()->objectHandler.release(Handle);
assert(URes == UR_RESULT_SUCCESS);

// check memory leaks
Expand Down Expand Up @@ -944,7 +945,8 @@ AsanRuntimeDataWrapper::~AsanRuntimeDataWrapper() {

LaunchInfo::~LaunchInfo() {
[[maybe_unused]] ur_result_t Result;
Result = getContext()->urDdiTable.Context.pfnRelease(Context);
// Result = getContext()->urDdiTable.Context.pfnRelease(Context);
Result = getContext()->objectHandler.release(Context);
assert(Result == UR_RESULT_SUCCESS);
Result = getContext()->urDdiTable.Device.pfnRelease(Device);
assert(Result == UR_RESULT_SUCCESS);
Expand Down
7 changes: 3 additions & 4 deletions source/loader/layers/sanitizer/asan/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ struct ContextInfo {
AsanStatsWrapper Stats;

explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRetain(Context);
[[maybe_unused]] auto Result = getContext()->objectHandler.retain(Handle);
assert(Result == UR_RESULT_SUCCESS);
}

Expand Down Expand Up @@ -252,9 +251,9 @@ struct LaunchInfo {
this->LocalWorkSize =
std::vector<size_t>(LocalWorkSize, LocalWorkSize + WorkDim);
}
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRetain(Context);
[[maybe_unused]] auto Result = getContext()->objectHandler.retain(Context);
assert(Result == UR_RESULT_SUCCESS);

Result = getContext()->urDdiTable.Device.pfnRetain(Device);
assert(Result == UR_RESULT_SUCCESS);
}
Expand Down
13 changes: 10 additions & 3 deletions source/loader/layers/sanitizer/asan/asan_shadow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace asan {
std::shared_ptr<ShadowMemory> GetShadowMemory(ur_context_handle_t Context,
ur_device_handle_t Device,
DeviceType Type) {
getContext()->objectHandler.use(Context);
if (Type == DeviceType::CPU) {
static std::shared_ptr<ShadowMemory> ShadowCPU =
std::make_shared<ShadowMemoryCPU>(Context, Device);
Expand Down Expand Up @@ -109,6 +110,8 @@ ur_result_t ShadowMemoryGPU::Setup() {
// the SVM range, so that GFX driver will automatically switch to reservation on the GPU
// heap.
const void *StartAddress = (void *)(0x100'0000'0000'0000ULL);

getContext()->objectHandler.use(Context);
// TODO: Protect Bad Zone
auto Result = getContext()->urDdiTable.VirtualMem.pfnReserve(
Context, StartAddress, ShadowSize, (void **)&ShadowBegin);
Expand All @@ -120,7 +123,8 @@ ur_result_t ShadowMemoryGPU::Setup() {
}
ShadowEnd = ShadowBegin + ShadowSize;
// Retain the context which reserves shadow memory
getContext()->urDdiTable.Context.pfnRetain(Context);
// getContext()->urDdiTable.Context.pfnRetain(Context);
getContext()->objectHandler.retain(Context);

// Set shadow memory for null pointer
// For GPU, wu use up to 1 page of shadow memory
Expand All @@ -147,6 +151,7 @@ ur_result_t ShadowMemoryGPU::Destory() {
}

static ur_result_t Result = [this]() {
getContext()->objectHandler.use(Context);
const size_t PageSize = GetVirtualMemGranularity(Context, Device);
for (auto [MappedPtr, PhysicalMem] : VirtualMemMaps) {
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnUnmap(
Expand All @@ -156,7 +161,8 @@ ur_result_t ShadowMemoryGPU::Destory() {
}
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnFree(
Context, (const void *)ShadowBegin, GetShadowSize()));
UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
// UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
UR_CALL(getContext()->objectHandler.release(Context));
return UR_RESULT_SUCCESS;
}();
if (!Result) {
Expand All @@ -171,7 +177,8 @@ ur_result_t ShadowMemoryGPU::Destory() {
if (ShadowBegin != 0) {
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnFree(
Context, (const void *)ShadowBegin, GetShadowSize()));
UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
// UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
UR_CALL(getContext()->objectHandler.release(Context));
ShadowBegin = ShadowEnd = 0;
}
return UR_RESULT_SUCCESS;
Expand Down
Empty file.
116 changes: 116 additions & 0 deletions source/loader/layers/sanitizer/ur_obj_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
The UrObjectHandler is intend to provide a global maps for all UrObjects and corresponding XXXInfo objects that used with sanitizer layers.
Also, it provides a checker that checks for UrObjects' status to avoid any use-after-released cases.
*/

/**
* 20250107: first we impl this as a checker to check for use-after-released cases.
*/

#include "ur/ur.hpp"
#include "ur_api.h"

#include <atomic>
#include <cassert>
#include <unordered_map>
#include <variant>

#pragma once

namespace ur_sanitizer_layer {

typedef std::variant<ur_context_handle_t, ur_device_handle_t> UrObjectT;

class UrObjectHandler {
public:
void add(UrObjectT urObject) {
std::scoped_lock<ur_shared_mutex> Guard(urObjectStatusMapMutex);
// if (urObjectStatusMap.find(urObject) != urObjectStatusMap.end()) {
// if (urObjectStatusMap[urObject].refCount > 0) {
// assert(false && "Add of a exist object");
// } else {
// // remove an old object
// ; // Nothing to do for now as we only do ref-counting
// }
// }
assert(urObjectStatusMap.find(urObject) == urObjectStatusMap.end() &&
"Add of a exist object");
std::ignore = urObjectStatusMap[urObject];
}

ur_result_t retain(UrObjectT urObject) {
assert(ddiTableInstalled && "DdiTable is not installed");
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
"Retain of a nonexistent object");
urObjectStatusMap[urObject].retain();

if (std::holds_alternative<ur_context_handle_t>(urObject)) {
return urDdiTable.Context.pfnRetain(
std::get<ur_context_handle_t>(urObject));
} else if (std::holds_alternative<ur_device_handle_t>(urObject)) {
return urDdiTable.Device.pfnRetain(
std::get<ur_device_handle_t>(urObject));
}
assert(false && "Abonomal object type");
return UR_RESULT_SUCCESS;
}

ur_result_t release(UrObjectT urObject) {
assert(ddiTableInstalled && "DdiTable is not installed");
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
"Release of a nonexistent object");
urObjectStatusMap[urObject].release();
if (urObjectStatusMap[urObject].refCount == 0) {
std::scoped_lock<ur_shared_mutex> Guard(urObjectStatusMapMutex);
urObjectStatusMap.erase(urObject);
}

if (std::holds_alternative<ur_context_handle_t>(urObject)) {
return urDdiTable.Context.pfnRelease(
std::get<ur_context_handle_t>(urObject));
} else if (std::holds_alternative<ur_device_handle_t>(urObject)) {
return urDdiTable.Device.pfnRelease(
std::get<ur_device_handle_t>(urObject));
}
assert(false && "Abonomal object type");
return UR_RESULT_SUCCESS;
}

bool use(UrObjectT urObject) {
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
"Use of a nonexistent object");
return urObjectStatusMap[urObject].use();
}

void installDdiTable(ur_dditable_t *dditable) {
urDdiTable = *dditable;
ddiTableInstalled = true;
}

~UrObjectHandler() {
// Check for not released objects
}

private:
struct UrObjectInfo {
UrObjectInfo() : refCount(1) {}
std::atomic<int> refCount;

void retain() { refCount += 1; }
void release() {
assert(refCount > 0 && "Release of a invalid object");
refCount -= 1;
}
bool use() {
assert(refCount > 0 && "Use of a invalid object");
return refCount > 0;
}
};

ur_dditable_t urDdiTable;
bool ddiTableInstalled = false;
ur_shared_mutex urObjectStatusMapMutex;
std::unordered_map<UrObjectT, UrObjectInfo> urObjectStatusMap;
};

} // namespace ur_sanitizer_layer
2 changes: 2 additions & 0 deletions source/loader/layers/sanitizer/ur_sanitizer_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "logger/ur_logger.hpp"
#include "ur_proxy_layer.hpp"
#include "ur_obj_handler.hpp"

#define SANITIZER_COMP_NAME "sanitizer layer"

Expand All @@ -33,6 +34,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,
ur_dditable_t urDdiTable = {};
logger::Logger logger;
SanitizerType enabledType = SanitizerType::None;
UrObjectHandler objectHandler;

context_t();
~context_t();
Expand Down
Loading