Skip to content

Commit

Permalink
Add iterator for slice and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Riley Sutton authored and Riley Sutton committed Mar 25, 2024
1 parent abf594b commit 6bcf07f
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/engine/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod iter;
pub mod builder;
pub mod factory;

use crate::helper::{Shape, Position, Slice, };
use crate::helper::{Interval, Position, Shape, Slice };
use self::extension::ExtensionProvider;
use self::iter::EngineTensorUnitIterator;
use std::fmt::Debug;
Expand All @@ -19,7 +19,7 @@ pub trait EngineTensor: Debug {

fn clone(&self) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn mat(&self) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn slice(&self, slice: &Slice) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn slice(&self, intervals: &[Interval]) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn reshape(&self, shape: &Shape) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box<dyn EngineTensor<Unit = Self::Unit>>;

Expand Down
8 changes: 4 additions & 4 deletions src/engine_impl/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ pub fn im2col_2d<T: UnitCompatible, E: EngineTensorFactory<Unit = T>>(

for y in 0..out_y {
for x in 0..out_x {
let grouped_patches = a_padded.slice(&Slice::from(
let grouped_patches = a_padded.slice(
[
Interval::all(),
Interval::all(),
Interval::between_with_step(y, y + k_y, stride),
Interval::between_with_step(x, x + k_x, stride),
]
.as_slice(),
));
);

let grouped_patches = grouped_patches.reshape(&grouped_patches_shape);

for batch in 0..batches {
for channel in 0..in_channels {
let patch = grouped_patches.slice(&Slice::from(
let patch = grouped_patches.slice(
[
Interval::only(batch),
Interval::only(channel),
Interval::all(),
]
.as_slice(),
));
);

let start_index = Position::from([batch, channel, y, x, 0].as_slice())
.tensor_index(&out_stride)
Expand Down
8 changes: 5 additions & 3 deletions src/engine_impl/tensor/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
},
unit::UnitCompatible,
},
helper::{Position, Shape, Slice, Stride, VarArrayCompatible},
helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible},
};

#[derive(Debug)]
Expand Down Expand Up @@ -90,12 +90,14 @@ impl<T: AllowedArray> EngineTensor for Array<T> {
Array::from_iter(self.iter_units(), self.shape().clone()).generic()
}

fn slice(&self, slice: &Slice) -> Box<dyn EngineTensor<Unit = T>> {
fn slice(&self, intervals: &[Interval]) -> Box<dyn EngineTensor<Unit = T>> {
let slice = Slice::new(intervals.into(), self.shape().clone());

let offset = slice.start().tensor_index(&self.stride).unwrap();

Box::from(Self {
stride: self.stride.clone(),
shape: slice.inferred_shape(self.shape()),
shape: slice.inferred_shape(),
data: self.data.clone(),
offset,
})
Expand Down
8 changes: 5 additions & 3 deletions src/engine_impl/tensor/padded.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{engine::{tensor::{extension::{EmptyExtensionProvider, ExtensionProvider}, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible}, helper::{Position, Shape, Slice, VarArray, VarArrayCompatible}};
use crate::{engine::{tensor::{extension::{EmptyExtensionProvider, ExtensionProvider}, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible}, helper::{Interval, Position, Shape, Slice, VarArray, VarArrayCompatible}};

use super::array::Array;

Expand Down Expand Up @@ -122,8 +122,10 @@ impl<T: AllowedPadded> EngineTensor for Padded<T> {
}

//We can handle slices but changing anything more drastic needs a deep copy
fn slice(&self, slice: &Slice) -> Box<dyn EngineTensor<Unit = Self::Unit>> {
let slice_shape = slice.inferred_shape(self.shape());
fn slice(&self, intervals: &[Interval]) -> Box<dyn EngineTensor<Unit = Self::Unit>> {
let slice = Slice::new(intervals.into(), self.shape().clone());

let slice_shape = slice.inferred_shape();

let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step()));

Expand Down
8 changes: 7 additions & 1 deletion src/helper/position/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct Iter<'a> {
pos: Position,
until: &'a Position,
bounds: &'a Shape,
is_done: bool,
}

impl<'a> Iter<'a> {
Expand All @@ -16,6 +17,7 @@ impl<'a> Iter<'a> {
pos,
until,
bounds,
is_done: false,
}
}
}
Expand All @@ -25,8 +27,12 @@ impl<'a> Iterator for Iter<'a> {
type Item = Position;

fn next(&mut self) -> Option<Self::Item> {
if self.pos == *self.until {
if self.is_done {
None
} else if self.pos == *self.until {
self.is_done = true;

Some(self.pos.clone())
} else {
self.pos.incdec_mut(self.bounds, 1).unwrap();

Expand Down
10 changes: 10 additions & 0 deletions src/helper/position/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ impl From<&[usize]> for Position {
}
}

macro_rules! position {
() => {
Position::from([].as_slice())
};
($($x:expr),+) => {
Position::from([$($x),+].as_slice())
};
}
pub(crate) use position;

#[derive(Error, Debug)]
pub enum PositionError {
#[error("Stride length was: {0}, expected: {1}")]
Expand Down
128 changes: 128 additions & 0 deletions src/helper/slice/iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use crate::helper::{Position, Shape, VarArrayCompatible};

use super::Slice;

pub struct Iter<'a> {
pos: Position,
until: Position,
slice_shape: Shape,
starts: Position,
slice: &'a Slice,
is_done: bool,
}

impl<'a> Iter<'a> {
pub fn new(slice: &'a Slice) -> Iter<'a> {
let slice_shape = slice.inferred_shape();

Self {
pos: slice_shape.first(),
until: slice_shape.last(),
slice_shape,
starts: slice.start(), //Will be
slice,
is_done: false,
}
}
}

impl<'a> Iterator for Iter<'a> {
type Item = Position;

fn next(&mut self) -> Option<Self::Item> {
if self.is_done {
None
} else if self.pos == self.until {
self.is_done = true;

Some(self.pos.add(&self.starts).unwrap())
} else {
let out = self.pos.add(&self.starts).unwrap();

self.pos.incdec_mut(&self.slice_shape, 1).unwrap();

Some(out)
}
}
}

#[cfg(test)]
mod test {
use crate::helper::{position, shape, Interval};

use super::*;

#[test]
fn basic_intervals() {
let base_shape = shape![6, 5, 4, 3, 2];
let intervals = Box::from([
Interval::all(),
Interval::between(2, 4),
Interval::start_to(2),
Interval::end_from(1),
Interval::only(1),
]);

let slice = Slice::new(intervals, base_shape);

println!("{:?}", slice.start());
println!("{:?}", slice.last());

let ref_positions = [
position![0, 2, 0, 1, 1],
position![0, 2, 0, 2, 1],
position![0, 2, 1, 1, 1],
position![0, 2, 1, 2, 1],
position![0, 3, 0, 1, 1],
position![0, 3, 0, 2, 1],
position![0, 3, 1, 1, 1],
position![0, 3, 1, 2, 1],
position![1, 2, 0, 1, 1],
position![1, 2, 0, 2, 1],
position![1, 2, 1, 1, 1],
position![1, 2, 1, 2, 1],
position![1, 3, 0, 1, 1],
position![1, 3, 0, 2, 1],
position![1, 3, 1, 1, 1],
position![1, 3, 1, 2, 1],
position![2, 2, 0, 1, 1],
position![2, 2, 0, 2, 1],
position![2, 2, 1, 1, 1],
position![2, 2, 1, 2, 1],
position![2, 3, 0, 1, 1],
position![2, 3, 0, 2, 1],
position![2, 3, 1, 1, 1],
position![2, 3, 1, 2, 1],
position![3, 2, 0, 1, 1],
position![3, 2, 0, 2, 1],
position![3, 2, 1, 1, 1],
position![3, 2, 1, 2, 1],
position![3, 3, 0, 1, 1],
position![3, 3, 0, 2, 1],
position![3, 3, 1, 1, 1],
position![3, 3, 1, 2, 1],
position![4, 2, 0, 1, 1],
position![4, 2, 0, 2, 1],
position![4, 2, 1, 1, 1],
position![4, 2, 1, 2, 1],
position![4, 3, 0, 1, 1],
position![4, 3, 0, 2, 1],
position![4, 3, 1, 1, 1],
position![4, 3, 1, 2, 1],
position![5, 2, 0, 1, 1],
position![5, 2, 0, 2, 1],
position![5, 2, 1, 1, 1],
position![5, 2, 1, 2, 1],
position![5, 3, 0, 1, 1],
position![5, 3, 0, 2, 1],
position![5, 3, 1, 1, 1],
position![5, 3, 1, 2, 1],
];

assert_eq!(slice.iter().collect::<Vec<Position>>().len(), slice.elements());

for (e, r) in slice.iter().zip(ref_positions.iter()) {
assert_eq!(e, *r);
}
}
}
45 changes: 29 additions & 16 deletions src/helper/slice.rs → src/helper/slice/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use crate::helper::Shape;

use self::iter::Iter;

use super::{Position, VarArrayCompatible};

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub mod iter;

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Interval {
start: Option<usize>,
end: Option<usize>,
step: Option<usize>,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Slice(Box<[Interval]>);
pub struct Slice {
intervals: Box<[Interval]>,
shape: Shape,
}

impl Interval {
pub fn new(start: Option<usize>, end: Option<usize>, step: Option<usize>) -> Self {
Expand Down Expand Up @@ -93,39 +100,45 @@ impl Interval {
}

impl Slice {
pub fn new(data: Box<[Interval]>) -> Self {
Self(data)
pub fn new(intervals: Box<[Interval]>, shape: Shape) -> Self {
Self {
intervals,
shape,
}
}

pub fn as_boxed_slice(&self) -> &Box<[Interval]> {
&self.0
&self.intervals
}

pub fn as_mut_boxed_slice(&mut self) -> &mut Box<[Interval]> {
&mut self.0
&mut self.intervals
}

//Should be possible to have a different len to self.shape
pub fn inferred_shape(&self) -> Shape {
Shape::new(self.as_boxed_slice().iter().zip(self.shape.iter()).map(|(interval, dim)| interval.len(dim)).collect())
}

pub fn inferred_shape(&self, shape: &Shape) -> Shape {
Shape::new(self.as_boxed_slice().iter().zip(shape.iter()).map(|(interval, dim)| interval.len(dim)).collect())
pub fn len(&self) -> usize {
self.inferred_shape().len()
}

pub fn len(&self, shape: &Shape) -> usize {
self.inferred_shape(shape).len()
pub fn elements(&self) -> usize {
self.inferred_shape().elements()
}

pub fn start(&self) -> Position {
Position::new(self.as_boxed_slice().iter().map(|interval| interval.start_index()).collect())
}

//This is called last to differentiate between end which wouldn't be a valid position
pub fn last(&self, shape: &Shape) -> Position {
Position::new(self.as_boxed_slice().iter().zip(shape.iter()).map(|(interval, dim)| interval.end_index(dim).saturating_sub(1)).collect())
pub fn last(&self) -> Position {
Position::new(self.as_boxed_slice().iter().zip(self.shape.iter()).map(|(interval, dim)| interval.end_index(dim).saturating_sub(1)).collect())
}
}

impl From<&[Interval]> for Slice {
fn from(value: &[Interval]) -> Self {
Self(Box::from(value))
pub fn iter(&self) -> Iter {
Iter::new(self)
}
}

Expand Down

0 comments on commit 6bcf07f

Please sign in to comment.