Skip to content

Commit

Permalink
Add im2col
Browse files Browse the repository at this point in the history
  • Loading branch information
rileysu committed Feb 4, 2024
1 parent 6a55a41 commit 1180941
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 57 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
/target
/old_src
/old_src
flamegraph.svg
perf.data
perf.data.old
83 changes: 34 additions & 49 deletions src/engine/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,49 @@ use std::ops::IndexMut;

use itertools::Itertools;

use crate::{engine::{Engine, EngineError, EngineTensorFactory, EngineTensor, util::{err_if_incorrect_dimensions, err_if_dimension_mismatch}}, helper::{Position, Interval, Slice, Shape, Stride, VarArrayCompatible}};
use crate::{engine::{Engine, EngineError, EngineTensorFactory, EngineTensor, util::{err_if_incorrect_dimensions, err_if_dimension_mismatch}}, helper::{shape, Interval, Position, Shape, Slice, Stride, VarArrayCompatible}};

use super::shared::im2col_2d;

pub struct Basic {}



macro_rules! conv_fn {
($unit:ty) => {
fn conv2d<E: EngineTensorFactory<Unit = $unit>>(a: &dyn EngineTensor<Unit = $unit>, kernel: &dyn EngineTensor<Unit = $unit>, stride: usize) -> Result<Box<dyn EngineTensor<Unit = $unit>>, crate::engine::EngineError> {
//a: (batches, in_channels, y, x)
//kernel: (out_channels, in_channels, k_y, k_x)

//a: (batches, in_channels, y, x)
//kernel: (out_channels, in_channels, k_y, k_x)
fn conv2d<E: EngineTensorFactory<Unit = $unit>>(a: &dyn EngineTensor<Unit = $unit>, kernel: &dyn EngineTensor<Unit = $unit>, padding: usize, stride: usize) -> Result<Box<dyn EngineTensor<Unit = $unit>>, crate::engine::EngineError> {
err_if_incorrect_dimensions(a.shape(), 4)?;
err_if_incorrect_dimensions(kernel.shape(), 4)?;
err_if_dimension_mismatch(a.shape().get(1).unwrap(), kernel.shape().get(1).unwrap())?;

let y = a.shape().get(a.shape().len() - 2).unwrap();
let x = a.shape().get(a.shape().len() - 1).unwrap();
let k_y = kernel.shape().get(kernel.shape().len() - 2).unwrap() + 2 * (stride - 1);
let k_x = kernel.shape().get(kernel.shape().len() - 1).unwrap() + 2 * (stride - 1);


let batches = a.shape().get(0).unwrap();
let in_channels = a.shape().get(1).unwrap();

let out_channels = kernel.shape().get(0).unwrap();
let in_channels = kernel.shape().get(1).unwrap();

//(batches, out_channels, in_channels, y, x)
let a_broadcast = a.broadcast_splice(1, &[out_channels]);
//(batches, out_channels, in_channels, k_y, k_x)
let kernel_broadcast = kernel.broadcast_splice(0, &[batches]);
let k_y = kernel.shape().get(2).unwrap();
let k_x = kernel.shape().get(3).unwrap();

//(batches, out_channels, in_channels, out_y, out_x, patch_len)
let proc = im2col_2d::<$unit, E>(a, kernel.shape(), padding, stride).broadcast_splice(1, &[out_channels]);

let out_y = proc.shape().get(3).unwrap();
let out_x = proc.shape().get(4).unwrap();
let patch_len = proc.shape().get(5).unwrap();

//(batches, out_channels, in_channels, out_y, out_x, patch_len)
let kernels = kernel.reshape(&shape![out_channels, in_channels, k_y * k_x]).broadcast_splice(0, &[batches]).broadcast_splice(3, [out_y, out_x].as_slice());

let half_k_y = k_y / 2;
let half_k_x = k_x / 2;

let y_out = y - half_k_y * 2;
let x_out = x - half_k_x * 2;
let out_shape = Shape::from([batches, out_channels, y_out, x_out].as_slice());
let out_stride = Stride::default_from_shape(&out_shape);

let mut reordered_sums: Box<[$unit]> = vec![<$unit>::default(); batches * out_channels * y_out * x_out].into_boxed_slice();

for curr_y in 0..y_out {
for curr_x in 0..x_out {
let a_sliced = a_broadcast.slice(&Slice::from([Interval::all(), Interval::all(), Interval::all(), Interval::between_with_step(curr_y, curr_y + k_y, stride), Interval::between_with_step(curr_x, curr_x + k_x, stride)].as_slice()));

let chunked_products = a_sliced.iter_units().zip(kernel_broadcast.iter_units()).map(|(a_elem, k_elem)| a_elem * k_elem).chunks(in_channels * k_y * k_x);
let curr_sums = chunked_products.into_iter().map(|i| -> $unit { i.sum() });

for (i, sum) in curr_sums.enumerate() {
let batch = i / out_channels;
let out_channel = i % out_channels;

let index = Position::from([batch, out_channel, curr_y, curr_x].as_slice()).tensor_index(&out_stride)?;

*reordered_sums.index_mut(index) = sum;
}
}
}

//(batches, out_channels, y_out, x_out)
Ok(E::from_slice(&reordered_sums, out_shape))
println!("{:?}", out_y);
println!("{:?}", proc.shape());
println!("{:?}", kernels.shape());

let chunked_iter = proc.iter_units().zip(kernels.iter_units()).map(|(x, y)| x * y).chunks(patch_len);
let out_data = chunked_iter.into_iter().map(|i| i.sum());

//(batches, out_channels, out_y, out_x)
Ok(E::from_iter(out_data, shape![batches, out_channels, out_y, out_x]))
}
};
}
Expand Down Expand Up @@ -237,10 +222,10 @@ mod test {

#[test]
pub fn conv() {
let a = Array::from_iter((1..=65536).map(|x| (x as f32) / 65536.0).cycle().take(1 * 3 * 256 * 256), shape![1, 3, 256, 256]);
let a = Array::from_iter((1..=65536).map(|x| (x as f32) / 65536.0).cycle().take(4 * 3 * 256 * 256), shape![4, 3, 256, 256]);
let kernel = Array::from_iter((1..=9).map(|x| x as f32).cycle().take(1 * 3 * 3 * 3), shape![1, 3, 3, 3]);

let res = Basic::conv2d::<Array<f32>>(a.as_ref(), kernel.as_ref(), 2).unwrap();
let res = Basic::conv2d::<Array<f32>>(a.as_ref(), kernel.as_ref(), 2, 1).unwrap();

println!("{:?}", res.shape());
//println!("{:?}", res.iter_unit().collect::<Vec<f32>>());
Expand Down
3 changes: 2 additions & 1 deletion src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod tensor;
pub mod basic;

mod shared;
mod util;

use crate::helper::{Shape, PositionError};
Expand Down Expand Up @@ -31,7 +32,7 @@ pub trait Engine<T: AllowedUnit> {
fn div<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, b: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;

//Conv
fn conv2d<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, kernel: &dyn EngineTensor<Unit = T>, stride: usize) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
fn conv2d<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, kernel: &dyn EngineTensor<Unit = T>, padding: usize, stride: usize) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
}

#[derive(Error, Debug)]
Expand Down
189 changes: 189 additions & 0 deletions src/engine/shared.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use std::iter;

use crate::{
engine::tensor::padded::Padded,
helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible},
};

use super::tensor::{allowed_unit::AllowedUnit, factory::EngineTensorFactory, Array, EngineTensor};

//a: (batches, in_channels, img_y, img_x)
//kernel_shape: (in_channels, k_y, k_x)
//out: (batches, in_channels, out_y, out_x, k_y * k_x)
pub fn im2col_2d<T: AllowedUnit + Default, E: EngineTensorFactory<Unit = T>>(
a: &dyn EngineTensor<Unit = T>,
kernel_shape: &Shape,
padding: usize,
stride: usize,
) -> Box<dyn EngineTensor<Unit = T>> {
let batches = a.shape().get(0).unwrap();
let in_channels = a.shape().get(1).unwrap();

//Ok if zero padding
let a_padded = Padded::pad_from(
a.clone(),
[0, 0, padding, padding].as_slice().into(),
T::default(),
);

let img_y = a_padded.shape().get(2).unwrap();
let img_x = a_padded.shape().get(3).unwrap();

let k_y = kernel_shape.get(1).unwrap();
let k_x = kernel_shape.get(2).unwrap();

let out_y = (img_y - k_y) / stride + 1;
let out_x = (img_x - k_x) / stride + 1;

let patch_len = k_y * k_x;

let out_shape = Shape::from([batches, in_channels, out_y, out_x, patch_len].as_slice());
let out_stride = Stride::default_from_shape(&out_shape);

//let final_img_dims = Shape::new(kernel_shape.iter().zip(img_dims.iter()).map(|(k_d, a_d)| (a_d + 2 * padding - k_d) / stride + 1).collect());

let grouped_patches_shape = Shape::from([batches, in_channels, patch_len].as_slice());

//Buffer used for output

let mut buffer = Vec::<T>::from_iter(iter::repeat(T::default()).take(out_shape.elements()));
buffer.shrink_to_fit();

for y in 0..out_y {
for x in 0..out_x {
let grouped_patches = a_padded.slice(&Slice::from(
[
Interval::all(),
Interval::all(),
Interval::between_with_step(y, y + k_y, stride),
Interval::between_with_step(x, x + k_x, stride),
]
.as_slice(),
));

let grouped_patches = grouped_patches.reshape(&grouped_patches_shape);

for batch in 0..batches {
for channel in 0..in_channels {
let patch = grouped_patches.slice(&Slice::from(
[
Interval::only(batch),
Interval::only(channel),
Interval::all(),
]
.as_slice(),
));

let start_index = Position::from([batch, channel, y, x, 0].as_slice())
.tensor_index(&out_stride)
.unwrap();

buffer.splice(start_index..(start_index + patch_len), patch.iter_units());
}
}
}
}

E::from_slice(buffer.as_slice(), out_shape)
}

#[cfg(test)]
mod test {
use crate::engine::tensor::Array;

use super::*;

#[test]
fn simple_im2col_2d() {
//Pytorch generated im2col
let expected: [f32; 972] = [
0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 0.0, 0.0, 0.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0,
7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0,
8.0, 9.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 8.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
10.0, 11.0, 0.0, 13.0, 14.0, 0.0, 0.0, 0.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0,
0.0, 0.0, 11.0, 12.0, 0.0, 14.0, 15.0, 0.0, 0.0, 10.0, 11.0, 0.0, 13.0, 14.0, 0.0,
16.0, 17.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 11.0, 12.0, 0.0,
14.0, 15.0, 0.0, 17.0, 18.0, 0.0, 0.0, 13.0, 14.0, 0.0, 16.0, 17.0, 0.0, 0.0, 0.0,
13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 0.0, 0.0, 0.0, 14.0, 15.0, 0.0, 17.0, 18.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 19.0, 20.0, 0.0, 22.0, 23.0, 0.0, 0.0, 0.0, 19.0,
20.0, 21.0, 22.0, 23.0, 24.0, 0.0, 0.0, 0.0, 20.0, 21.0, 0.0, 23.0, 24.0, 0.0, 0.0,
19.0, 20.0, 0.0, 22.0, 23.0, 0.0, 25.0, 26.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
26.0, 27.0, 20.0, 21.0, 0.0, 23.0, 24.0, 0.0, 26.0, 27.0, 0.0, 0.0, 22.0, 23.0, 0.0,
25.0, 26.0, 0.0, 0.0, 0.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 0.0, 0.0, 0.0, 23.0,
24.0, 0.0, 26.0, 27.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 28.0, 29.0, 0.0, 31.0,
32.0, 0.0, 0.0, 0.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 0.0, 0.0, 0.0, 29.0, 30.0,
0.0, 32.0, 33.0, 0.0, 0.0, 28.0, 29.0, 0.0, 31.0, 32.0, 0.0, 34.0, 35.0, 28.0, 29.0,
30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 29.0, 30.0, 0.0, 32.0, 33.0, 0.0, 35.0, 36.0,
0.0, 0.0, 31.0, 32.0, 0.0, 34.0, 35.0, 0.0, 0.0, 0.0, 31.0, 32.0, 33.0, 34.0, 35.0,
36.0, 0.0, 0.0, 0.0, 32.0, 33.0, 0.0, 35.0, 36.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 37.0, 38.0, 0.0, 40.0, 41.0, 0.0, 0.0, 0.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0,
0.0, 0.0, 0.0, 38.0, 39.0, 0.0, 41.0, 42.0, 0.0, 0.0, 37.0, 38.0, 0.0, 40.0, 41.0, 0.0,
43.0, 44.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 38.0, 39.0, 0.0,
41.0, 42.0, 0.0, 44.0, 45.0, 0.0, 0.0, 40.0, 41.0, 0.0, 43.0, 44.0, 0.0, 0.0, 0.0,
40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 0.0, 0.0, 0.0, 41.0, 42.0, 0.0, 44.0, 45.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 46.0, 47.0, 0.0, 49.0, 50.0, 0.0, 0.0, 0.0, 46.0,
47.0, 48.0, 49.0, 50.0, 51.0, 0.0, 0.0, 0.0, 47.0, 48.0, 0.0, 50.0, 51.0, 0.0, 0.0,
46.0, 47.0, 0.0, 49.0, 50.0, 0.0, 52.0, 53.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0,
53.0, 54.0, 47.0, 48.0, 0.0, 50.0, 51.0, 0.0, 53.0, 54.0, 0.0, 0.0, 49.0, 50.0, 0.0,
52.0, 53.0, 0.0, 0.0, 0.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 0.0, 0.0, 0.0, 50.0,
51.0, 0.0, 53.0, 54.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 55.0, 56.0, 0.0, 58.0,
59.0, 0.0, 0.0, 0.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 0.0, 0.0, 0.0, 56.0, 57.0,
0.0, 59.0, 60.0, 0.0, 0.0, 55.0, 56.0, 0.0, 58.0, 59.0, 0.0, 61.0, 62.0, 55.0, 56.0,
57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 56.0, 57.0, 0.0, 59.0, 60.0, 0.0, 62.0, 63.0,
0.0, 0.0, 58.0, 59.0, 0.0, 61.0, 62.0, 0.0, 0.0, 0.0, 58.0, 59.0, 60.0, 61.0, 62.0,
63.0, 0.0, 0.0, 0.0, 59.0, 60.0, 0.0, 62.0, 63.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 64.0, 65.0, 0.0, 67.0, 68.0, 0.0, 0.0, 0.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0,
0.0, 0.0, 0.0, 65.0, 66.0, 0.0, 68.0, 69.0, 0.0, 0.0, 64.0, 65.0, 0.0, 67.0, 68.0, 0.0,
70.0, 71.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 65.0, 66.0, 0.0,
68.0, 69.0, 0.0, 71.0, 72.0, 0.0, 0.0, 67.0, 68.0, 0.0, 70.0, 71.0, 0.0, 0.0, 0.0,
67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 0.0, 0.0, 0.0, 68.0, 69.0, 0.0, 71.0, 72.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 73.0, 74.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 73.0,
74.0, 75.0, 76.0, 77.0, 78.0, 0.0, 0.0, 0.0, 74.0, 75.0, 0.0, 77.0, 78.0, 0.0, 0.0,
73.0, 74.0, 0.0, 76.0, 77.0, 0.0, 79.0, 80.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0,
80.0, 81.0, 74.0, 75.0, 0.0, 77.0, 78.0, 0.0, 80.0, 81.0, 0.0, 0.0, 76.0, 77.0, 0.0,
79.0, 80.0, 0.0, 0.0, 0.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 0.0, 0.0, 0.0, 77.0,
78.0, 0.0, 80.0, 81.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 85.0,
86.0, 0.0, 0.0, 0.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 0.0, 0.0, 0.0, 83.0, 84.0,
0.0, 86.0, 87.0, 0.0, 0.0, 82.0, 83.0, 0.0, 85.0, 86.0, 0.0, 88.0, 89.0, 82.0, 83.0,
84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 83.0, 84.0, 0.0, 86.0, 87.0, 0.0, 89.0, 90.0,
0.0, 0.0, 85.0, 86.0, 0.0, 88.0, 89.0, 0.0, 0.0, 0.0, 85.0, 86.0, 87.0, 88.0, 89.0,
90.0, 0.0, 0.0, 0.0, 86.0, 87.0, 0.0, 89.0, 90.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 91.0, 92.0, 0.0, 94.0, 95.0, 0.0, 0.0, 0.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0,
0.0, 0.0, 0.0, 92.0, 93.0, 0.0, 95.0, 96.0, 0.0, 0.0, 91.0, 92.0, 0.0, 94.0, 95.0, 0.0,
97.0, 98.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 92.0, 93.0, 0.0,
95.0, 96.0, 0.0, 98.0, 99.0, 0.0, 0.0, 94.0, 95.0, 0.0, 97.0, 98.0, 0.0, 0.0, 0.0,
94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 0.0, 0.0, 0.0, 95.0, 96.0, 0.0, 98.0, 99.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 101.0, 0.0, 103.0, 104.0, 0.0, 0.0, 0.0,
100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 0.0, 0.0, 0.0, 101.0, 102.0, 0.0, 104.0,
105.0, 0.0, 0.0, 100.0, 101.0, 0.0, 103.0, 104.0, 0.0, 106.0, 107.0, 100.0, 101.0,
102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 101.0, 102.0, 0.0, 104.0, 105.0, 0.0,
107.0, 108.0, 0.0, 0.0, 103.0, 104.0, 0.0, 106.0, 107.0, 0.0, 0.0, 0.0, 103.0, 104.0,
105.0, 106.0, 107.0, 108.0, 0.0, 0.0, 0.0, 104.0, 105.0, 0.0, 107.0, 108.0, 0.0, 0.0,
0.0, 0.0,
];

let batches = 4_usize;
let in_channels = 1_usize;
let y = 3_usize;
let x = 3_usize;

let k_y = 3_usize;
let k_x = 3_usize;

let a_shape = Shape::from([batches, in_channels, y, x].as_slice());
let kernel_shape = Shape::from([in_channels, k_y, k_x].as_slice());

let a = Array::from_iter(
(1..=(batches * in_channels * y * x)).map(|x| x as f32),
a_shape,
);

let res = im2col_2d::<_, Array<_>>(a.as_ref(), &kernel_shape, 1, 1);

for (res_element, expected_element) in res.iter_units().zip(expected.iter()) {
assert_eq!(res_element, *expected_element);
}
}
}
12 changes: 11 additions & 1 deletion src/engine/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ pub struct Array<T: AllowedArray> {
}

impl<T: AllowedArray> Array<T> {

pub fn from_data(stride: Stride, shape: Shape, data: Arc<[T]>, offset: usize) -> Self {
Self {
stride,
shape,
data,
offset,
}
}

//Am I dumb?
//This is wrong!!!
//TODO
Expand Down Expand Up @@ -108,7 +118,7 @@ impl<T: AllowedArray> EngineTensor for Array<T> {
//Attempts to efficiently reuse memory if tensor is contiguous
//If this is not an option it will copy from an iterator
fn reshape(&self, shape: &Shape) -> Box<dyn EngineTensor<Unit = T>> {
if shape.len() == self.shape().len() {
if shape.elements() == self.shape().elements() {
if self.is_contiguous() {
Box::new(Array::<T> {
stride: Stride::default_from_shape(shape),
Expand Down
9 changes: 6 additions & 3 deletions src/engine/tensor/padded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl<T: AllowedPadded> EngineTensor for Padded<T> {
fn slice(&self, slice: &Slice) -> Box<dyn EngineTensor<Unit = Self::Unit>> {
let slice_shape = slice.inferred_shape(self.shape());

let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step_index()));
let steps = VarArray::from_iter(slice.as_boxed_slice().iter().map(|int| int.step()));

let start_rel = slice.start();
let start_abs = self.relative_to_absolute_pos(&start_rel);
Expand Down Expand Up @@ -186,8 +186,9 @@ mod test {

fn create_examples() -> Vec<Padded<f32>> {
vec![
Padded::pad_from(Array::from_iter([0.0; 0].iter().copied(), shape![]), varr![], 0.0),
//Padded::pad_from(Array::from_iter([0.0; 0].iter().copied(), shape![]), varr![], 0.0),
Padded::pad_from(Array::from_iter((1..=9).map(|x| x as f32 / 9.0), shape![9]), varr![1], 0.0),
Padded::pad_from(Array::from_iter((1..=9).map(|x| x as f32 / 9.0), shape![3, 3]), varr![1, 1], 0.0),
Padded::pad_from(Array::from_iter((1..=105).map(|x| x as f32 / 105.0), shape![3, 5, 7]), varr![1, 2, 3], 0.0),
Padded::pad_from(Array::from_iter((1..=105).map(|x| x as f32 / 105.0), shape![1, 1, 3, 5, 7]), varr![0, 1, 1, 2, 3], 0.0),
]
Expand All @@ -201,8 +202,10 @@ mod test {

let shape = example.shape();

println!("{:?}", Array::from_iter(example.iter_units(), shape.clone()));

for curr_pos in shape.first().iter_positions(&shape.last(), &shape) {
println!("{:?}", curr_pos);
//println!("{:?}", curr_pos);

let abs_pos = example.relative_to_absolute_pos(&curr_pos);

Expand Down
Loading

0 comments on commit 1180941

Please sign in to comment.