From 6672d0dc48ff8c7fa8ce97515011908a7e200d77 Mon Sep 17 00:00:00 2001 From: Riley Sutton Date: Sun, 31 Dec 2023 19:57:51 +1100 Subject: [PATCH] WIP --- Cargo.lock | 16 +++ Cargo.toml | 1 + ideas.md | 42 ++++++ operations.md | 6 +- src/comp_graph/mod.rs | 8 +- src/engine/basic.rs | 249 ++++++++++++++++++++++++++++++++++- src/engine/mod.rs | 13 +- src/engine/tensor/basic.rs | 164 ----------------------- src/engine/tensor/iter.rs | 45 ++++++- src/engine/tensor/mod.rs | 146 +++++++++++++++++--- src/engine/tensor/padded.rs | 95 +++++++++++++ src/engine/util.rs | 18 +++ src/helper/inferred_shape.rs | 37 ++++++ src/helper/mod.rs | 2 + src/helper/position.rs | 4 +- src/helper/shape.rs | 17 ++- src/helper/slice.rs | 8 ++ 17 files changed, 672 insertions(+), 199 deletions(-) delete mode 100644 src/engine/tensor/basic.rs create mode 100644 src/engine/tensor/padded.rs create mode 100644 src/helper/inferred_shape.rs diff --git a/Cargo.lock b/Cargo.lock index 793816f..20d06c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,21 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "num" version = "0.4.1" @@ -147,6 +162,7 @@ dependencies = [ name = "therml" version = "0.1.0" dependencies = [ + "itertools", "num", "ord_subset", "slotmap", diff --git a/Cargo.toml b/Cargo.toml index 27d26e8..4f9a53b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +itertools = "0.12.0" num = "0.4.1" ord_subset = "3.1.1" slotmap = "1.0.6" diff --git a/ideas.md b/ideas.md index 92cfdd9..43797dc 100644 --- a/ideas.md +++ b/ideas.md @@ -21,6 +21,48 @@ - Tensors all have an implicit ordering of position indexes which follows the significance from left being most significant and right being least significant - Iterators between positions, reshaping all operate on this idea +## Algorithms + +### Conv2d +``` +a1: (batches, in_channels, y, x) +k1: (out_channels, in_channels, k_y, k_x) + +--- + +let a2: (batches, out_channels, in_channels, y, x) = a1.broadcast_splice([out_channels], 1) +let k2: (batches, out_channels, in_channels, k_y, k_x) = k1.broadcast_splice([batches], 0) + +let k_half_len_y = (k_y // 2) +let k_half_len_x = (k_x // 2) + +let start_y = k_half_len_y +let end_y = y - start_y +let start_x = k_half_len_x +let end_x = x - start_x + +let mut out_units = + +for curr_y in start_y..end_y { + for curr_x in start_x..end_x { + let a3: (batches, out_channels, in_channels, k_y, k_x) = a2.slice([:, :, :, curr_y-k_half_len_y:curr_y+k_half_len_y, curr_x-k_half_len_x:curr_x+k_half_len_x]) + + let r1: (batches, out_channels, in_channels, k_y, k_x) = a3 * k2 + + let r2: (batches, out_channels, in_channels * k_y * k_x) = r1.reshape([batches, out_channels, in_channels, k_y * k_x]) + + let r3: (batches, out_channels, 1) = r2.sum() + + // batches, x, y + out_units.extend(r3.iter_unit()) + } +} + +return tensor::from_slice(&out_units): (y, x, batches, out_channels) + + +``` + ## TODO - Refactor comp_graph to improve errors (ones with no nodekey) and reduce repeated code diff --git a/operations.md b/operations.md index c0be14c..cdf27c8 100644 --- a/operations.md +++ b/operations.md @@ -6,7 +6,7 @@ Check indictates support * [x] abs * [x] neg -* Pointwise Scalar +* Pointwise Scalar (broadcast?) * [x] add_scalar * [x] sub_scalar_lh (s - t) * [x] sub_scalar_rh (t - s) @@ -31,6 +31,6 @@ Check indictates support * [x] shape * [x] stride * [x] iter - * [ ] slice - * [ ] reshape + * [x] slice + * [x] reshape * [ ] concat diff --git a/src/comp_graph/mod.rs b/src/comp_graph/mod.rs index 927b1f6..a95caf7 100644 --- a/src/comp_graph/mod.rs +++ b/src/comp_graph/mod.rs @@ -5,7 +5,7 @@ use std::collections::{HashSet, HashMap}; use slotmap::{SlotMap, new_key_type}; use thiserror::Error; -use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorIterator}, Engine, EngineError}; +use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorUnitIterator}, Engine, EngineError}; use self::edge::Edge; @@ -101,8 +101,8 @@ impl CompGraph { self.nodes.insert(Node::create_node(edge)) } - pub fn iter(&self, tensor: &CompGraphTensor) -> EngineTensorIterator { - EngineTensorIterator::new(self.get_node(tensor.node_key()).unwrap().tensor().unwrap()) + pub fn iter(&self, tensor: &CompGraphTensor) -> EngineTensorUnitIterator { + EngineTensorUnitIterator::new(self.get_node(tensor.node_key()).unwrap().tensor().unwrap()) } //First return is open nodes, second is node_to_children @@ -470,7 +470,7 @@ mod test { let tensor = new_node_keys.last().unwrap(); - let expected = Array::from_iter( &mut expected_original.iter().map(|x| x * 2.0f32.pow(power)), expected_original.shape().clone()); + let expected = Array::from_iter( &mut expected_original.iter_unit().map(|x| x * 2.0f32.pow(power)), expected_original.shape().clone()); graph.non_populating_eval(&tensor).unwrap(); diff --git a/src/engine/basic.rs b/src/engine/basic.rs index 1f999bb..b5cb440 100644 --- a/src/engine/basic.rs +++ b/src/engine/basic.rs @@ -1 +1,248 @@ -pub struct Basic {} \ No newline at end of file +use std::ops::IndexMut; + +use itertools::Itertools; + +use crate::{engine::{Engine, EngineError, EngineTensorFactory, EngineTensor, util::{err_if_incorrect_dimensions, err_if_dimension_mismatch}}, helper::{Position, Interval, Slice, Shape, Stride}}; + +pub struct Basic {} + +macro_rules! conv_fn { + ($unit:ty) => { + fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, stride: usize) -> Result>, crate::engine::EngineError> { + //a: (batches, in_channels, y, x) + //kernel: (out_channels, in_channels, k_y, k_x) + + err_if_incorrect_dimensions(a.shape(), 4)?; + err_if_incorrect_dimensions(kernel.shape(), 4)?; + err_if_dimension_mismatch(a.shape().dim(1), kernel.shape().dim(1))?; + + let y = a.shape().dim(a.shape().dims() - 2); + let x = a.shape().dim(a.shape().dims() - 1); + let k_y = kernel.shape().dim(kernel.shape().dims() - 2) + 2 * (stride - 1); + let k_x = kernel.shape().dim(kernel.shape().dims() - 1) + 2 * (stride - 1); + + let batches = a.shape().dim(0); + let out_channels = kernel.shape().dim(0); + let in_channels = kernel.shape().dim(1); + + //(batches, out_channels, in_channels, y, x) + let a_broadcast = a.broadcast_splice(1, &[out_channels]); + //(batches, out_channels, in_channels, k_y, k_x) + let kernel_broadcast = kernel.broadcast_splice(0, &[batches]); + + let half_k_y = k_y / 2; + let half_k_x = k_x / 2; + + let y_out = y - half_k_y * 2; + let x_out = x - half_k_x * 2; + let out_shape = Shape::from([batches, out_channels, y_out, x_out].as_slice()); + let out_stride = Stride::from(&out_shape); + + let mut reordered_sums: Box<[$unit]> = vec![<$unit>::default(); batches * out_channels * y_out * x_out].into_boxed_slice(); + + for curr_y in 0..y_out { + for curr_x in 0..x_out { + let a_sliced = a_broadcast.slice(&Slice::from([Interval::all(), Interval::all(), Interval::all(), Interval::between_with_step(curr_y, curr_y + k_y, stride), Interval::between_with_step(curr_x, curr_x + k_x, stride)].as_slice())); + + let chunked_products = a_sliced.iter_unit().zip(kernel_broadcast.iter_unit()).map(|(a_elem, k_elem)| a_elem * k_elem).chunks(in_channels * k_y * k_x); + let curr_sums = chunked_products.into_iter().map(|i| -> $unit { i.sum() }); + + for (i, sum) in curr_sums.enumerate() { + let batch = i / out_channels; + let out_channel = i % out_channels; + + let index = Position::from([batch, out_channel, curr_y, curr_x].as_slice()).tensor_index(&out_stride)?; + + *reordered_sums.index_mut(index) = sum; + } + } + } + + //(batches, out_channels, y_out, x_out) + Ok(E::from_slice(&reordered_sums, out_shape)) + } + }; +} + +macro_rules! basic_impl { + ($unit:ty) => { + impl Engine<$unit> for Basic { + + //Pointwise Single + fn abs>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + + Ok(E::from_iter(a.iter_unit().map(|x| x.abs()), a.shape().clone())) + } + + fn neg>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| -x), a.shape().clone())) + } + + //Scalar + fn add_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x + s), a.shape().clone())) + } + + fn sub_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| s - x), a.shape().clone())) + } + + fn sub_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x - s), a.shape().clone())) + } + + fn mul_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x * s), a.shape().clone())) + } + + fn div_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| s / x), a.shape().clone())) + } + + fn div_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x / s), a.shape().clone())) + } + + //Pointwise Double + fn add>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x + y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn sub>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x - y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn mul>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x * y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x / y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + conv_fn!($unit); + } + }; +} + +macro_rules! basic_unsigned_impl { + ($unit:ty) => { + impl Engine<$unit> for Basic { + + //Pointwise Single + fn abs>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + + Ok(E::from_iter(&mut a.iter_unit().map(|x| x), a.shape().clone())) + } + + fn neg>(_a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + Err(crate::engine::EngineError::OperationUnsupportedForType()) + } + + //Scalar + fn add_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x + s), a.shape().clone())) + } + + fn sub_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| s - x), a.shape().clone())) + } + + fn sub_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x - s), a.shape().clone())) + } + + fn mul_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x * s), a.shape().clone())) + } + + fn div_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| s / x), a.shape().clone())) + } + + fn div_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { + Ok(E::from_iter(a.iter_unit().map(|x| x / s), a.shape().clone())) + } + + //Pointwise Double + fn add>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x + y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn sub>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x - y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn mul>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x * y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { + if a.shape() == b.shape() { + Ok(E::from_iter(&mut a.iter_unit().zip(b.iter_unit()).map(|(x, y)| x / y), a.shape().clone())) + } else { + Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) + } + } + + conv_fn!($unit); + } + }; +} + +basic_impl!(f32); +basic_impl!(f64); +basic_impl!(i8); +basic_impl!(i16); +basic_impl!(i32); +basic_impl!(i64); +basic_unsigned_impl!(u8); +basic_unsigned_impl!(u16); +basic_unsigned_impl!(u32); +basic_unsigned_impl!(u64); + +#[cfg(test)] +mod test { + use crate::{helper::shape, engine::tensor::Array}; + + use super::*; + + #[test] + pub fn conv() { + let a = Array::from_iter((1..=65536).map(|x| (x as f32) / 65536.0).cycle().take(1 * 3 * 256 * 256), shape![1, 3, 256, 256]); + let kernel = Array::from_iter((1..=9).map(|x| x as f32).cycle().take(1 * 3 * 3 * 3), shape![1, 3, 3, 3]); + + let res = Basic::conv2d::>(a.as_ref(), kernel.as_ref(), 2).unwrap(); + + println!("{:?}", res.shape()); + //println!("{:?}", res.iter_unit().collect::>()); + } +} \ No newline at end of file diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 63ad59b..8457d4d 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -3,7 +3,7 @@ pub mod basic; mod util; -use crate::helper::Shape; +use crate::helper::{Shape, PositionError}; use self::tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory}; use thiserror::Error; @@ -29,12 +29,21 @@ pub trait Engine { fn sub>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, EngineError>; fn mul>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, EngineError>; fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, EngineError>; + + //Conv + fn conv2d>(a: &dyn EngineTensor, kernel: &dyn EngineTensor, stride: usize) -> Result>, EngineError>; } #[derive(Error, Debug)] pub enum EngineError { - #[error("The tensor of size {0} does not match {1}")] + #[error("The tensor of shape {0} does not match expected {1}")] ShapeMismatch(Shape, Shape), + #[error("The dimension {0} does not match expected {1}")] + DimensionMismatch(usize, usize), + #[error("Got {0} dimensions but expected {1}")] + IncorrectDimensions(usize, usize), + #[error("Position operation failed: {0}")] + Tensor(#[from] PositionError), #[error("The operation is not supported on this data type")] OperationUnsupportedForType(), } diff --git a/src/engine/tensor/basic.rs b/src/engine/tensor/basic.rs deleted file mode 100644 index 82c8aed..0000000 --- a/src/engine/tensor/basic.rs +++ /dev/null @@ -1,164 +0,0 @@ -use crate::engine::{Engine, basic::Basic, EngineError}; - -use super::{EngineTensor, factory::EngineTensorFactory}; - -macro_rules! basic_impl { - ($unit:ty) => { - impl Engine<$unit> for Basic { - - //Pointwise Single - fn abs>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - - Ok(E::from_iter(a.iter().map(|x| x.abs()), a.shape().clone())) - } - - fn neg>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - Ok(E::from_iter(a.iter().map(|x| -x), a.shape().clone())) - } - - //Scalar - fn add_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x + s), a.shape().clone())) - } - - fn sub_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| s - x), a.shape().clone())) - } - - fn sub_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x - s), a.shape().clone())) - } - - fn mul_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x * s), a.shape().clone())) - } - - fn div_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| s / x), a.shape().clone())) - } - - fn div_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x / s), a.shape().clone())) - } - - //Pointwise Double - fn add>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x + y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn sub>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x - y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn mul>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x * y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x / y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - } - }; -} - -macro_rules! basic_unsigned_impl { - ($unit:ty) => { - impl Engine<$unit> for Basic { - - //Pointwise Single - fn abs>(a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - - Ok(E::from_iter(&mut a.iter().map(|x| x), a.shape().clone())) - } - - fn neg>(_a: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - Err(crate::engine::EngineError::OperationUnsupportedForType()) - } - - //Scalar - fn add_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x + s), a.shape().clone())) - } - - fn sub_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| s - x), a.shape().clone())) - } - - fn sub_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x - s), a.shape().clone())) - } - - fn mul_scalar>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x * s), a.shape().clone())) - } - - fn div_scalar_lh>(s: $unit, a: &dyn EngineTensor) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| s / x), a.shape().clone())) - } - - fn div_scalar_rh>(a: &dyn EngineTensor, s: $unit) -> Result>, EngineError> { - Ok(E::from_iter(a.iter().map(|x| x / s), a.shape().clone())) - } - - //Pointwise Double - fn add>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x + y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn sub>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x - y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn mul>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x * y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - - fn div>(a: &dyn EngineTensor, b: &dyn EngineTensor) -> Result>, crate::engine::EngineError> { - if a.shape() == b.shape() { - Ok(E::from_iter(&mut a.iter().zip(b.iter()).map(|(x, y)| x / y), a.shape().clone())) - } else { - Err(EngineError::ShapeMismatch(a.shape().clone(), b.shape().clone())) - } - } - } - }; -} - -basic_impl!(f32); -basic_impl!(f64); -basic_impl!(i8); -basic_impl!(i16); -basic_impl!(i32); -basic_impl!(i64); -basic_unsigned_impl!(u8); -basic_unsigned_impl!(u16); -basic_unsigned_impl!(u32); -basic_unsigned_impl!(u64); \ No newline at end of file diff --git a/src/engine/tensor/iter.rs b/src/engine/tensor/iter.rs index 7867249..eb79c0a 100644 --- a/src/engine/tensor/iter.rs +++ b/src/engine/tensor/iter.rs @@ -1,17 +1,54 @@ use crate::helper::Position; use super::{allowed_unit::AllowedUnit, EngineTensor}; +/*pub struct EngineTensorIterator<'a, T: AllowedUnit> { + tensor: &'a dyn EngineTensor, + curr: usize, + ended: bool, +} + +impl<'a, T: AllowedUnit> EngineTensorIterator<'a, T> { + pub fn new(tensor: &'a dyn EngineTensor) -> Self { + Self { + tensor, + curr: *tensor.shape().as_boxed_slice().first().unwrap(), + } + } +} + +impl<'a, T: AllowedUnit> Iterator for EngineTensorIterator<'a, T> { + type Item = &'a dyn EngineTensor; + + fn next(&mut self) -> Option { + if self.curr < *self.tensor.shape().as_boxed_slice().first().unwrap() { + let mut pos = self.tensor.shape().first(); + *pos.as_mut_boxed_slice().get_mut(0).unwrap() = self.curr; + + let out = Some(self.tensor.get(pos)); + + if self.curr != self.finish { + self.curr.incdec_mut(self.tensor.shape(), 1); + } else { + self.ended = true; + } + + out + } else { + None + } + } +}*/ + //TODO basic impl that isn't optimised //It can be enhanced by fetching chunks of contig memory if available - -pub struct EngineTensorIterator<'a, T: AllowedUnit> { +pub struct EngineTensorUnitIterator<'a, T: AllowedUnit> { tensor: &'a dyn EngineTensor, curr: Position, finish: Position, ended: bool, } -impl<'a, T: AllowedUnit> EngineTensorIterator<'a, T> { +impl<'a, T: AllowedUnit> EngineTensorUnitIterator<'a, T> { pub fn new(tensor: &'a dyn EngineTensor) -> Self { Self { tensor, @@ -22,7 +59,7 @@ impl<'a, T: AllowedUnit> EngineTensorIterator<'a, T> { } } -impl Iterator for EngineTensorIterator<'_, T> { +impl Iterator for EngineTensorUnitIterator<'_, T> { type Item = T; fn next(&mut self) -> Option { diff --git a/src/engine/tensor/mod.rs b/src/engine/tensor/mod.rs index c5e115a..97ad181 100644 --- a/src/engine/tensor/mod.rs +++ b/src/engine/tensor/mod.rs @@ -1,33 +1,38 @@ pub mod extension; pub mod iter; -pub mod basic; pub mod allowed_unit; pub mod factory; +pub mod padded; use std::sync::Arc; use crate::helper::{Shape, Stride, Position, Slice}; use self::extension::{ExtensionProvider, EmptyExtensionProvider}; -use self::{iter::EngineTensorIterator, allowed_unit::{AllowedUnit, AllowedArray, AllowedQuant}}; +use self::factory::EngineTensorFactory; +use self::{iter::EngineTensorUnitIterator, allowed_unit::{AllowedUnit, AllowedArray, AllowedQuant}}; use std::fmt::Debug; -pub trait EngineTensor<>: Debug { +pub trait EngineTensor: Debug { type Unit: AllowedUnit; fn shape(&self) -> &Shape; fn stride(&self) -> &Stride; fn get(&self, pos: &Position) -> Self::Unit; - fn iter(&self) -> EngineTensorIterator<'_, Self::Unit>; + fn iter_unit(&self) -> EngineTensorUnitIterator<'_, Self::Unit>; + + fn clone(&self) -> Box>; fn slice(&self, slice: &Slice) -> Box>; - fn reshape(&self, shape: Shape) -> Box>; + fn reshape(&self, shape: &Shape) -> Box>; + fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box>; + fn extensions(&self)-> Box; } impl PartialEq for dyn EngineTensor + '_ { fn eq(&self, other: &Self) -> bool { - self.shape() == other.shape() && self.iter().zip(other.iter()).all(|(a, b)| a == b) + self.shape() == other.shape() && self.iter_unit().zip(other.iter_unit()).all(|(a, b)| a == b) } } @@ -39,6 +44,32 @@ pub struct Array { offset: usize, } +impl Array { + //Am I dumb? + //This is wrong!!! + //TODO + fn is_contiguous(&self) -> bool { + let mut check: Option = None; + + for curr in self.stride().as_boxed_slice().iter().copied() { + match check { + Some(prev) => { + if prev * prev == curr { + check = Some(curr); + } else { + return false; + } + }, + None => { + check = Some(curr); + }, + } + } + + return true; + } +} + impl EngineTensor for Array { type Unit = T; @@ -51,28 +82,75 @@ impl EngineTensor for Array { } fn get(&self, pos: &Position) -> T { - let index = pos.index(&self.stride).unwrap() + self.offset; + let index = pos.tensor_index(&self.stride).unwrap() + self.offset; *self.data.as_ref().get(index).unwrap() } - fn iter(&self) -> EngineTensorIterator<'_, T> { - EngineTensorIterator::new(self) + fn iter_unit(&self) -> EngineTensorUnitIterator<'_, T> { + EngineTensorUnitIterator::new(self) + } + + fn clone(&self) -> Box> { + Box::new(Self { + stride: self.stride.clone() , + shape: self.shape.clone(), + data: self.data.clone(), + offset: self.offset, + }) } fn slice(&self, slice: &Slice) -> Box> { - let offset = slice.start().index(self.stride()).unwrap(); + let offset = slice.start().tensor_index(self.stride()).unwrap(); Box::from(Self { stride: self.stride.clone(), - shape: self.shape.clone(), + shape: slice.inferred_shape(self.shape()), data: self.data.clone(), offset, }) } - fn reshape(&self, shape: Shape) -> Box> { - todo!() + //Attempts to efficiently reuse memory if tensor is contiguous + //If this is not an option it will copy from an iterator + fn reshape(&self, shape: &Shape) -> Box> { + if shape.len() == self.shape().len() { + if self.is_contiguous() { + Box::new(Array:: { + stride: Stride::from(shape), + shape: shape.clone(), + data: self.data.clone(), + offset: self.offset, + }) + } else { + Array::::from_iter(self.iter_unit(), shape.clone()) + } + } else { + todo!() + } + } + + fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box> { + if pos <= self.shape().dims() { + let mut shape_buffer = self.shape().as_boxed_slice().to_vec(); + shape_buffer.splice(pos..pos, sub.iter().copied()); + + let broadcast_shape = Shape::new(shape_buffer.into()); + + let mut stride_buffer = self.stride().as_boxed_slice().to_vec(); + stride_buffer.splice(pos..pos, std::iter::repeat(0).take(sub.len())); + + let broadcast_stride = Stride::new(stride_buffer.into()); + + Box::new(Self { + stride: broadcast_stride, + shape: broadcast_shape, + data: self.data.clone(), + offset: self.offset, + }) + } else { + todo!() + } } fn extensions(&self) -> Box { @@ -100,17 +178,26 @@ impl EngineTensor for Quant { } fn get(&self, pos: &Position) -> T { - let index = pos.index(&self.stride).unwrap() + self.offset; + let index = pos.tensor_index(&self.stride).unwrap() + self.offset; *self.data.as_ref().get(index).unwrap() } - fn iter(&self) -> EngineTensorIterator<'_, T> { - EngineTensorIterator::new(self) + fn iter_unit(&self) -> EngineTensorUnitIterator<'_, T> { + EngineTensorUnitIterator::new(self) + } + + fn clone(&self) -> Box> { + Box::new(Self { + stride: self.stride.clone() , + shape: self.shape.clone(), + data: self.data.clone(), + offset: self.offset, + }) } fn slice(&self, slice: &Slice) -> Box> { - let offset = slice.start().index(self.stride()).unwrap(); + let offset = slice.start().tensor_index(self.stride()).unwrap(); Box::from(Self { stride: self.stride.clone(), @@ -120,10 +207,33 @@ impl EngineTensor for Quant { }) } - fn reshape(&self, shape: Shape) -> Box> { + fn reshape(&self, shape: &Shape) -> Box> { todo!() } + fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box> { + if pos <= self.shape().dims() { + let mut shape_buffer = self.shape().as_boxed_slice().to_vec(); + shape_buffer.splice(pos..pos, sub.iter().copied()); + + let broadcast_shape = Shape::new(shape_buffer.into()); + + let mut stride_buffer = self.stride().as_boxed_slice().to_vec(); + stride_buffer.splice(pos..pos, std::iter::repeat(0).take(sub.len())); + + let broadcast_stride = Stride::new(stride_buffer.into()); + + Box::new(Self { + stride: broadcast_stride, + shape: broadcast_shape, + data: self.data.clone(), + offset: self.offset, + }) + } else { + todo!() + } + } + fn extensions(&self) -> Box { Box::from(EmptyExtensionProvider::from(self)) } diff --git a/src/engine/tensor/padded.rs b/src/engine/tensor/padded.rs new file mode 100644 index 0000000..91fb7f0 --- /dev/null +++ b/src/engine/tensor/padded.rs @@ -0,0 +1,95 @@ +use crate::helper::{Shape, Stride, Position}; +use super::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorUnitIterator}; + +pub trait AllowedPadded: AllowedUnit {} +impl AllowedPadded for T {} + +#[derive(Debug)] +pub struct Padded { + tensor: Box>, + stride: Stride, + shape: Shape, + + high_padding: Box<[usize]>, + low_padding: Box<[usize]>, + + padding_val: T, +} + +impl Padded { + pub fn pad_from(a: Box>, padding: Shape, padding_val: T) -> Self { + if a.shape().dims() == padding.dims() { + let shape = Shape::new(a.shape().as_boxed_slice().iter().zip(padding.as_boxed_slice().iter()).map(|(o, p)| o + 2 * p).collect()); + let stride = Stride::from(&shape); + + let high_padding = a.shape().as_boxed_slice().iter().zip(padding.as_boxed_slice().iter()).map(|(o, p)| o + p).collect(); + let low_padding = padding.as_boxed_slice().clone(); + + Self { + tensor: a, + stride, + shape, + + high_padding, + low_padding, + + padding_val, + } + } else { + todo!() + } + } +} + +impl EngineTensor for Padded { + type Unit = T; + + fn shape(&self) -> &Shape { + &self.shape + } + + fn stride(&self) -> &Stride { + &self.stride + } + + fn get(&self, pos: &Position) -> Self::Unit { + if pos.within_bounds(self.shape()) { + let pos_in_unpadded_bounds = pos.as_boxed_slice().iter().zip(self.high_padding.iter()).zip(self.low_padding.iter()).all(|((pos, low), hi)| (*low..*hi).contains(pos)); + + if pos_in_unpadded_bounds { + let middle_pos = Position::new(pos.as_boxed_slice().iter().zip(self.low_padding.iter()).map(|(pos, pad)| pos - pad).collect()); + + self.tensor.get(&middle_pos) + } else { + self.padding_val + } + } else { + todo!() + } + } + + fn iter_unit(&self) -> super::iter::EngineTensorUnitIterator<'_, Self::Unit> { + EngineTensorUnitIterator::new(self) + } + + fn clone(&self) -> Box> { + todo!() + } + + fn slice(&self, slice: &crate::helper::Slice) -> Box> { + todo!() + } + + fn reshape(&self, shape: &Shape) -> Box> { + todo!() + } + + fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box> { + todo!() + } + + fn extensions(&self)-> Box { + todo!() + } + +} \ No newline at end of file diff --git a/src/engine/util.rs b/src/engine/util.rs index 652882e..d1aa92b 100644 --- a/src/engine/util.rs +++ b/src/engine/util.rs @@ -8,4 +8,22 @@ pub fn return_if_matched_shape(a: &Shape, b: &Shape, out: T) -> Result Result<(), EngineError> { + let a_dims = a.dims(); + + if a_dims == expected_dims { + Ok(()) + } else { + Err(EngineError::IncorrectDimensions(a_dims, expected_dims)) + } +} + +pub fn err_if_dimension_mismatch(provided_dim: usize, expected_dim: usize) -> Result<(), EngineError> { + if provided_dim == expected_dim { + Ok(()) + } else { + Err(EngineError::DimensionMismatch(provided_dim, expected_dim)) + } } \ No newline at end of file diff --git a/src/helper/inferred_shape.rs b/src/helper/inferred_shape.rs new file mode 100644 index 0000000..cd682e4 --- /dev/null +++ b/src/helper/inferred_shape.rs @@ -0,0 +1,37 @@ +use super::Shape; + +//Inferred pos is the position where the inferred dimension is. +//eg: +// 0 means it is before the first dimension +// in a 3 len dims, inferred_pos = 3 means the last dimension is inferred +struct InferredShape { + inferred_pos: usize, + dims: Box<[usize]>, +} + +impl InferredShape { + pub fn new(inferred_pos: usize, dims: Box<[usize]>) -> Self { + Self { + inferred_pos, + dims, + } + } + + pub fn infer(&self, length: usize) -> Shape { + let inferred = length / self.dims.iter().product::(); + + let mut new_dims = self.dims.to_vec(); + new_dims.insert(self.inferred_pos, inferred); + + Shape::from(new_dims.as_slice()) + } +} + +impl From<(&[usize], &[usize])> for InferredShape { + fn from(value: (&[usize], &[usize])) -> Self { + Self { + inferred_pos: value.0.len(), + dims: value.0.iter().chain(value.1.iter()).copied().collect() + } + } +} \ No newline at end of file diff --git a/src/helper/mod.rs b/src/helper/mod.rs index 2f13772..beb5340 100644 --- a/src/helper/mod.rs +++ b/src/helper/mod.rs @@ -1,8 +1,10 @@ +mod inferred_shape; mod shape; mod stride; mod position; mod slice; +pub use inferred_shape::*; pub use shape::*; pub use stride::*; pub use position::*; diff --git a/src/helper/position.rs b/src/helper/position.rs index 378b779..01c2e22 100644 --- a/src/helper/position.rs +++ b/src/helper/position.rs @@ -18,7 +18,7 @@ impl Position { &mut self.0 } - pub fn index(&self, stride: &Stride) -> Result { + pub fn tensor_index(&self, stride: &Stride) -> Result { let position_length = stride.as_boxed_slice().len(); let stride_length = stride.as_boxed_slice().len(); @@ -35,7 +35,7 @@ impl Position { pub fn incdec_mut(&mut self, bounds: &Shape, off: i64) { let mut curr = off; - for i in (0..bounds.as_boxed_slice().len()).rev() { + for i in (0..bounds.dims()).rev() { let signed_bound = bounds[i] as i64; curr += self[i] as i64; diff --git a/src/helper/shape.rs b/src/helper/shape.rs index 2e69740..a565246 100644 --- a/src/helper/shape.rs +++ b/src/helper/shape.rs @@ -22,6 +22,14 @@ impl Shape { self.0.iter().product() } + pub fn dim(&self, ind: usize) -> usize { + self.0[ind] + } + + pub fn dims(&self) -> usize { + self.0.len() + } + //first valid position pub fn first(&self) -> Position { Position::new(vec![0; self.as_boxed_slice().len()].into()) @@ -53,4 +61,11 @@ impl Display for Shape { write!(f, "({})", conv_sizes.join(",")) } -} \ No newline at end of file +} + +macro_rules! shape { + ($($x:expr),+) => { + Shape::from([$($x),+].as_slice()) + }; +} +pub(crate) use shape; \ No newline at end of file diff --git a/src/helper/slice.rs b/src/helper/slice.rs index 47a186d..dba443d 100644 --- a/src/helper/slice.rs +++ b/src/helper/slice.rs @@ -45,6 +45,14 @@ impl Interval { } } + pub fn between_with_step(start: usize, finish: usize, step: usize) -> Self { + Self { + start: Some(start), + finish: Some(finish), + step: Some(step), + } + } + pub fn all() -> Self { Self { start: None,