Skip to content

Commit

Permalink
Merge pull request #2620 from RossBrunton/ross/l0devirtual
Browse files Browse the repository at this point in the history
Remove virtual methods from ur_mem_handle_t_
  • Loading branch information
kbenzie authored Feb 4, 2025
2 parents be34bcb + 80fa413 commit 0e111ff
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 50 deletions.
55 changes: 42 additions & 13 deletions source/adapters/level_zero/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,11 +1747,12 @@ ur_result_t urMemRelease(
if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED)
return ze2urResult(ZeResult);
}
delete Image;
} else {
auto Buffer = reinterpret_cast<_ur_buffer *>(Mem);
Buffer->free();
delete Buffer;
}
delete Mem;

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -2081,10 +2082,11 @@ static ur_result_t ZeDeviceMemAllocHelper(void **ResultPtr,
return UR_RESULT_SUCCESS;
}

ur_result_t _ur_buffer::getZeHandle(char *&ZeHandle, access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t _ur_buffer::getBufferZeHandle(char *&ZeHandle,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {

// NOTE: There might be no valid allocation at all yet and we get
// here from piEnqueueKernelLaunch that would be doing the buffer
Expand Down Expand Up @@ -2393,7 +2395,7 @@ ur_result_t _ur_buffer::free() {
// Buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,
bool ImportedHostPtr = false)
: ur_mem_handle_t_(Context), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context), Size(Size) {

// We treat integrated devices (physical memory shared with the CPU)
// differently from discrete devices (those with distinct memories).
Expand Down Expand Up @@ -2422,13 +2424,13 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, char *HostPtr,

_ur_buffer::_ur_buffer(ur_context_handle_t Context, ur_device_handle_t Device,
size_t Size)
: ur_mem_handle_t_(Context, Device), Size(Size) {}
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {}

// Interop-buffer constructor
_ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
ur_device_handle_t Device, char *ZeMemHandle,
bool OwnZeMemHandle)
: ur_mem_handle_t_(Context, Device), Size(Size) {
: ur_mem_handle_t_(mem_type_t::buffer, Context, Device), Size(Size) {

// Device == nullptr means host allocation
Allocations[Device].ZeHandle = ZeMemHandle;
Expand All @@ -2449,11 +2451,38 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size,
LastDeviceWithValidAllocation = Device;
}

ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr,
access_mode_t AccessMode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
ur_result_t ur_mem_handle_t_::getZeHandle(char *&ZeHandle, access_mode_t mode,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandle(
ZeHandle, mode, Device, phWaitEvents, numWaitEvents);
}
ur::unreachable();
}

ur_result_t ur_mem_handle_t_::getZeHandlePtr(
char **&ZeHandlePtr, access_mode_t mode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
switch (mem_type) {
case ur_mem_handle_t_::image:
return reinterpret_cast<_ur_image *>(this)->getImageZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
case ur_mem_handle_t_::buffer:
return reinterpret_cast<_ur_buffer *>(this)->getBufferZeHandlePtr(
ZeHandlePtr, mode, Device, phWaitEvents, numWaitEvents);
}
ur::unreachable();
}

ur_result_t _ur_buffer::getBufferZeHandlePtr(
char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents, uint32_t numWaitEvents) {
char *ZeHandle;
UR_CALL(
getZeHandle(ZeHandle, AccessMode, Device, phWaitEvents, numWaitEvents));
Expand Down
78 changes: 41 additions & 37 deletions source/adapters/level_zero/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,34 +70,41 @@ struct ur_mem_handle_t_ : _ur_object {
// Keeps device of this memory handle
ur_device_handle_t UrDevice;

// Whether this is an image or buffer
enum mem_type_t { image, buffer };
mem_type_t mem_type;

// Enumerates all possible types of accesses.
enum access_mode_t { unknown, read_write, read_only, write_only };

// Interface of the _ur_mem object

// Get the Level Zero handle of the current memory object
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Get a pointer to the Level Zero handle of the current memory object
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) = 0;
ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

// Method to get type of the derived object (image or buffer)
virtual bool isImage() const = 0;

virtual ~ur_mem_handle_t_() = default;
bool isImage() const { return mem_type == mem_type_t::image; }

protected:
ur_mem_handle_t_(ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr} {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context)
: UrContext{Context}, UrDevice{nullptr}, mem_type(type) {}

ur_mem_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device) {}
ur_mem_handle_t_(mem_type_t type, ur_context_handle_t Context,
ur_device_handle_t Device)
: UrContext{Context}, UrDevice(Device), mem_type(type) {}

// Since the destructor isn't virtual, callers must destruct it via _ur_buffer
// or _ur_image
~ur_mem_handle_t_() {};
};

struct _ur_buffer final : ur_mem_handle_t_ {
Expand All @@ -110,7 +117,7 @@ struct _ur_buffer final : ur_mem_handle_t_ {

// Sub-buffer constructor
_ur_buffer(_ur_buffer *Parent, size_t Origin, size_t Size)
: ur_mem_handle_t_(Parent->UrContext), Size(Size),
: ur_mem_handle_t_(mem_type_t::buffer, Parent->UrContext), Size(Size),
SubBuffer{{Parent, Origin}} {
// Retain the Parent Buffer due to the Creation of the SubBuffer.
Parent->RefCount.increment();
Expand All @@ -127,16 +134,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
// up-to-date and any data copies needed for that are performed under
// the hood.
//
virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override;
ur_result_t getBufferZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);
ur_result_t getBufferZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t Device,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents);

bool isImage() const override { return false; }
bool isSubBuffer() const { return SubBuffer != std::nullopt; }

// Frees all allocations made for the buffer.
Expand Down Expand Up @@ -206,35 +212,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
struct _ur_image final : ur_mem_handle_t_ {
// Image constructor
_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {}

_ur_image(ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
bool OwnZeMemHandle)
: ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
: ur_mem_handle_t_(mem_type_t::image, UrContext), ZeImage{ZeImage} {
OwnNativeHandle = OwnZeMemHandle;
}

virtual ur_result_t getZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandle(char *&ZeHandle, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandle = reinterpret_cast<char *>(ZeImage);
return UR_RESULT_SUCCESS;
}
virtual ur_result_t getZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) override {
ur_result_t getImageZeHandlePtr(char **&ZeHandlePtr, access_mode_t,
ur_device_handle_t,
const ur_event_handle_t *phWaitEvents,
uint32_t numWaitEvents) {
std::ignore = phWaitEvents;
std::ignore = numWaitEvents;
ZeHandlePtr = reinterpret_cast<char **>(&ZeImage);
return UR_RESULT_SUCCESS;
}

bool isImage() const override { return true; }

// Keep the descriptor of the image
ZeStruct<ze_image_desc_t> ZeImageDesc;

Expand Down

0 comments on commit 0e111ff

Please sign in to comment.