diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index 0f95adacf10c..9c2d4af8c454 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -308,6 +308,13 @@ impl Array for BooleanArray { self.values.is_empty() } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { self.values.offset() } diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index bec0caab1045..f2b22507081d 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -453,6 +453,14 @@ impl Array for GenericByteArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + self.value_offsets.shrink_to_fit(); + self.value_data.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 81bb6a38550b..9d2d396a5266 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -430,31 +430,31 @@ impl GenericByteViewArray { /// /// Before GC: /// ```text - /// ┌──────┐ - /// │......│ - /// │......│ - /// ┌────────────────────┐ ┌ ─ ─ ─ ▶ │Data1 │ Large buffer + /// ┌──────┐ + /// │......│ + /// │......│ + /// ┌────────────────────┐ ┌ ─ ─ ─ ▶ │Data1 │ Large buffer /// │ View 1 │─ ─ ─ ─ │......│ with data that /// ├────────────────────┤ │......│ is not referred /// │ View 2 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data2 │ to by View 1 or - /// └────────────────────┘ │......│ View 2 - /// │......│ - /// 2 views, refer to │......│ - /// small portions of a └──────┘ - /// large buffer + /// └────────────────────┘ │......│ View 2 + /// │......│ + /// 2 views, refer to │......│ + /// small portions of a └──────┘ + /// large buffer /// ``` - /// + /// /// After GC: /// /// ```text /// ┌────────────────────┐ ┌─────┐ After gc, only - /// │ View 1 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data1│ data that is - /// ├────────────────────┤ ┌ ─ ─ ─ ▶ │Data2│ pointed to by - /// │ View 2 │─ ─ ─ ─ └─────┘ the views is - /// └────────────────────┘ left - /// - /// - /// 2 views + /// │ View 1 │─ ─ ─ ─ ─ ─ ─ ─▶ │Data1│ data that is + /// ├────────────────────┤ ┌ ─ ─ ─ ▶ │Data2│ pointed to by + /// │ View 2 │─ ─ ─ ─ └─────┘ the views is + /// └────────────────────┘ left + /// + /// + /// 2 views /// ``` /// This method will compact the data buffers by recreating the view array and only include the data /// that is pointed to by the views. @@ -575,6 +575,15 @@ impl Array for GenericByteViewArray { self.views.is_empty() } + fn shrink_to_fit(&mut self) { + self.views.shrink_to_fit(); + self.buffers.iter_mut().for_each(|b| b.shrink_to_fit()); + self.buffers.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 1187e16769a0..988bdbc7c9b4 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -720,6 +720,11 @@ impl Array for DictionaryArray { self.keys.is_empty() } + fn shrink_to_fit(&mut self) { + self.keys.shrink_to_fit(); + self.values.shrink_to_fit(); + } + fn offset(&self) -> usize { self.keys.offset() } diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index 8f1489ee4c3c..25fe2e3dfe88 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -602,6 +602,13 @@ impl Array for FixedSizeBinaryArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + self.value_data.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 00a3144a87ad..2e784343b064 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -401,6 +401,13 @@ impl Array for FixedSizeListArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 1fab0009f2cc..ee02b887720b 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -485,6 +485,14 @@ impl Array for GenericListArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.values.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index 4e949a642701..c59b178f7b2f 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -326,6 +326,15 @@ impl Array for GenericListViewArray { self.value_sizes.is_empty() } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.values.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + self.value_sizes.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 254437630a44..18a7c491aa16 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -372,6 +372,14 @@ impl Array for MapArray { self.value_offsets.len() <= 1 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.entries.shrink_to_fit(); + self.value_offsets.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 04d9883f5bd8..456b3acac009 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -167,6 +167,12 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// ``` fn is_empty(&self) -> bool; + /// Shrinks the capacity of any exclusively owned buffer as much as possible + /// + /// Shared or externally allocated buffers will be ignored, and + /// any buffer offsets will be preserved. + fn shrink_to_fit(&mut self) {} + /// Returns the offset into the underlying data used by this array(-slice). /// Note that the underlying data can be shared by many arrays. /// This defaults to `0`. @@ -366,6 +372,15 @@ impl Array for ArrayRef { self.as_ref().is_empty() } + /// For shared buffers, this is a no-op. + fn shrink_to_fit(&mut self) { + if let Some(slf) = Arc::get_mut(self) { + slf.shrink_to_fit(); + } else { + // We ignore shared buffers. + } + } + fn offset(&self) -> usize { self.as_ref().offset() } diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 7b0d6c5ca1b6..bb7413bb859e 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -1152,6 +1152,13 @@ impl Array for PrimitiveArray { self.values.is_empty() } + fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index dc4e6c96d9da..b340bf9a9065 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -330,6 +330,11 @@ impl Array for RunArray { self.run_ends.is_empty() } + fn shrink_to_fit(&mut self) { + self.run_ends.shrink_to_fit(); + self.values.shrink_to_fit(); + } + fn offset(&self) -> usize { self.run_ends.offset() } diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 41eb8235e540..16deba1063d7 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -370,6 +370,13 @@ impl Array for StructArray { self.len == 0 } + fn shrink_to_fit(&mut self) { + if let Some(nulls) = &mut self.nulls { + nulls.shrink_to_fit(); + } + self.fields.iter_mut().for_each(|n| n.shrink_to_fit()); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 3c6da5a7b5c0..43019c659f0a 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -744,6 +744,17 @@ impl Array for UnionArray { self.type_ids.is_empty() } + fn shrink_to_fit(&mut self) { + self.type_ids.shrink_to_fit(); + if let Some(offsets) = &mut self.offsets { + offsets.shrink_to_fit(); + } + for array in self.fields.iter_mut().flatten() { + array.shrink_to_fit(); + } + self.fields.shrink_to_fit(); + } + fn offset(&self) -> usize { 0 } diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs index 49a75b468dbe..e2c892315b5f 100644 --- a/arrow-buffer/src/buffer/boolean.rs +++ b/arrow-buffer/src/buffer/boolean.rs @@ -125,6 +125,12 @@ impl BooleanBuffer { self.len == 0 } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + // TODO(emilk): we could shrink even more in the case where we are a small sub-slice of the full buffer + self.buffer.shrink_to_fit(); + } + /// Returns the boolean value at index `i`. /// /// # Panics diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 8d1a46583fca..820ad04bf61a 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -167,6 +167,41 @@ impl Buffer { self.data.capacity() } + /// Tried to shrink the capacity of the buffer as much as possible, freeing unused memory. + /// + /// If the buffer is shared, this is a no-op. + /// + /// If the memory was allocated with a custom allocator, this is a no-op. + /// + /// If the capacity is already less than or equal to the desired capacity, this is a no-op. + /// + /// The memory region will be reallocated using `std::alloc::realloc`. + pub fn shrink_to_fit(&mut self) { + let offset = self.ptr_offset(); + let is_empty = self.is_empty(); + let desired_capacity = if is_empty { + 0 + } else { + // For realloc to work, we cannot free the elements before the offset + offset + self.len() + }; + if desired_capacity < self.capacity() { + if let Some(bytes) = Arc::get_mut(&mut self.data) { + if bytes.try_realloc(desired_capacity).is_ok() { + // Realloc complete - update our pointer into `bytes`: + self.ptr = if is_empty { + bytes.as_ptr() + } else { + // SAFETY: we kept all elements leading up to the offset + unsafe { bytes.as_ptr().add(offset) } + } + } else { + // Failure to reallocate is fine; we just failed to free up memory. + } + } + } + } + /// Returns whether the buffer is empty. #[inline] pub fn is_empty(&self) -> bool { @@ -562,6 +597,34 @@ mod tests { assert_eq!(buf2.slice_with_length(2, 1).as_slice(), &[10]); } + #[test] + fn test_shrink_to_fit() { + let original = Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(original.as_slice(), &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(original.capacity(), 64); + + let slice = original.slice_with_length(2, 3); + drop(original); // Make sure the buffer isn't shared (or shrink_to_fit won't work) + assert_eq!(slice.as_slice(), &[2, 3, 4]); + assert_eq!(slice.capacity(), 64); + + let mut shrunk = slice; + shrunk.shrink_to_fit(); + assert_eq!(shrunk.as_slice(), &[2, 3, 4]); + assert_eq!(shrunk.capacity(), 5); // shrink_to_fit is allowed to keep the elements before the offset + + // Test that we can handle empty slices: + let empty_slice = shrunk.slice_with_length(1, 0); + drop(shrunk); // Make sure the buffer isn't shared (or shrink_to_fit won't work) + assert_eq!(empty_slice.as_slice(), &[]); + assert_eq!(empty_slice.capacity(), 5); + + let mut shrunk_empty = empty_slice; + shrunk_empty.shrink_to_fit(); + assert_eq!(shrunk_empty.as_slice(), &[]); + assert_eq!(shrunk_empty.capacity(), 0); + } + #[test] #[should_panic(expected = "the offset of the new Buffer cannot exceed the existing length")] fn test_slice_offset_out_of_bound() { diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index 7fcbd89dd262..285c7b10ef4c 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -483,10 +483,13 @@ impl MutableBuffer { } } +/// Creates a non-null pointer with alignment of [`ALIGNMENT`] +/// +/// This is similar to [`NonNull::dangling`] #[inline] -fn dangling_ptr() -> NonNull { - // SAFETY: ALIGNMENT is a non-zero usize which is then casted - // to a *mut T. Therefore, `ptr` is not null and the conditions for +pub(crate) fn dangling_ptr() -> NonNull { + // SAFETY: ALIGNMENT is a non-zero usize which is then cast + // to a *mut u8. Therefore, `ptr` is not null and the conditions for // calling new_unchecked() are respected. #[cfg(miri)] { diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs index c79aef398059..137d900ac8fa 100644 --- a/arrow-buffer/src/buffer/null.rs +++ b/arrow-buffer/src/buffer/null.rs @@ -130,6 +130,11 @@ impl NullBuffer { self.buffer.is_empty() } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + /// Returns the null count for this [`NullBuffer`] #[inline] pub fn null_count(&self) -> usize { diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index e9087d30098c..a6be2b67af84 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -133,6 +133,11 @@ impl OffsetBuffer { Self(out.into()) } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit(); + } + /// Returns the inner [`ScalarBuffer`] pub fn inner(&self) -> &ScalarBuffer { &self.0 diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs index 3dbbe344a025..cc6d19044feb 100644 --- a/arrow-buffer/src/buffer/run.rs +++ b/arrow-buffer/src/buffer/run.rs @@ -136,6 +136,12 @@ where self.len == 0 } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + // TODO(emilk): we could shrink even more in the case where we are a small sub-slice of the full buffer + self.run_ends.shrink_to_fit(); + } + /// Returns the values of this [`RunEndBuffer`] not including any offset #[inline] pub fn values(&self) -> &[E] { diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 343b8549e93d..ab6c87168e5c 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -72,6 +72,11 @@ impl ScalarBuffer { buffer.slice_with_length(byte_offset, byte_len).into() } + /// Free up unused memory. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + /// Returns a zero-copy slice of this buffer with length `len` and starting at `offset` pub fn slice(&self, offset: usize, len: usize) -> Self { Self::new(self.buffer.clone(), offset, len) diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index ba61342d8e39..77724137aef7 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -24,6 +24,7 @@ use std::ptr::NonNull; use std::{fmt::Debug, fmt::Formatter}; use crate::alloc::Deallocation; +use crate::buffer::dangling_ptr; /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. /// @@ -96,6 +97,48 @@ impl Bytes { } } + /// Try to reallocate the underlying memory region to a new size (smaller or larger). + /// + /// Only works for bytes allocated with the standard allocator. + /// Returns `Err` if the memory was allocated with a custom allocator, + /// or the call to `realloc` failed, for whatever reason. + /// In case of `Err`, the [`Bytes`] will remain as it was (i.e. have the old size). + pub fn try_realloc(&mut self, new_len: usize) -> Result<(), ()> { + if let Deallocation::Standard(old_layout) = self.deallocation { + if old_layout.size() == new_len { + return Ok(()); // Nothing to do + } + + if let Ok(new_layout) = std::alloc::Layout::from_size_align(new_len, old_layout.align()) + { + let old_ptr = self.ptr.as_ptr(); + + let new_ptr = match new_layout.size() { + 0 => { + // SAFETY: Verified that old_layout.size != new_len (0) + unsafe { std::alloc::dealloc(self.ptr.as_ptr(), old_layout) }; + Some(dangling_ptr()) + } + // SAFETY: the call to `realloc` is safe if all the following hold (from https://doc.rust-lang.org/stable/std/alloc/trait.GlobalAlloc.html#method.realloc): + // * `old_ptr` must be currently allocated via this allocator (guaranteed by the invariant/contract of `Bytes`) + // * `old_layout` must be the same layout that was used to allocate that block of memory (same) + // * `new_len` must be greater than zero + // * `new_len`, when rounded up to the nearest multiple of `layout.align()`, must not overflow `isize` (guaranteed by the success of `Layout::from_size_align`) + _ => NonNull::new(unsafe { std::alloc::realloc(old_ptr, old_layout, new_len) }), + }; + + if let Some(ptr) = new_ptr { + self.ptr = ptr; + self.len = new_len; + self.deallocation = Deallocation::Standard(new_layout); + return Ok(()); + } + } + } + + Err(()) + } + #[inline] pub(crate) fn deallocation(&self) -> &Deallocation { &self.deallocation diff --git a/arrow/tests/shrink_to_fit.rs b/arrow/tests/shrink_to_fit.rs new file mode 100644 index 000000000000..5d7c2cf98bc9 --- /dev/null +++ b/arrow/tests/shrink_to_fit.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, ListArray, PrimitiveArray}, + buffer::OffsetBuffer, + datatypes::{Field, UInt8Type}, +}; + +/// Test that `shrink_to_fit` frees memory after concatenating a large number of arrays. +#[test] +fn test_shrink_to_fit_after_concat() { + let array_len = 6_000; + let num_concats = 100; + + let primitive_array: PrimitiveArray = (0..array_len) + .map(|v| (v % 255) as u8) + .collect::>() + .into(); + let primitive_array: ArrayRef = Arc::new(primitive_array); + + let list_array: ArrayRef = Arc::new(ListArray::new( + Field::new_list_field(primitive_array.data_type().clone(), false).into(), + OffsetBuffer::from_lengths([primitive_array.len()]), + primitive_array.clone(), + None, + )); + + // Num bytes allocated globally and by this thread, respectively. + let (concatenated, _bytes_allocated_globally, bytes_allocated_by_this_thread) = + memory_use(|| { + let mut concatenated = concatenate(num_concats, list_array.clone()); + concatenated.shrink_to_fit(); // This is what we're testing! + dbg!(concatenated.data_type()); + concatenated + }); + let expected_len = num_concats * array_len; + assert_eq!(bytes_used(concatenated.clone()), expected_len); + eprintln!("The concatenated array is {expected_len} B long. Amount of memory used by this thread: {bytes_allocated_by_this_thread} B"); + + assert!( + expected_len <= bytes_allocated_by_this_thread, + "We must allocate at least as much space as the concatenated array" + ); + assert!( + bytes_allocated_by_this_thread <= expected_len + expected_len / 100, + "We shouldn't have more than 1% memory overhead. In fact, we are using {bytes_allocated_by_this_thread} B of memory for {expected_len} B of data" + ); +} + +fn concatenate(num_times: usize, array: ArrayRef) -> ArrayRef { + let mut concatenated = array.clone(); + for _ in 0..num_times - 1 { + concatenated = arrow::compute::kernels::concat::concat(&[&*concatenated, &*array]).unwrap(); + } + concatenated +} + +fn bytes_used(array: ArrayRef) -> usize { + let mut array = array; + loop { + match array.data_type() { + arrow::datatypes::DataType::UInt8 => break, + arrow::datatypes::DataType::List(_) => { + let list = array.as_any().downcast_ref::().unwrap(); + array = list.values().clone(); + } + _ => unreachable!(), + } + } + + array.len() +} + +// --- Memory tracking --- + +use std::{ + alloc::Layout, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +static LIVE_BYTES_GLOBAL: AtomicUsize = AtomicUsize::new(0); + +thread_local! { + static LIVE_BYTES_IN_THREAD: AtomicUsize = const { AtomicUsize::new(0) } ; +} + +pub struct TrackingAllocator { + allocator: std::alloc::System, +} + +#[global_allocator] +pub static GLOBAL_ALLOCATOR: TrackingAllocator = TrackingAllocator { + allocator: std::alloc::System, +}; + +#[allow(unsafe_code)] +// SAFETY: +// We just do book-keeping and then let another allocator do all the actual work. +unsafe impl std::alloc::GlobalAlloc for TrackingAllocator { + #[allow(clippy::let_and_return)] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + // SAFETY: + // Just deferring + let ptr = unsafe { self.allocator.alloc(layout) }; + if !ptr.is_null() { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_add(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_add(layout.size(), Relaxed); + } + ptr + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_sub(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_sub(layout.size(), Relaxed); + + // SAFETY: + // Just deferring + unsafe { self.allocator.dealloc(ptr, layout) }; + } + + // No need to override `alloc_zeroed` or `realloc`, + // since they both by default just defer to `alloc` and `dealloc`. +} + +fn live_bytes_local() -> usize { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.load(Relaxed)) +} + +fn live_bytes_global() -> usize { + LIVE_BYTES_GLOBAL.load(Relaxed) +} + +/// Returns `(num_bytes_allocated, num_bytes_allocated_by_this_thread)`. +fn memory_use(run: impl Fn() -> R) -> (R, usize, usize) { + let used_bytes_start_local = live_bytes_local(); + let used_bytes_start_global = live_bytes_global(); + let ret = run(); + let bytes_used_local = live_bytes_local() - used_bytes_start_local; + let bytes_used_global = live_bytes_global() - used_bytes_start_global; + (ret, bytes_used_global, bytes_used_local) +}