From 1180941939ea5a57cca4469c3b2dec7b3ec1451f Mon Sep 17 00:00:00 2001 From: Riley Sutton Date: Sun, 4 Feb 2024 23:31:34 +1100 Subject: [PATCH] Add im2col --- .gitignore | 5 +- src/engine/basic.rs | 83 +++++++--------- src/engine/mod.rs | 3 +- src/engine/shared.rs | 189 ++++++++++++++++++++++++++++++++++++ src/engine/tensor/mod.rs | 12 ++- src/engine/tensor/padded.rs | 9 +- src/helper/slice.rs | 12 ++- 7 files changed, 256 insertions(+), 57 deletions(-) create mode 100644 src/engine/shared.rs diff --git a/.gitignore b/.gitignore index 211413a..8cfdc31 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ /target -/old_src \ No newline at end of file +/old_src +flamegraph.svg +perf.data +perf.data.old \ No newline at end of file diff --git a/src/engine/basic.rs b/src/engine/basic.rs index 3dd6a7e..711f205 100644 --- a/src/engine/basic.rs +++ b/src/engine/basic.rs @@ -2,64 +2,49 @@ use std::ops::IndexMut; use itertools::Itertools; -use crate::{engine::{Engine, EngineError, EngineTensorFactory, EngineTensor, util::{err_if_incorrect_dimensions, err_if_dimension_mismatch}}, helper::{Position, Interval, Slice, Shape, Stride, VarArrayCompatible}}; +use crate::{engine::{Engine, EngineError, EngineTensorFactory, EngineTensor, util::{err_if_incorrect_dimensions, err_if_dimension_mismatch}}, helper::{shape, Interval, Position, Shape, Slice, Stride, VarArrayCompatible}}; + +use super::shared::im2col_2d; pub struct Basic {} + + macro_rules! conv_fn { ($unit:ty) => { - fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, stride: usize) -> Result>, crate::engine::EngineError> { - //a: (batches, in_channels, y, x) - //kernel: (out_channels, in_channels, k_y, k_x) - + //a: (batches, in_channels, y, x) + //kernel: (out_channels, in_channels, k_y, k_x) + fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, padding: usize, stride: usize) -> Result>, crate::engine::EngineError> { err_if_incorrect_dimensions(a.shape(), 4)?; err_if_incorrect_dimensions(kernel.shape(), 4)?; err_if_dimension_mismatch(a.shape().get(1).unwrap(), kernel.shape().get(1).unwrap())?; - - let y = a.shape().get(a.shape().len() - 2).unwrap(); - let x = a.shape().get(a.shape().len() - 1).unwrap(); - let k_y = kernel.shape().get(kernel.shape().len() - 2).unwrap() + 2 * (stride - 1); - let k_x = kernel.shape().get(kernel.shape().len() - 1).unwrap() + 2 * (stride - 1); - + let batches = a.shape().get(0).unwrap(); + let in_channels = a.shape().get(1).unwrap(); + let out_channels = kernel.shape().get(0).unwrap(); - let in_channels = kernel.shape().get(1).unwrap(); - - //(batches, out_channels, in_channels, y, x) - let a_broadcast = a.broadcast_splice(1, &[out_channels]); - //(batches, out_channels, in_channels, k_y, k_x) - let kernel_broadcast = kernel.broadcast_splice(0, &[batches]); + let k_y = kernel.shape().get(2).unwrap(); + let k_x = kernel.shape().get(3).unwrap(); + + //(batches, out_channels, in_channels, out_y, out_x, patch_len) + let proc = im2col_2d::<$unit, E>(a, kernel.shape(), padding, stride).broadcast_splice(1, &[out_channels]); + + let out_y = proc.shape().get(3).unwrap(); + let out_x = proc.shape().get(4).unwrap(); + let patch_len = proc.shape().get(5).unwrap(); + + //(batches, out_channels, in_channels, out_y, out_x, patch_len) + let kernels = kernel.reshape(&shape![out_channels, in_channels, k_y * k_x]).broadcast_splice(0, &[batches]).broadcast_splice(3, [out_y, out_x].as_slice()); - let half_k_y = k_y / 2; - let half_k_x = k_x / 2; - - let y_out = y - half_k_y * 2; - let x_out = x - half_k_x * 2; - let out_shape = Shape::from([batches, out_channels, y_out, x_out].as_slice()); - let out_stride = Stride::default_from_shape(&out_shape); - - let mut reordered_sums: Box<[$unit]> = vec![<$unit>::default(); batches * out_channels * y_out * x_out].into_boxed_slice(); - - for curr_y in 0..y_out { - for curr_x in 0..x_out { - let a_sliced = a_broadcast.slice(&Slice::from([Interval::all(), Interval::all(), Interval::all(), Interval::between_with_step(curr_y, curr_y + k_y, stride), Interval::between_with_step(curr_x, curr_x + k_x, stride)].as_slice())); - - let chunked_products = a_sliced.iter_units().zip(kernel_broadcast.iter_units()).map(|(a_elem, k_elem)| a_elem * k_elem).chunks(in_channels * k_y * k_x); - let curr_sums = chunked_products.into_iter().map(|i| -> $unit { i.sum() }); - - for (i, sum) in curr_sums.enumerate() { - let batch = i / out_channels; - let out_channel = i % out_channels; - - let index = Position::from([batch, out_channel, curr_y, curr_x].as_slice()).tensor_index(&out_stride)?; - - *reordered_sums.index_mut(index) = sum; - } - } - } - - //(batches, out_channels, y_out, x_out) - Ok(E::from_slice(&reordered_sums, out_shape)) + println!("{:?}", out_y); + println!("{:?}", proc.shape()); + println!("{:?}", kernels.shape()); + + let chunked_iter = proc.iter_units().zip(kernels.iter_units()).map(|(x, y)| x * y).chunks(patch_len); + let out_data = chunked_iter.into_iter().map(|i| i.sum()); + + //(batches, out_channels, out_y, out_x) + Ok(E::from_iter(out_data, shape![batches, out_channels, out_y, out_x])) } }; } @@ -237,10 +222,10 @@ mod test { #[test] pub fn conv() { - let a = Array::from_iter((1..=65536).map(|x| (x as f32) / 65536.0).cycle().take(1 * 3 * 256 * 256), shape![1, 3, 256, 256]); + let a = Array::from_iter((1..=65536).map(|x| (x as f32) / 65536.0).cycle().take(4 * 3 * 256 * 256), shape![4, 3, 256, 256]); let kernel = Array::from_iter((1..=9).map(|x| x as f32).cycle().take(1 * 3 * 3 * 3), shape![1, 3, 3, 3]); - let res = Basic::conv2d::>(a.as_ref(), kernel.as_ref(), 2).unwrap(); + let res = Basic::conv2d::>(a.as_ref(), kernel.as_ref(), 2, 1).unwrap(); println!("{:?}", res.shape()); //println!("{:?}", res.iter_unit().collect::>()); diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 76b0828..c2885ee 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -1,6 +1,7 @@ pub mod tensor; pub mod basic; +mod shared; mod util; use crate::helper::{Shape, PositionError}; @@ -31,7 +32,7 @@ pub trait Engine { fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, EngineError>; //Conv - fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, stride: usize) -> Result>, EngineError>; + fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, padding: usize, stride: usize) -> Result>, EngineError>; } #[derive(Error, Debug)] diff --git a/src/engine/shared.rs b/src/engine/shared.rs new file mode 100644 index 0000000..c0eaa89 --- /dev/null +++ b/src/engine/shared.rs @@ -0,0 +1,189 @@ +use std::iter; + +use crate::{ + engine::tensor::padded::Padded, + helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible}, +}; + +use super::tensor::{allowed_unit::AllowedUnit, factory::EngineTensorFactory, Array, EngineTensor}; + +//a: (batches, in_channels, img_y, img_x) +//kernel_shape: (in_channels, k_y, k_x) +//out: (batches, in_channels, out_y, out_x, k_y * k_x) +pub fn im2col_2d>( + a: &dyn EngineTensor, + kernel_shape: &Shape, + padding: usize, + stride: usize, +) -> Box> { + let batches = a.shape().get(0).unwrap(); + let in_channels = a.shape().get(1).unwrap(); + + //Ok if zero padding + let a_padded = Padded::pad_from( + a.clone(), + [0, 0, padding, padding].as_slice().into(), + T::default(), + ); + + let img_y = a_padded.shape().get(2).unwrap(); + let img_x = a_padded.shape().get(3).unwrap(); + + let k_y = kernel_shape.get(1).unwrap(); + let k_x = kernel_shape.get(2).unwrap(); + + let out_y = (img_y - k_y) / stride + 1; + let out_x = (img_x - k_x) / stride + 1; + + let patch_len = k_y * k_x; + + let out_shape = Shape::from([batches, in_channels, out_y, out_x, patch_len].as_slice()); + let out_stride = Stride::default_from_shape(&out_shape); + + //let final_img_dims = Shape::new(kernel_shape.iter().zip(img_dims.iter()).map(|(k_d, a_d)| (a_d + 2 * padding - k_d) / stride + 1).collect()); + + let grouped_patches_shape = Shape::from([batches, in_channels, patch_len].as_slice()); + + //Buffer used for output + + let mut buffer = Vec::::from_iter(iter::repeat(T::default()).take(out_shape.elements())); + buffer.shrink_to_fit(); + + for y in 0..out_y { + for x in 0..out_x { + let grouped_patches = a_padded.slice(&Slice::from( + [ + Interval::all(), + Interval::all(), + Interval::between_with_step(y, y + k_y, stride), + Interval::between_with_step(x, x + k_x, stride), + ] + .as_slice(), + )); + + let grouped_patches = grouped_patches.reshape(&grouped_patches_shape); + + for batch in 0..batches { + for channel in 0..in_channels { + let patch = grouped_patches.slice(&Slice::from( + [ + Interval::only(batch), + Interval::only(channel), + Interval::all(), + ] + .as_slice(), + )); + + let start_index = Position::from([batch, channel, y, x, 0].as_slice()) + .tensor_index(&out_stride) + .unwrap(); + + buffer.splice(start_index..(start_index + patch_len), patch.iter_units()); + } + } + } + } + + E::from_slice(buffer.as_slice(), out_shape) +} + +#[cfg(test)] +mod test { + use crate::engine::tensor::Array; + + use super::*; + + #[test] + fn simple_im2col_2d() { + //Pytorch generated im2col + let expected: [f32; 972] = [ + 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, + 6.0, 0.0, 0.0, 0.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0, + 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, + 8.0, 9.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 8.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 10.0, 11.0, 0.0, 13.0, 14.0, 0.0, 0.0, 0.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, + 0.0, 0.0, 11.0, 12.0, 0.0, 14.0, 15.0, 0.0, 0.0, 10.0, 11.0, 0.0, 13.0, 14.0, 0.0, + 16.0, 17.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 11.0, 12.0, 0.0, + 14.0, 15.0, 0.0, 17.0, 18.0, 0.0, 0.0, 13.0, 14.0, 0.0, 16.0, 17.0, 0.0, 0.0, 0.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 0.0, 0.0, 0.0, 14.0, 15.0, 0.0, 17.0, 18.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 19.0, 20.0, 0.0, 22.0, 23.0, 0.0, 0.0, 0.0, 19.0, + 20.0, 21.0, 22.0, 23.0, 24.0, 0.0, 0.0, 0.0, 20.0, 21.0, 0.0, 23.0, 24.0, 0.0, 0.0, + 19.0, 20.0, 0.0, 22.0, 23.0, 0.0, 25.0, 26.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 26.0, 27.0, 20.0, 21.0, 0.0, 23.0, 24.0, 0.0, 26.0, 27.0, 0.0, 0.0, 22.0, 23.0, 0.0, + 25.0, 26.0, 0.0, 0.0, 0.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 0.0, 0.0, 0.0, 23.0, + 24.0, 0.0, 26.0, 27.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.0, 29.0, 0.0, 31.0, + 32.0, 0.0, 0.0, 0.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 0.0, 0.0, 0.0, 29.0, 30.0, + 0.0, 32.0, 33.0, 0.0, 0.0, 28.0, 29.0, 0.0, 31.0, 32.0, 0.0, 34.0, 35.0, 28.0, 29.0, + 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 29.0, 30.0, 0.0, 32.0, 33.0, 0.0, 35.0, 36.0, + 0.0, 0.0, 31.0, 32.0, 0.0, 34.0, 35.0, 0.0, 0.0, 0.0, 31.0, 32.0, 33.0, 34.0, 35.0, + 36.0, 0.0, 0.0, 0.0, 32.0, 33.0, 0.0, 35.0, 36.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 37.0, 38.0, 0.0, 40.0, 41.0, 0.0, 0.0, 0.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, + 0.0, 0.0, 0.0, 38.0, 39.0, 0.0, 41.0, 42.0, 0.0, 0.0, 37.0, 38.0, 0.0, 40.0, 41.0, 0.0, + 43.0, 44.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 38.0, 39.0, 0.0, + 41.0, 42.0, 0.0, 44.0, 45.0, 0.0, 0.0, 40.0, 41.0, 0.0, 43.0, 44.0, 0.0, 0.0, 0.0, + 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 0.0, 0.0, 0.0, 41.0, 42.0, 0.0, 44.0, 45.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 46.0, 47.0, 0.0, 49.0, 50.0, 0.0, 0.0, 0.0, 46.0, + 47.0, 48.0, 49.0, 50.0, 51.0, 0.0, 0.0, 0.0, 47.0, 48.0, 0.0, 50.0, 51.0, 0.0, 0.0, + 46.0, 47.0, 0.0, 49.0, 50.0, 0.0, 52.0, 53.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, + 53.0, 54.0, 47.0, 48.0, 0.0, 50.0, 51.0, 0.0, 53.0, 54.0, 0.0, 0.0, 49.0, 50.0, 0.0, + 52.0, 53.0, 0.0, 0.0, 0.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 0.0, 0.0, 0.0, 50.0, + 51.0, 0.0, 53.0, 54.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 55.0, 56.0, 0.0, 58.0, + 59.0, 0.0, 0.0, 0.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 0.0, 0.0, 0.0, 56.0, 57.0, + 0.0, 59.0, 60.0, 0.0, 0.0, 55.0, 56.0, 0.0, 58.0, 59.0, 0.0, 61.0, 62.0, 55.0, 56.0, + 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 56.0, 57.0, 0.0, 59.0, 60.0, 0.0, 62.0, 63.0, + 0.0, 0.0, 58.0, 59.0, 0.0, 61.0, 62.0, 0.0, 0.0, 0.0, 58.0, 59.0, 60.0, 61.0, 62.0, + 63.0, 0.0, 0.0, 0.0, 59.0, 60.0, 0.0, 62.0, 63.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 64.0, 65.0, 0.0, 67.0, 68.0, 0.0, 0.0, 0.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, + 0.0, 0.0, 0.0, 65.0, 66.0, 0.0, 68.0, 69.0, 0.0, 0.0, 64.0, 65.0, 0.0, 67.0, 68.0, 0.0, + 70.0, 71.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 65.0, 66.0, 0.0, + 68.0, 69.0, 0.0, 71.0, 72.0, 0.0, 0.0, 67.0, 68.0, 0.0, 70.0, 71.0, 0.0, 0.0, 0.0, + 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 0.0, 0.0, 0.0, 68.0, 69.0, 0.0, 71.0, 72.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 73.0, 74.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 73.0, + 74.0, 75.0, 76.0, 77.0, 78.0, 0.0, 0.0, 0.0, 74.0, 75.0, 0.0, 77.0, 78.0, 0.0, 0.0, + 73.0, 74.0, 0.0, 76.0, 77.0, 0.0, 79.0, 80.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, + 80.0, 81.0, 74.0, 75.0, 0.0, 77.0, 78.0, 0.0, 80.0, 81.0, 0.0, 0.0, 76.0, 77.0, 0.0, + 79.0, 80.0, 0.0, 0.0, 0.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 0.0, 0.0, 0.0, 77.0, + 78.0, 0.0, 80.0, 81.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 85.0, + 86.0, 0.0, 0.0, 0.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 0.0, 0.0, 0.0, 83.0, 84.0, + 0.0, 86.0, 87.0, 0.0, 0.0, 82.0, 83.0, 0.0, 85.0, 86.0, 0.0, 88.0, 89.0, 82.0, 83.0, + 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 83.0, 84.0, 0.0, 86.0, 87.0, 0.0, 89.0, 90.0, + 0.0, 0.0, 85.0, 86.0, 0.0, 88.0, 89.0, 0.0, 0.0, 0.0, 85.0, 86.0, 87.0, 88.0, 89.0, + 90.0, 0.0, 0.0, 0.0, 86.0, 87.0, 0.0, 89.0, 90.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 91.0, 92.0, 0.0, 94.0, 95.0, 0.0, 0.0, 0.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, + 0.0, 0.0, 0.0, 92.0, 93.0, 0.0, 95.0, 96.0, 0.0, 0.0, 91.0, 92.0, 0.0, 94.0, 95.0, 0.0, + 97.0, 98.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 92.0, 93.0, 0.0, + 95.0, 96.0, 0.0, 98.0, 99.0, 0.0, 0.0, 94.0, 95.0, 0.0, 97.0, 98.0, 0.0, 0.0, 0.0, + 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 0.0, 0.0, 0.0, 95.0, 96.0, 0.0, 98.0, 99.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 101.0, 0.0, 103.0, 104.0, 0.0, 0.0, 0.0, + 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 0.0, 0.0, 0.0, 101.0, 102.0, 0.0, 104.0, + 105.0, 0.0, 0.0, 100.0, 101.0, 0.0, 103.0, 104.0, 0.0, 106.0, 107.0, 100.0, 101.0, + 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 101.0, 102.0, 0.0, 104.0, 105.0, 0.0, + 107.0, 108.0, 0.0, 0.0, 103.0, 104.0, 0.0, 106.0, 107.0, 0.0, 0.0, 0.0, 103.0, 104.0, + 105.0, 106.0, 107.0, 108.0, 0.0, 0.0, 0.0, 104.0, 105.0, 0.0, 107.0, 108.0, 0.0, 0.0, + 0.0, 0.0, + ]; + + let batches = 4_usize; + let in_channels = 1_usize; + let y = 3_usize; + let x = 3_usize; + + let k_y = 3_usize; + let k_x = 3_usize; + + let a_shape = Shape::from([batches, in_channels, y, x].as_slice()); + let kernel_shape = Shape::from([in_channels, k_y, k_x].as_slice()); + + let a = Array::from_iter( + (1..=(batches * in_channels * y * x)).map(|x| x as f32), + a_shape, + ); + + let res = im2col_2d::<_, Array<_>>(a.as_ref(), &kernel_shape, 1, 1); + + for (res_element, expected_element) in res.iter_units().zip(expected.iter()) { + assert_eq!(res_element, *expected_element); + } + } +} diff --git a/src/engine/tensor/mod.rs b/src/engine/tensor/mod.rs index 52605b9..699ee1e 100644 --- a/src/engine/tensor/mod.rs +++ b/src/engine/tensor/mod.rs @@ -43,6 +43,16 @@ pub struct Array { } impl Array { + + pub fn from_data(stride: Stride, shape: Shape, data: Arc<[T]>, offset: usize) -> Self { + Self { + stride, + shape, + data, + offset, + } + } + //Am I dumb? //This is wrong!!! //TODO @@ -108,7 +118,7 @@ impl EngineTensor for Array { //Attempts to efficiently reuse memory if tensor is contiguous //If this is not an option it will copy from an iterator fn reshape(&self, shape: &Shape) -> Box> { - if shape.len() == self.shape().len() { + if shape.elements() == self.shape().elements() { if self.is_contiguous() { Box::new(Array:: { stride: Stride::default_from_shape(shape), diff --git a/src/engine/tensor/padded.rs b/src/engine/tensor/padded.rs index a12a77f..ded4e07 100644 --- a/src/engine/tensor/padded.rs +++ b/src/engine/tensor/padded.rs @@ -120,7 +120,7 @@ impl EngineTensor for Padded { fn slice(&self, slice: &Slice) -> Box> { let slice_shape = slice.inferred_shape(self.shape()); - let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step_index())); + let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step())); let start_rel = slice.start(); let start_abs = self.relative_to_absolute_pos(&start_rel); @@ -186,8 +186,9 @@ mod test { fn create_examples() -> Vec> { vec![ - Padded::pad_from(Array::from_iter([0.0; 0].iter().copied(), shape![]), varr![], 0.0), + //Padded::pad_from(Array::from_iter([0.0; 0].iter().copied(), shape![]), varr![], 0.0), Padded::pad_from(Array::from_iter((1..=9).map(|x| x as f32 / 9.0), shape![9]), varr![1], 0.0), + Padded::pad_from(Array::from_iter((1..=9).map(|x| x as f32 / 9.0), shape![3, 3]), varr![1, 1], 0.0), Padded::pad_from(Array::from_iter((1..=105).map(|x| x as f32 / 105.0), shape![3, 5, 7]), varr![1, 2, 3], 0.0), Padded::pad_from(Array::from_iter((1..=105).map(|x| x as f32 / 105.0), shape![1, 1, 3, 5, 7]), varr![0, 1, 1, 2, 3], 0.0), ] @@ -201,8 +202,10 @@ mod test { let shape = example.shape(); + println!("{:?}", Array::from_iter(example.iter_units(), shape.clone())); + for curr_pos in shape.first().iter_positions(&shape.last(), &shape) { - println!("{:?}", curr_pos); + //println!("{:?}", curr_pos); let abs_pos = example.relative_to_absolute_pos(&curr_pos); diff --git a/src/helper/slice.rs b/src/helper/slice.rs index c40c46a..36cf0a8 100644 --- a/src/helper/slice.rs +++ b/src/helper/slice.rs @@ -53,6 +53,14 @@ impl Interval { } } + pub fn only(start: usize) -> Self { + Self { + start: Some(start), + finish: Some(start + 1), + step: None, + } + } + pub fn all() -> Self { Self { start: None, @@ -69,14 +77,14 @@ impl Interval { self.finish.unwrap_or(dim) } - pub fn step_index(&self) -> usize { + pub fn step(&self) -> usize { self.step.unwrap_or(1) } pub fn len(&self, dim: usize) -> usize { let start_index = self.start_index(); let finish_index = self.finish_index(dim); - let step_index = self.step_index(); + let step_index = self.step(); (finish_index - start_index) / step_index }