Skip to content

Commit

Permalink
[ET-VK] Removing unnecessary and redundant members from StagingBuffer.
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#7220

This diff removes unnecessary and redundant members from the StagingBuffer class in the Vulkan runtime API. Specifically, the `numel_` and `nbytes_` members are removed, as they can be calculated from the `dtype_` member. This simplifies the class and reduces the amount of memory used.
ghstack-source-id: 257227238
@exported-using-ghexport

Differential Revision: [D66742613](https://our.internmc.facebook.com/intern/diff/D66742613/)

Co-authored-by: Vivek Trivedi <[email protected]>
  • Loading branch information
pytorchbot and trivedivivek authored Dec 10, 2024
1 parent 893f690 commit 28ad3f2
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions backends/vulkan/runtime/api/containers/StagingBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class StagingBuffer final {
private:
Context* context_p_;
vkapi::ScalarType dtype_;
size_t numel_;
size_t nbytes_;
vkapi::VulkanBuffer vulkan_buffer_;

void* mapped_data_;
Expand All @@ -36,10 +34,8 @@ class StagingBuffer final {
const size_t numel)
: context_p_(context_p),
dtype_(dtype),
numel_(numel),
nbytes_(element_size(dtype_) * numel_),
vulkan_buffer_(
context_p_->adapter_ptr()->vma().create_staging_buffer(nbytes_)),
vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer(
element_size(dtype_) * numel)),
mapped_data_(nullptr) {}

StagingBuffer(const StagingBuffer&) = delete;
Expand Down Expand Up @@ -68,15 +64,15 @@ class StagingBuffer final {
}

inline size_t numel() {
return numel_;
return nbytes() / element_size(dtype_);
}

inline size_t nbytes() {
return nbytes_;
return vulkan_buffer_.mem_size();
}

inline void copy_from(const void* src, const size_t nbytes) {
VK_CHECK_COND(nbytes <= nbytes_);
VK_CHECK_COND(nbytes <= this->nbytes());
memcpy(data(), src, nbytes);
vmaFlushAllocation(
vulkan_buffer_.vma_allocator(),
Expand All @@ -86,7 +82,7 @@ class StagingBuffer final {
}

inline void copy_to(void* dst, const size_t nbytes) {
VK_CHECK_COND(nbytes <= nbytes_);
VK_CHECK_COND(nbytes <= this->nbytes());
vmaInvalidateAllocation(
vulkan_buffer_.vma_allocator(),
vulkan_buffer_.allocation(),
Expand All @@ -96,7 +92,7 @@ class StagingBuffer final {
}

inline void set_staging_zeros() {
memset(data(), 0, nbytes_);
memset(data(), 0, nbytes());
}
};

Expand Down

0 comments on commit 28ad3f2

Please sign in to comment.