Skip to content

Commit

Permalink
moved flat index computation to an inline function
Browse files Browse the repository at this point in the history
  • Loading branch information
imrn99 committed Nov 10, 2023
1 parent 529f3bc commit b1b2179
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/view/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use std::{
};

#[derive(Debug)]
/// Enum used to classify view-related errors.
///
/// In all variants, the internal value is a description of the error.
pub enum ViewError<'a> {
ValueError(&'a str),
}
Expand Down Expand Up @@ -129,17 +132,25 @@ where
))
}
}

#[inline(always)]
pub fn flat_idx(&self, index: [usize; N]) -> usize {
index
.iter()
.zip(self.stride.iter())
.map(|(i, s_i)| *i * *s_i)
.sum()
}
}

impl<'a, const N: usize, T> Index<[usize; N]> for ViewBase<'a, N, T> {
impl<'a, const N: usize, T> Index<[usize; N]> for ViewBase<'a, N, T>
where
T: Default + Clone,
{
type Output = T;

fn index(&self, index: [usize; N]) -> &Self::Output {
let flat_idx: usize = index
.iter()
.zip(self.stride.iter())
.map(|(i, s_i)| *i * *s_i)
.sum();
let flat_idx: usize = self.flat_idx(index);
match &self.data {
DataType::Owned(v) => {
assert!(flat_idx < v.len()); // remove bounds check
Expand All @@ -157,13 +168,12 @@ impl<'a, const N: usize, T> Index<[usize; N]> for ViewBase<'a, N, T> {
}
}

impl<'a, const N: usize, T> IndexMut<[usize; N]> for ViewBase<'a, N, T> {
impl<'a, const N: usize, T> IndexMut<[usize; N]> for ViewBase<'a, N, T>
where
T: Default + Clone,
{
fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output {
let flat_idx: usize = index
.iter()
.zip(self.stride.iter())
.map(|(i, s_i)| *i * *s_i)
.sum();
let flat_idx: usize = self.flat_idx(index);
match &mut self.data {
DataType::Owned(v) => {
assert!(flat_idx < v.len()); // remove bounds check
Expand Down

0 comments on commit b1b2179

Please sign in to comment.