diff --git a/src/engine/tensor/mod.rs b/src/engine/tensor/mod.rs index 9a096ed..b26cad7 100644 --- a/src/engine/tensor/mod.rs +++ b/src/engine/tensor/mod.rs @@ -3,7 +3,7 @@ pub mod iter; pub mod builder; pub mod factory; -use crate::helper::{Shape, Position, Slice, }; +use crate::helper::{Interval, Position, Shape, Slice }; use self::extension::ExtensionProvider; use self::iter::EngineTensorUnitIterator; use std::fmt::Debug; @@ -19,7 +19,7 @@ pub trait EngineTensor: Debug { fn clone(&self) -> Box>; fn mat(&self) -> Box>; - fn slice(&self, slice: &Slice) -> Box>; + fn slice(&self, intervals: &[Interval]) -> Box>; fn reshape(&self, shape: &Shape) -> Box>; fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box>; diff --git a/src/engine_impl/shared.rs b/src/engine_impl/shared.rs index 8ef82c9..9619a04 100644 --- a/src/engine_impl/shared.rs +++ b/src/engine_impl/shared.rs @@ -48,7 +48,7 @@ pub fn im2col_2d>( for y in 0..out_y { for x in 0..out_x { - let grouped_patches = a_padded.slice(&Slice::from( + let grouped_patches = a_padded.slice( [ Interval::all(), Interval::all(), @@ -56,20 +56,20 @@ pub fn im2col_2d>( Interval::between_with_step(x, x + k_x, stride), ] .as_slice(), - )); + ); let grouped_patches = grouped_patches.reshape(&grouped_patches_shape); for batch in 0..batches { for channel in 0..in_channels { - let patch = grouped_patches.slice(&Slice::from( + let patch = grouped_patches.slice( [ Interval::only(batch), Interval::only(channel), Interval::all(), ] .as_slice(), - )); + ); let start_index = Position::from([batch, channel, y, x, 0].as_slice()) .tensor_index(&out_stride) diff --git a/src/engine_impl/tensor/array.rs b/src/engine_impl/tensor/array.rs index 481fc88..e531b31 100644 --- a/src/engine_impl/tensor/array.rs +++ b/src/engine_impl/tensor/array.rs @@ -11,7 +11,7 @@ use crate::{ }, unit::UnitCompatible, }, - helper::{Position, Shape, Slice, Stride, VarArrayCompatible}, + helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible}, }; #[derive(Debug)] @@ -90,12 +90,14 @@ impl EngineTensor for Array { Array::from_iter(self.iter_units(), self.shape().clone()).generic() } - fn slice(&self, slice: &Slice) -> Box> { + fn slice(&self, intervals: &[Interval]) -> Box> { + let slice = Slice::new(intervals.into(), self.shape().clone()); + let offset = slice.start().tensor_index(&self.stride).unwrap(); Box::from(Self { stride: self.stride.clone(), - shape: slice.inferred_shape(self.shape()), + shape: slice.inferred_shape(), data: self.data.clone(), offset, }) diff --git a/src/engine_impl/tensor/padded.rs b/src/engine_impl/tensor/padded.rs index cf6a6b6..e590426 100644 --- a/src/engine_impl/tensor/padded.rs +++ b/src/engine_impl/tensor/padded.rs @@ -1,4 +1,4 @@ -use crate::{engine::{tensor::{extension::{EmptyExtensionProvider, ExtensionProvider}, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible}, helper::{Position, Shape, Slice, VarArray, VarArrayCompatible}}; +use crate::{engine::{tensor::{extension::{EmptyExtensionProvider, ExtensionProvider}, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible}, helper::{Interval, Position, Shape, Slice, VarArray, VarArrayCompatible}}; use super::array::Array; @@ -122,8 +122,10 @@ impl EngineTensor for Padded { } //We can handle slices but changing anything more drastic needs a deep copy - fn slice(&self, slice: &Slice) -> Box> { - let slice_shape = slice.inferred_shape(self.shape()); + fn slice(&self, intervals: &[Interval]) -> Box> { + let slice = Slice::new(intervals.into(), self.shape().clone()); + + let slice_shape = slice.inferred_shape(); let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step())); diff --git a/src/helper/position/iter.rs b/src/helper/position/iter.rs index faa88ca..b48570f 100644 --- a/src/helper/position/iter.rs +++ b/src/helper/position/iter.rs @@ -8,6 +8,7 @@ pub struct Iter<'a> { pos: Position, until: &'a Position, bounds: &'a Shape, + is_done: bool, } impl<'a> Iter<'a> { @@ -16,6 +17,7 @@ impl<'a> Iter<'a> { pos, until, bounds, + is_done: false, } } } @@ -25,8 +27,12 @@ impl<'a> Iterator for Iter<'a> { type Item = Position; fn next(&mut self) -> Option { - if self.pos == *self.until { + if self.is_done { None + } else if self.pos == *self.until { + self.is_done = true; + + Some(self.pos.clone()) } else { self.pos.incdec_mut(self.bounds, 1).unwrap(); diff --git a/src/helper/position/mod.rs b/src/helper/position/mod.rs index beedede..e83ed86 100644 --- a/src/helper/position/mod.rs +++ b/src/helper/position/mod.rs @@ -80,6 +80,16 @@ impl From<&[usize]> for Position { } } +macro_rules! position { + () => { + Position::from([].as_slice()) + }; + ($($x:expr),+) => { + Position::from([$($x),+].as_slice()) + }; +} +pub(crate) use position; + #[derive(Error, Debug)] pub enum PositionError { #[error("Stride length was: {0}, expected: {1}")] diff --git a/src/helper/slice/iter.rs b/src/helper/slice/iter.rs new file mode 100644 index 0000000..39c4600 --- /dev/null +++ b/src/helper/slice/iter.rs @@ -0,0 +1,128 @@ +use crate::helper::{Position, Shape, VarArrayCompatible}; + +use super::Slice; + +pub struct Iter<'a> { + pos: Position, + until: Position, + slice_shape: Shape, + starts: Position, + slice: &'a Slice, + is_done: bool, +} + +impl<'a> Iter<'a> { + pub fn new(slice: &'a Slice) -> Iter<'a> { + let slice_shape = slice.inferred_shape(); + + Self { + pos: slice_shape.first(), + until: slice_shape.last(), + slice_shape, + starts: slice.start(), //Will be + slice, + is_done: false, + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = Position; + + fn next(&mut self) -> Option { + if self.is_done { + None + } else if self.pos == self.until { + self.is_done = true; + + Some(self.pos.add(&self.starts).unwrap()) + } else { + let out = self.pos.add(&self.starts).unwrap(); + + self.pos.incdec_mut(&self.slice_shape, 1).unwrap(); + + Some(out) + } + } +} + +#[cfg(test)] +mod test { + use crate::helper::{position, shape, Interval}; + + use super::*; + + #[test] + fn basic_intervals() { + let base_shape = shape![6, 5, 4, 3, 2]; + let intervals = Box::from([ + Interval::all(), + Interval::between(2, 4), + Interval::start_to(2), + Interval::end_from(1), + Interval::only(1), + ]); + + let slice = Slice::new(intervals, base_shape); + + println!("{:?}", slice.start()); + println!("{:?}", slice.last()); + + let ref_positions = [ + position![0, 2, 0, 1, 1], + position![0, 2, 0, 2, 1], + position![0, 2, 1, 1, 1], + position![0, 2, 1, 2, 1], + position![0, 3, 0, 1, 1], + position![0, 3, 0, 2, 1], + position![0, 3, 1, 1, 1], + position![0, 3, 1, 2, 1], + position![1, 2, 0, 1, 1], + position![1, 2, 0, 2, 1], + position![1, 2, 1, 1, 1], + position![1, 2, 1, 2, 1], + position![1, 3, 0, 1, 1], + position![1, 3, 0, 2, 1], + position![1, 3, 1, 1, 1], + position![1, 3, 1, 2, 1], + position![2, 2, 0, 1, 1], + position![2, 2, 0, 2, 1], + position![2, 2, 1, 1, 1], + position![2, 2, 1, 2, 1], + position![2, 3, 0, 1, 1], + position![2, 3, 0, 2, 1], + position![2, 3, 1, 1, 1], + position![2, 3, 1, 2, 1], + position![3, 2, 0, 1, 1], + position![3, 2, 0, 2, 1], + position![3, 2, 1, 1, 1], + position![3, 2, 1, 2, 1], + position![3, 3, 0, 1, 1], + position![3, 3, 0, 2, 1], + position![3, 3, 1, 1, 1], + position![3, 3, 1, 2, 1], + position![4, 2, 0, 1, 1], + position![4, 2, 0, 2, 1], + position![4, 2, 1, 1, 1], + position![4, 2, 1, 2, 1], + position![4, 3, 0, 1, 1], + position![4, 3, 0, 2, 1], + position![4, 3, 1, 1, 1], + position![4, 3, 1, 2, 1], + position![5, 2, 0, 1, 1], + position![5, 2, 0, 2, 1], + position![5, 2, 1, 1, 1], + position![5, 2, 1, 2, 1], + position![5, 3, 0, 1, 1], + position![5, 3, 0, 2, 1], + position![5, 3, 1, 1, 1], + position![5, 3, 1, 2, 1], + ]; + + assert_eq!(slice.iter().collect::>().len(), slice.elements()); + + for (e, r) in slice.iter().zip(ref_positions.iter()) { + assert_eq!(e, *r); + } + } +} \ No newline at end of file diff --git a/src/helper/slice.rs b/src/helper/slice/mod.rs similarity index 80% rename from src/helper/slice.rs rename to src/helper/slice/mod.rs index 690fc7e..7c00efc 100644 --- a/src/helper/slice.rs +++ b/src/helper/slice/mod.rs @@ -1,8 +1,12 @@ use crate::helper::Shape; +use self::iter::Iter; + use super::{Position, VarArrayCompatible}; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub mod iter; + +#[derive(Clone, PartialEq, Eq, Debug)] pub struct Interval { start: Option, end: Option, @@ -10,7 +14,10 @@ pub struct Interval { } #[derive(Clone, PartialEq, Eq, Debug)] -pub struct Slice(Box<[Interval]>); +pub struct Slice { + intervals: Box<[Interval]>, + shape: Shape, +} impl Interval { pub fn new(start: Option, end: Option, step: Option) -> Self { @@ -93,24 +100,32 @@ impl Interval { } impl Slice { - pub fn new(data: Box<[Interval]>) -> Self { - Self(data) + pub fn new(intervals: Box<[Interval]>, shape: Shape) -> Self { + Self { + intervals, + shape, + } } pub fn as_boxed_slice(&self) -> &Box<[Interval]> { - &self.0 + &self.intervals } pub fn as_mut_boxed_slice(&mut self) -> &mut Box<[Interval]> { - &mut self.0 + &mut self.intervals + } + + //Should be possible to have a different len to self.shape + pub fn inferred_shape(&self) -> Shape { + Shape::new(self.as_boxed_slice().iter().zip(self.shape.iter()).map(|(interval, dim)| interval.len(dim)).collect()) } - pub fn inferred_shape(&self, shape: &Shape) -> Shape { - Shape::new(self.as_boxed_slice().iter().zip(shape.iter()).map(|(interval, dim)| interval.len(dim)).collect()) + pub fn len(&self) -> usize { + self.inferred_shape().len() } - pub fn len(&self, shape: &Shape) -> usize { - self.inferred_shape(shape).len() + pub fn elements(&self) -> usize { + self.inferred_shape().elements() } pub fn start(&self) -> Position { @@ -118,14 +133,12 @@ impl Slice { } //This is called last to differentiate between end which wouldn't be a valid position - pub fn last(&self, shape: &Shape) -> Position { - Position::new(self.as_boxed_slice().iter().zip(shape.iter()).map(|(interval, dim)| interval.end_index(dim).saturating_sub(1)).collect()) + pub fn last(&self) -> Position { + Position::new(self.as_boxed_slice().iter().zip(self.shape.iter()).map(|(interval, dim)| interval.end_index(dim).saturating_sub(1)).collect()) } -} -impl From<&[Interval]> for Slice { - fn from(value: &[Interval]) -> Self { - Self(Box::from(value)) + pub fn iter(&self) -> Iter { + Iter::new(self) } }