Skip to content

Commit

Permalink
Fix shape access (#8)
Browse files Browse the repository at this point in the history
* use shape from object

* use shape from object & fix iteration

* fmt - clippy

* clippy

---------

Co-authored-by: mar1 <[email protected]>
  • Loading branch information
xxxxxxxxxmr and mar1 authored Jul 14, 2024
1 parent d2c554c commit 4483199
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 60 deletions.
12 changes: 3 additions & 9 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,21 @@ fn bench_matrix_multiplication(c: &mut Criterion) {
let m1 = Dimensional::<f64, LinearArrayStorage<f64, 2>, 2>::ones(shape1);
let m2 = Dimensional::<f64, LinearArrayStorage<f64, 2>, 2>::ones(shape2);

c.bench_function("matrix_multiplication", |b| {
b.iter(|| m1.dot(&m2))
});
c.bench_function("matrix_multiplication", |b| b.iter(|| m1.dot(&m2)));
}

fn bench_matrix_transpose(c: &mut Criterion) {
let shape = [1000, 1000];
let m = Dimensional::<f64, LinearArrayStorage<f64, 2>, 2>::ones(shape);

c.bench_function("matrix_transpose", |b| {
b.iter(|| m.transpose())
});
c.bench_function("matrix_transpose", |b| b.iter(|| m.transpose()));
}

fn bench_matrix_trace(c: &mut Criterion) {
let shape = [1000, 1000];
let m = Dimensional::<f64, LinearArrayStorage<f64, 2>, 2>::ones(shape);

c.bench_function("matrix_trace", |b| {
b.iter(|| m.trace())
});
c.bench_function("matrix_trace", |b| b.iter(|| m.trace()));
}

criterion_group!(
Expand Down
63 changes: 39 additions & 24 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,41 +126,39 @@ where
where
F: Fn([usize; N]) -> T,
{
let data = (0..shape.iter().product::<usize>())
.map(|i| {
let index = Self::unravel_index(i, &shape);
f(index)
})
.collect();

let storage = S::from_vec(shape, data);

Self {
// Initialize with zeros
let storage = S::zeros(shape);
let mut array = Self {
shape,
storage,
_marker: PhantomData,
};

// Unravel index and apply f
for i in 0..array.len() {
let index = array.unravel_index(i);
array.storage.as_mut_slice()[i] = f(index);
}
}

// TODO Seems like both of these could just use the shape already on the object
array
}

/// Converts a linear index to a multidimensional index.
///
/// # Arguments
///
/// * `index`: The linear index.
/// * `shape`: The shape of the array.
///
/// # Returns
///
/// A multidimensional index as an array of `usize`.
pub fn unravel_index(index: usize, shape: &[usize; N]) -> [usize; N] {
pub fn unravel_index(&self, index: usize) -> [usize; N] {
let mut index = index;
let mut unraveled = [0; N];

for i in (0..N).rev() {
unraveled[i] = index % shape[i];
index /= shape[i];
unraveled[i] = index % self.shape[i];
index /= self.shape[i];
}

unraveled
Expand All @@ -171,15 +169,14 @@ where
/// # Arguments
///
/// * `indices`: The multidimensional index.
/// * `shape`: The shape of the array.
///
/// # Returns
///
/// A linear index as `usize`.
pub fn ravel_index(indices: &[usize; N], shape: &[usize; N]) -> usize {
pub fn ravel_index(&self, indices: &[usize; N]) -> usize {
indices
.iter()
.zip(shape.iter())
.zip(self.shape.iter())
.fold(0, |acc, (&i, &s)| acc * s + i)
}

Expand Down Expand Up @@ -355,16 +352,34 @@ mod tests {

#[test]
fn test_unravel_and_ravel_index() {
let shape = [2, 3, 4];
let array: Dimensional<i32, LinearArrayStorage<i32, 3>, 3> = Dimensional::zeros([2, 3, 4]);
for i in 0..24 {
let unraveled =
Dimensional::<i32, LinearArrayStorage<i32, 3>, 3>::unravel_index(i, &shape);
let raveled =
Dimensional::<i32, LinearArrayStorage<i32, 3>, 3>::ravel_index(&unraveled, &shape);
let unraveled = array.unravel_index(i);
let raveled = array.ravel_index(&unraveled);
assert_eq!(i, raveled);
}
}

#[test]
fn test_ravel_unravel_consistency() {
let array: Dimensional<i32, LinearArrayStorage<i32, 3>, 3> = Dimensional::zeros([2, 3, 4]);

for i in 0..2 {
for j in 0..3 {
for k in 0..4 {
let index = [i, j, k];
let raveled = array.ravel_index(&index);
let unraveled = array.unravel_index(raveled);
assert_eq!(
index, unraveled,
"Ravel/unravel mismatch for index {:?}",
index
);
}
}
}
}

#[test]
fn test_shape_and_dimensions() {
let array: Dimensional<i32, LinearArrayStorage<i32, 3>, 3> = Dimensional::zeros([2, 3, 4]);
Expand Down
12 changes: 11 additions & 1 deletion src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ where
let mut index_array = [0; N];
index_array[0] = i;
index_array[1] = j;
let index = Self::ravel_index(&index_array, &shape);
let index = self.ravel_index(&index_array);
// Check if a precision is specified in the formatter
if let Some(precision) = f.precision() {
write!(f, "{:.1$}", self.as_slice()[index], precision)?;
Expand Down Expand Up @@ -263,4 +263,14 @@ mod tests {
"5D array: shape [2, 2, 2, 2, 2], data [0, 1, 2, ..., 31]"
);
}

#[test]
fn test_display_consistency() {
let array: Dimensional<i32, LinearArrayStorage<i32, 2>, 2> =
Dimensional::from_fn([3, 4], |[i, j]| (i * 4 + j) as i32);

let display_output = format!("{}", array);
let expected_output = "[\n [0, 1, 2, 3],\n [4, 5, 6, 7],\n [8, 9, 10, 11]\n]";
assert_eq!(display_output, expected_output, "Display output mismatch");
}
}
158 changes: 147 additions & 11 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,19 @@ where

let index = self.current_index;

// TODO: Actually iterate correctly here over an `N`-dimensional array
// with `N` axes each with a possibly different length.
// and determine iteration pattern
self.remaining -= 1;

// Update the index for the next iteration
for i in (0..N).rev() {
self.current_index[i] += 1;
if self.current_index[i] < self.dimensional.shape()[i] {
if self.current_index[i] < self.dimensional.shape()[i] - 1 {
self.current_index[i] += 1;
break;
} else {
self.current_index[i] = 0;
}
self.current_index[i] = 0;
}

self.remaining -= 1;

let linear_index = Dimensional::<T, S, N>::ravel_index(&index, &self.dimensional.shape());
let linear_index = self.dimensional.ravel_index(&index);
// TODO: We really don't want to use unsafe rust here
// SAFETY: This is safe because we're returning a unique reference to each element,
// and we're iterating over each element only once.
Expand Down Expand Up @@ -222,8 +219,6 @@ where
mod tests {
use crate::{matrix, storage::LinearArrayStorage, Dimensional};

// ... (previous tests remain unchanged)

#[test]
fn test_iter_mut_borrow() {
let mut m = matrix![[1, 2], [3, 4]];
Expand All @@ -234,4 +229,145 @@ mod tests {
assert_eq!(iter.next(), Some(&mut 4));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_next() {
let array_1d: Dimensional<i32, LinearArrayStorage<i32, 1>, 1> = Dimensional::zeros([5]);
let mut iter = array_1d.iter();
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_next_matrix() {
let array_2d: Dimensional<i32, LinearArrayStorage<i32, 2>, 2> = Dimensional::zeros([2, 3]);
let mut iter = array_2d.iter();
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_mut_next() {
let mut array_1d: Dimensional<i32, LinearArrayStorage<i32, 1>, 1> = Dimensional::zeros([5]);
let mut iter = array_1d.iter_mut();
if let Some(elem) = iter.next() {
*elem = 1;
}
if let Some(elem) = iter.next() {
*elem = 2;
}
if let Some(elem) = iter.next() {
*elem = 3;
}
if let Some(elem) = iter.next() {
*elem = 4;
}
if let Some(elem) = iter.next() {
*elem = 5;
}

let mut iter = array_1d.iter_mut();
assert_eq!(iter.next(), Some(&mut 1));
assert_eq!(iter.next(), Some(&mut 2));
assert_eq!(iter.next(), Some(&mut 3));
assert_eq!(iter.next(), Some(&mut 4));
assert_eq!(iter.next(), Some(&mut 5));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_mut_next_matrix() {
let mut array_2d: Dimensional<i32, LinearArrayStorage<i32, 2>, 2> =
Dimensional::zeros([2, 3]);
let mut iter = array_2d.iter_mut();
if let Some(elem) = iter.next() {
*elem = 1;
}
if let Some(elem) = iter.next() {
*elem = 2;
}
if let Some(elem) = iter.next() {
*elem = 3;
}
if let Some(elem) = iter.next() {
*elem = 4;
}
if let Some(elem) = iter.next() {
*elem = 5;
}
if let Some(elem) = iter.next() {
*elem = 6;
}

let mut iter = array_2d.iter_mut();
assert_eq!(iter.next(), Some(&mut 1));
assert_eq!(iter.next(), Some(&mut 2));
assert_eq!(iter.next(), Some(&mut 3));
assert_eq!(iter.next(), Some(&mut 4));
assert_eq!(iter.next(), Some(&mut 5));
assert_eq!(iter.next(), Some(&mut 6));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_empty() {
let array_empty: Dimensional<i32, LinearArrayStorage<i32, 1>, 1> = Dimensional::zeros([0]);
let mut iter = array_empty.iter();
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_mut_empty() {
let mut array_empty: Dimensional<i32, LinearArrayStorage<i32, 1>, 1> =
Dimensional::zeros([0]);
let mut iter = array_empty.iter_mut();
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_high_dimensional() {
let array_3d: Dimensional<i32, LinearArrayStorage<i32, 3>, 3> =
Dimensional::zeros([2, 3, 2]);
let mut iter = array_3d.iter();
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), None);
}

#[test]
fn test_iter_mut_high_dimensional() {
let mut array_3d: Dimensional<i32, LinearArrayStorage<i32, 3>, 3> =
Dimensional::zeros([2, 3, 2]);
let mut iter = array_3d.iter_mut();
for i in 1..=12 {
if let Some(elem) = iter.next() {
*elem = i;
}
}

let mut iter = array_3d.iter_mut();
for mut i in 1..=12 {
assert_eq!(iter.next(), Some(&mut i));
}
assert_eq!(iter.next(), None);
}
}
Loading

0 comments on commit 4483199

Please sign in to comment.