Skip to content

Commit

Permalink
Added additional enums to urPhysicalMemGetInfo along with splitting t…
Browse files Browse the repository at this point in the history
…he CTS test out from a switch statement to separate tests.

Fixed a typo in memory.yml.
  • Loading branch information
martygrant committed Nov 6, 2024
1 parent 6a96f52 commit e825f58
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 38 deletions.
11 changes: 9 additions & 2 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2516,7 +2516,7 @@ typedef enum ur_mem_type_t {
///////////////////////////////////////////////////////////////////////////////
/// @brief Memory Information type
typedef enum ur_mem_info_t {
UR_MEM_INFO_SIZE = 0, ///< [size_t] actual size of of memory object in bytes
UR_MEM_INFO_SIZE = 0, ///< [size_t] actual size of the memory object in bytes
UR_MEM_INFO_CONTEXT = 1, ///< [::ur_context_handle_t] context in which the memory object was created
/// @cond
UR_MEM_INFO_FORCE_UINT32 = 0x7fffffff
Expand Down Expand Up @@ -4122,7 +4122,14 @@ urPhysicalMemRelease(
///////////////////////////////////////////////////////////////////////////////
/// @brief Physical memory range info queries.
typedef enum ur_physical_mem_info_t {
UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT = 0, ///< [uint32_t] Reference count of the physical memory object.
UR_PHYSICAL_MEM_INFO_CONTEXT = 0, ///< [::ur_context_handle_t] context in which the physical memory object
///< was created.
UR_PHYSICAL_MEM_INFO_DEVICE = 1, ///< [::ur_device_handle_t] device associated with this physical memory
///< object.
UR_PHYSICAL_MEM_INFO_SIZE = 2, ///< [size_t] actual size of the physical memory object in bytes.
UR_PHYSICAL_MEM_INFO_PROPERTIES = 3, ///< [::ur_physical_mem_properties_t] properties set when creating this
///< physical memory object.
UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT = 4, ///< [uint32_t] Reference count of the physical memory object.
///< The reference count returned should be considered immediately stale.
///< It is unsuitable for general use in applications. This feature is
///< provided for identifying memory leaks.
Expand Down
62 changes: 62 additions & 0 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7401,6 +7401,18 @@ inline std::ostream &operator<<(std::ostream &os, const struct ur_physical_mem_p
/// std::ostream &
inline std::ostream &operator<<(std::ostream &os, enum ur_physical_mem_info_t value) {
switch (value) {
case UR_PHYSICAL_MEM_INFO_CONTEXT:
os << "UR_PHYSICAL_MEM_INFO_CONTEXT";
break;
case UR_PHYSICAL_MEM_INFO_DEVICE:
os << "UR_PHYSICAL_MEM_INFO_DEVICE";
break;
case UR_PHYSICAL_MEM_INFO_SIZE:
os << "UR_PHYSICAL_MEM_INFO_SIZE";
break;
case UR_PHYSICAL_MEM_INFO_PROPERTIES:
os << "UR_PHYSICAL_MEM_INFO_PROPERTIES";
break;
case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT:
os << "UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT";
break;
Expand All @@ -7420,6 +7432,56 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_physical_me
}

switch (value) {
case UR_PHYSICAL_MEM_INFO_CONTEXT: {
const ur_context_handle_t *tptr = (const ur_context_handle_t *)ptr;
if (sizeof(ur_context_handle_t) > size) {
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_context_handle_t) << ")";
return UR_RESULT_ERROR_INVALID_SIZE;
}
os << (const void *)(tptr) << " (";

ur::details::printPtr(os,
*tptr);

os << ")";
} break;
case UR_PHYSICAL_MEM_INFO_DEVICE: {
const ur_device_handle_t *tptr = (const ur_device_handle_t *)ptr;
if (sizeof(ur_device_handle_t) > size) {
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_device_handle_t) << ")";
return UR_RESULT_ERROR_INVALID_SIZE;
}
os << (const void *)(tptr) << " (";

ur::details::printPtr(os,
*tptr);

os << ")";
} break;
case UR_PHYSICAL_MEM_INFO_SIZE: {
const size_t *tptr = (const size_t *)ptr;
if (sizeof(size_t) > size) {
os << "invalid size (is: " << size << ", expected: >=" << sizeof(size_t) << ")";
return UR_RESULT_ERROR_INVALID_SIZE;
}
os << (const void *)(tptr) << " (";

os << *tptr;

os << ")";
} break;
case UR_PHYSICAL_MEM_INFO_PROPERTIES: {
const ur_physical_mem_properties_t *tptr = (const ur_physical_mem_properties_t *)ptr;
if (sizeof(ur_physical_mem_properties_t) > size) {
os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_physical_mem_properties_t) << ")";
return UR_RESULT_ERROR_INVALID_SIZE;
}
os << (const void *)(tptr) << " (";

os << *tptr;

os << ")";
} break;
case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: {
const uint32_t *tptr = (const uint32_t *)ptr;
if (sizeof(uint32_t) > size) {
Expand Down
2 changes: 1 addition & 1 deletion scripts/core/memory.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ name: $x_mem_info_t
typed_etors: True
etors:
- name: SIZE
desc: "[size_t] actual size of of memory object in bytes"
desc: "[size_t] actual size of the memory object in bytes"
- name: CONTEXT
desc: "[$x_context_handle_t] context in which the memory object was created"
--- #--------------------------------------------------------------------------
Expand Down
9 changes: 8 additions & 1 deletion scripts/core/virtual_memory.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,14 @@ class: $xPhysicalMem
name: $x_physical_mem_info_t
typed_etors: True
etors:
# properties and size too
- name: CONTEXT
desc: "[$x_context_handle_t] context in which the physical memory object was created."
- name: DEVICE
desc: "[$x_device_handle_t] device associated with this physical memory object."
- name: SIZE
desc: "[size_t] actual size of the physical memory object in bytes."
- name: PROPERTIES
desc: "[$x_physical_mem_properties_t] properties set when creating this physical memory object."
- name: REFERENCE_COUNT
desc: |
[uint32_t] Reference count of the physical memory object.
Expand Down
16 changes: 14 additions & 2 deletions source/adapters/cuda/physical_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemCreate(
UR_CHECK_ERROR(Result);
}
try {
*phPhysicalMem =
new ur_physical_mem_handle_t_(ResHandle, hContext, hDevice);
*phPhysicalMem = new ur_physical_mem_handle_t_(ResHandle, hContext, hDevice,
size, *pProperties);
} catch (std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand Down Expand Up @@ -74,6 +74,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemGetInfo(
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

switch (propName) {
case UR_PHYSICAL_MEM_INFO_CONTEXT: {
return ReturnValue(hPhysicalMem->getContext());
}
case UR_PHYSICAL_MEM_INFO_DEVICE: {
return ReturnValue(hPhysicalMem->getDevice());
}
case UR_PHYSICAL_MEM_INFO_SIZE: {
return ReturnValue(hPhysicalMem->getSize());
}
case UR_PHYSICAL_MEM_INFO_PROPERTIES: {
return ReturnValue(hPhysicalMem->getProperties());
}
case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: {
return ReturnValue(hPhysicalMem->getReferenceCount());
}
Expand Down
11 changes: 9 additions & 2 deletions source/adapters/cuda/physical_mem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ struct ur_physical_mem_handle_t_ {
native_type PhysicalMem;
ur_context_handle_t_ *Context;
ur_device_handle_t Device;
size_t Size;
ur_physical_mem_properties_t Properties;

ur_physical_mem_handle_t_(native_type PhysMem, ur_context_handle_t_ *Ctx,
ur_device_handle_t Device)
: RefCount(1), PhysicalMem(PhysMem), Context(Ctx), Device(Device) {
ur_device_handle_t Device, size_t Size, ur_physical_mem_properties_t Properties)
: RefCount(1), PhysicalMem(PhysMem), Context(Ctx), Device(Device),
Size(Size), Properties(Properties) {
urContextRetain(Context);
urDeviceRetain(Device);
}
Expand All @@ -51,4 +54,8 @@ struct ur_physical_mem_handle_t_ {
uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

size_t getSize() const noexcept { return Size; }

ur_physical_mem_properties_t getProperties() const noexcept { return Properties; }
};
45 changes: 45 additions & 0 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2547,10 +2547,55 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemGetInfo(
hPhysicalMem =
reinterpret_cast<ur_physical_mem_object_t *>(hPhysicalMem)->handle;

// this value is needed for converting adapter handles to loader handles
size_t sizeret = 0;
if (pPropSizeRet == NULL) {
pPropSizeRet = &sizeret;
}

// forward to device-platform
result =
pfnGetInfo(hPhysicalMem, propName, propSize, pPropValue, pPropSizeRet);

if (UR_RESULT_SUCCESS != result) {
return result;
}

try {
if (pPropValue != nullptr) {
switch (propName) {
case UR_PHYSICAL_MEM_INFO_CONTEXT: {
ur_context_handle_t *handles =
reinterpret_cast<ur_context_handle_t *>(pPropValue);
size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t);
for (size_t i = 0; i < nelements; ++i) {
if (handles[i] != nullptr) {
handles[i] = reinterpret_cast<ur_context_handle_t>(
context->factories.ur_context_factory.getInstance(
handles[i], dditable));
}
}
} break;
case UR_PHYSICAL_MEM_INFO_DEVICE: {
ur_device_handle_t *handles =
reinterpret_cast<ur_device_handle_t *>(pPropValue);
size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t);
for (size_t i = 0; i < nelements; ++i) {
if (handles[i] != nullptr) {
handles[i] = reinterpret_cast<ur_device_handle_t>(
context->factories.ur_device_factory.getInstance(
handles[i], dditable));
}
}
} break;
default: {
} break;
}
}
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand Down
15 changes: 8 additions & 7 deletions test/conformance/testing/include/uur/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -935,13 +935,9 @@ struct urPhysicalMemTest : urVirtualMemGranularityTest {
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(urVirtualMemGranularityTest::SetUp());
size = granularity * 256;
ur_physical_mem_properties_t props{
UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES,
nullptr,
0 /*flags*/,
};
ASSERT_SUCCESS(
urPhysicalMemCreate(context, device, size, &props, &physical_mem));

ASSERT_SUCCESS(urPhysicalMemCreate(context, device, size, &properties,
&physical_mem));
ASSERT_NE(physical_mem, nullptr);
}

Expand All @@ -954,6 +950,11 @@ struct urPhysicalMemTest : urVirtualMemGranularityTest {

size_t size = 0;
ur_physical_mem_handle_t physical_mem = nullptr;
ur_physical_mem_properties_t properties{
UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES,
nullptr,
0 /*flags*/,
};
};

template <class T>
Expand Down
104 changes: 81 additions & 23 deletions test/conformance/virtual_memory/urPhysicalMemGetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,90 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <uur/fixtures.h>

using urPhysicalMemGetInfoWithFlagsParamTest =
uur::urPhysicalMemTestWithParam<ur_physical_mem_info_t>;
UUR_TEST_SUITE_P(urPhysicalMemGetInfoWithFlagsParamTest,
::testing::Values(UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT),
uur::deviceTestWithParamPrinter<ur_physical_mem_info_t>);
using urPhysicalMemGetInfoTest = uur::urPhysicalMemTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urPhysicalMemGetInfoTest);

TEST_P(urPhysicalMemGetInfoWithFlagsParamTest, Success) {
TEST_P(urPhysicalMemGetInfoTest, Context) {
size_t info_size = 0;
ur_physical_mem_info_t info = getParam();
ASSERT_SUCCESS(
urPhysicalMemGetInfo(physical_mem, info, 0, nullptr, &info_size));

ASSERT_SUCCESS(urPhysicalMemGetInfo(
physical_mem, UR_PHYSICAL_MEM_INFO_CONTEXT, 0, nullptr, &info_size));
ASSERT_NE(info_size, 0);

std::vector<uint8_t> data(info_size);
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem,
UR_PHYSICAL_MEM_INFO_CONTEXT,
data.size(), data.data(), nullptr));

auto returned_context =
reinterpret_cast<ur_context_handle_t *>(data.data());
ASSERT_EQ(context, *returned_context);
}

TEST_P(urPhysicalMemGetInfoTest, Device) {
size_t info_size = 0;

ASSERT_SUCCESS(urPhysicalMemGetInfo(
physical_mem, UR_PHYSICAL_MEM_INFO_DEVICE, 0, nullptr, &info_size));
ASSERT_NE(info_size, 0);

std::vector<uint8_t> data(info_size);
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem,
UR_PHYSICAL_MEM_INFO_DEVICE,
data.size(), data.data(), nullptr));

auto returned_device = reinterpret_cast<ur_device_handle_t *>(data.data());
ASSERT_EQ(device, *returned_device);
}

TEST_P(urPhysicalMemGetInfoTest, Size) {
size_t info_size = 0;

ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem, UR_PHYSICAL_MEM_INFO_SIZE,
0, nullptr, &info_size));
ASSERT_NE(info_size, 0);

std::vector<uint8_t> data(info_size);
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem, info, data.size(),
data.data(), nullptr));

switch (info) {
case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: {
const size_t ReferenceCount =
*reinterpret_cast<const uint32_t *>(data.data());
ASSERT_EQ(ReferenceCount, 1);
} break;

default:
FAIL() << "Unhandled ur_physical_mem_info_t enumeration: " << info;
break;
}
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem, UR_PHYSICAL_MEM_INFO_SIZE,
data.size(), data.data(), nullptr));

auto returned_size = reinterpret_cast<size_t *>(data.data());
ASSERT_EQ(size, *returned_size);
}

TEST_P(urPhysicalMemGetInfoTest, Properties) {
size_t info_size = 0;

ASSERT_SUCCESS(urPhysicalMemGetInfo(
physical_mem, UR_PHYSICAL_MEM_INFO_PROPERTIES, 0, nullptr, &info_size));
ASSERT_NE(info_size, 0);

std::vector<uint8_t> data(info_size);
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem,
UR_PHYSICAL_MEM_INFO_PROPERTIES,
data.size(), data.data(), nullptr));

auto returned_properties =
reinterpret_cast<ur_physical_mem_properties_t *>(data.data());
ASSERT_EQ(properties.stype, returned_properties->stype);
ASSERT_EQ(properties.pNext, returned_properties->pNext);
ASSERT_EQ(properties.flags, returned_properties->flags);
}

TEST_P(urPhysicalMemGetInfoTest, ReferenceCount) {
size_t info_size = 0;

ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem,
UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT, 0,
nullptr, &info_size));
ASSERT_NE(info_size, 0);

std::vector<uint8_t> data(info_size);
ASSERT_SUCCESS(urPhysicalMemGetInfo(physical_mem,
UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT,
data.size(), data.data(), nullptr));

const size_t ReferenceCount =
*reinterpret_cast<const uint32_t *>(data.data());
ASSERT_EQ(ReferenceCount, 1);
}

0 comments on commit e825f58

Please sign in to comment.