diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index d6f3e0941..c4ca311f1 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -133,6 +133,8 @@ namespace ctranslate2 { return _size; } + dim_t item_size() const; + bool is_scalar() const { return _size == 1 && _shape.empty(); } diff --git a/src/storage_view.cc b/src/storage_view.cc index 093c9f04e..1535e65d5 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -118,9 +118,7 @@ namespace ctranslate2 { } dim_t StorageView::reserved_memory() const { - dim_t buffer_size = 0; - TYPE_DISPATCH(_dtype, buffer_size = _allocated_size * sizeof (T)); - return buffer_size; + return _allocated_size * item_size(); } StorageView& StorageView::clear() { @@ -143,10 +141,8 @@ namespace ctranslate2 { if (size <= _allocated_size) return *this; release(); - dim_t required_bytes = 0; - TYPE_DISPATCH(_dtype, required_bytes = size * sizeof (T)); _allocator = &get_allocator(_device); - _data = _allocator->allocate(required_bytes, _device_index); + _data = _allocator->allocate(size * item_size(), _device_index); if (_data == nullptr) THROW_RUNTIME_ERROR("failed to allocated memory"); _allocated_size = size; @@ -157,6 +153,12 @@ namespace ctranslate2 { return _allocator; } + dim_t StorageView::item_size() const { + dim_t size = 0; + TYPE_DISPATCH(_dtype, size = sizeof (T)); + return size; + } + StorageView& StorageView::reshape(Shape new_shape) { dim_t unknown_dim = -1; dim_t known_size = 1;