From abf594ba32c00fd3f93dde2ed5aaced6c2e919ca Mon Sep 17 00:00:00 2001 From: Riley Sutton Date: Thu, 21 Mar 2024 13:46:06 +1100 Subject: [PATCH] Refactor impls and spike for tensor builder --- Cargo.lock | 57 +++-- src/comp_graph/mod.rs | 2 +- src/engine/mod.rs | 4 - src/engine/tensor/allowed_unit.rs | 8 - src/engine/tensor/builder.rs | 12 +- src/engine/tensor/factory.rs | 49 +---- src/engine/tensor/generic.rs | 13 -- src/engine/tensor/mod.rs | 219 +------------------ src/engine/unit/core_value.rs | 2 +- src/{engine => engine_impl}/basic.rs | 6 +- src/engine_impl/mod.rs | 5 + src/{engine => engine_impl}/shared.rs | 11 +- src/engine_impl/tensor/array.rs | 209 ++++++++++++++++++ src/engine_impl/tensor/mod.rs | 2 + src/{engine => engine_impl}/tensor/padded.rs | 17 +- src/{engine => engine_impl}/util.rs | 4 +- src/helper/position/iter.rs | 14 +- src/helper/stride.rs | 2 - src/main.rs | 7 +- 19 files changed, 286 insertions(+), 357 deletions(-) delete mode 100644 src/engine/tensor/allowed_unit.rs delete mode 100644 src/engine/tensor/generic.rs rename src/{engine => engine_impl}/basic.rs (93%) create mode 100644 src/engine_impl/mod.rs rename src/{engine => engine_impl}/shared.rs (96%) create mode 100644 src/engine_impl/tensor/array.rs create mode 100644 src/engine_impl/tensor/mod.rs rename src/{engine => engine_impl}/tensor/padded.rs (91%) rename src/{engine => engine_impl}/util.rs (92%) diff --git a/Cargo.lock b/Cargo.lock index 20d06c4..3c4151a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,21 +10,21 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "bytemuck" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] @@ -56,28 +56,27 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "num-traits", ] [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-iter" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" dependencies = [ "autocfg", "num-integer", @@ -98,9 +97,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", ] @@ -113,18 +112,18 @@ checksum = "d7ce14664caf5b27f5656ff727defd68ae1eb75ef3c4d95259361df1eb376bef" [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -140,18 +139,18 @@ dependencies = [ [[package]] name = "slotmap" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1e08e261d0e8f5c43123b7adf3e4ca1690d655377ac93a03b2c9d3e98de1342" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" dependencies = [ "version_check", ] [[package]] name = "syn" -version = "2.0.37" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -172,18 +171,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", @@ -204,9 +203,9 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "wide" -version = "0.7.11" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa469ffa65ef7e0ba0f164183697b89b854253fd31aeb92358b7b6155177d62f" +checksum = "89beec544f246e679fc25490e3f8e08003bc4bf612068f325120dad4cea02c1c" dependencies = [ "bytemuck", "safe_arch", diff --git a/src/comp_graph/mod.rs b/src/comp_graph/mod.rs index 04146b1..ecf625b 100644 --- a/src/comp_graph/mod.rs +++ b/src/comp_graph/mod.rs @@ -320,7 +320,7 @@ pub enum ComputationGraphError { mod test { use num::traits::Pow; - use crate::{engine::{basic::Basic, tensor::{generic::EngineTensorGeneric, Array}}, helper::Shape}; + use crate::{engine_impl::{basic::Basic, tensor::array::Array}, helper::Shape}; use super::*; diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 7a77961..129bcc1 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -1,9 +1,5 @@ pub mod tensor; pub mod unit; -pub mod basic; - -mod shared; -mod util; use crate::helper::{Shape, PositionError}; use self::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible}; diff --git a/src/engine/tensor/allowed_unit.rs b/src/engine/tensor/allowed_unit.rs deleted file mode 100644 index 96217c3..0000000 --- a/src/engine/tensor/allowed_unit.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::engine::unit::UnitCompatible; - -pub trait AllowedArray: UnitCompatible {} -impl AllowedArray for T {} - -pub trait AllowedQuant: UnitCompatible {} -impl AllowedQuant for f32 {} -impl AllowedQuant for f64 {} \ No newline at end of file diff --git a/src/engine/tensor/builder.rs b/src/engine/tensor/builder.rs index 18edce4..d8d71d0 100644 --- a/src/engine/tensor/builder.rs +++ b/src/engine/tensor/builder.rs @@ -1,16 +1,16 @@ use std::ops::RangeBounds; -use crate::engine::unit::UnitCompatible; +use crate::{engine::unit::UnitCompatible, helper::Shape}; use super::EngineTensor; -trait EngineTensorBuilder { +pub trait EngineTensorBuilder { type Unit: UnitCompatible; type Tensor: EngineTensor; - fn splice, I: IntoIterator>(range: R, replace_with: I); + fn new(shape: Shape) -> Self; - fn construct() -> Box>; -} + fn splice, I: IntoIterator>(&mut self, range: R, replace_with: I); -struct ArrayBuilder; \ No newline at end of file + fn construct(self) -> Self::Tensor; +} \ No newline at end of file diff --git a/src/engine/tensor/factory.rs b/src/engine/tensor/factory.rs index 475c7b3..90feb7e 100644 --- a/src/engine/tensor/factory.rs +++ b/src/engine/tensor/factory.rs @@ -1,51 +1,14 @@ -use std::sync::Arc; -use crate::helper::{Shape, Stride}; -use super::{allowed_unit::{AllowedArray, AllowedQuant}, generic::EngineTensorGeneric, Array, EngineTensor, Quant}; +use crate::helper::Shape; +use super::EngineTensor; -pub trait EngineTensorFactory: EngineTensor + EngineTensorGeneric -where Self: Sized +pub trait EngineTensorFactory: EngineTensor +where Self: Sized + 'static { fn from_iter(iter: impl Iterator, shape: Shape) -> Self; fn from_slice(data: &[Self::Unit], shape: Shape) -> Self; //fn builder() -> impl EngineTensorBuilder; -} -impl EngineTensorFactory for Array { - fn from_iter(iter: impl Iterator, shape: Shape) -> Self { - Array { - stride: Stride::default_from_shape(&shape), - shape, - data: iter.collect(), - offset: 0, - } - } - - fn from_slice(data: &[T], shape: Shape) -> Self { - Array { - stride: Stride::default_from_shape(&shape), - shape, - data: Arc::from(data), - offset: 0, - } - } -} - -impl EngineTensorFactory for Quant { - fn from_iter(iter: impl Iterator, shape: Shape) -> Self { - Quant { - stride: Stride::default_from_shape(&shape), - shape, - data: iter.collect(), - offset: 0, - } - } - - fn from_slice(data: &[T], shape: Shape) -> Self { - Quant { - stride: Stride::default_from_shape(&shape), - shape, - data: Arc::from(data), - offset: 0, - } + fn generic(self) -> Box> { + Box::from(self) } } \ No newline at end of file diff --git a/src/engine/tensor/generic.rs b/src/engine/tensor/generic.rs deleted file mode 100644 index 32a05cf..0000000 --- a/src/engine/tensor/generic.rs +++ /dev/null @@ -1,13 +0,0 @@ -use super::{allowed_unit::{AllowedArray, AllowedQuant}, Array, EngineTensor, Quant}; - -pub trait EngineTensorGeneric: EngineTensor -where - Self: Sized + 'static, -{ - fn generic(self) -> Box> { - Box::from(self) - } -} - -impl EngineTensorGeneric for Array {} -impl EngineTensorGeneric for Quant {} diff --git a/src/engine/tensor/mod.rs b/src/engine/tensor/mod.rs index 97e1b92..9a096ed 100644 --- a/src/engine/tensor/mod.rs +++ b/src/engine/tensor/mod.rs @@ -1,18 +1,11 @@ pub mod extension; pub mod iter; -pub mod allowed_unit; pub mod builder; pub mod factory; -pub mod generic; -pub mod padded; -use std::sync::Arc; - -use crate::helper::{Shape, Stride, Position, Slice, VarArrayCompatible}; -use self::extension::{ExtensionProvider, EmptyExtensionProvider}; -use self::factory::EngineTensorFactory; -use self::generic::EngineTensorGeneric; -use self::{iter::EngineTensorUnitIterator, allowed_unit::{AllowedArray, AllowedQuant}}; +use crate::helper::{Shape, Position, Slice, }; +use self::extension::ExtensionProvider; +use self::iter::EngineTensorUnitIterator; use std::fmt::Debug; use super::unit::UnitCompatible; @@ -25,6 +18,7 @@ pub trait EngineTensor: Debug { fn iter_units(&self) -> EngineTensorUnitIterator<'_, Self::Unit>; fn clone(&self) -> Box>; + fn mat(&self) -> Box>; fn slice(&self, slice: &Slice) -> Box>; fn reshape(&self, shape: &Shape) -> Box>; fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box>; @@ -36,209 +30,4 @@ impl PartialEq for dyn EngineTensor + '_ { fn eq(&self, other: &Self) -> bool { self.shape() == other.shape() && self.iter_units().zip(other.iter_units()).all(|(a, b)| a == b) } -} - -#[derive(Debug)] -pub struct Array { - stride: Stride, - shape: Shape, - data: Arc<[T]>, - offset: usize, -} - -impl Array { - - pub fn from_data(stride: Stride, shape: Shape, data: Arc<[T]>, offset: usize) -> Self { - Self { - stride, - shape, - data, - offset, - } - } - - //Am I dumb? - //This is wrong!!! - //TODO - fn is_contiguous(&self) -> bool { - let mut check: Option = None; - - for curr in self.stride.iter() { - match check { - Some(prev) => { - if prev * prev == curr { - check = Some(curr); - } else { - return false; - } - }, - None => { - check = Some(curr); - }, - } - } - - return true; - } -} - -impl EngineTensor for Array { - type Unit = T; - - fn shape(&self) -> &Shape { - &self.shape - } - - fn get(&self, pos: &Position) -> T { - let index = pos.tensor_index(&self.stride).unwrap() + self.offset; - - *self.data.as_ref().get(index).unwrap() - } - - fn iter_units(&self) -> EngineTensorUnitIterator<'_, T> { - EngineTensorUnitIterator::new(self) - } - - fn clone(&self) -> Box> { - Box::new(Self { - stride: self.stride.clone() , - shape: self.shape.clone(), - data: self.data.clone(), - offset: self.offset, - }) - } - - fn slice(&self, slice: &Slice) -> Box> { - let offset = slice.start().tensor_index(&self.stride).unwrap(); - - Box::from(Self { - stride: self.stride.clone(), - shape: slice.inferred_shape(self.shape()), - data: self.data.clone(), - offset, - }) - } - - //Attempts to efficiently reuse memory if tensor is contiguous - //If this is not an option it will copy from an iterator - fn reshape(&self, shape: &Shape) -> Box> { - if shape.elements() == self.shape().elements() { - if self.is_contiguous() { - Box::new(Array:: { - stride: Stride::default_from_shape(shape), - shape: shape.clone(), - data: self.data.clone(), - offset: self.offset, - }) - } else { - Array::::from_iter(self.iter_units(), shape.clone()).generic() - } - } else { - todo!() - } - } - - fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box> { - if pos <= self.shape().len() { - let mut shape_buffer = self.shape().as_slice().to_vec(); - shape_buffer.splice(pos..pos, sub.iter().copied()); - - let broadcast_shape = Shape::new(shape_buffer.as_slice().into()); - - let mut stride_buffer = self.stride.as_slice().to_vec(); - stride_buffer.splice(pos..pos, std::iter::repeat(0).take(sub.len())); - - let broadcast_stride = Stride::new(stride_buffer.as_slice().into()); - - Box::new(Self { - stride: broadcast_stride, - shape: broadcast_shape, - data: self.data.clone(), - offset: self.offset, - }) - } else { - todo!() - } - } - - fn extensions(&self) -> Box { - Box::from(EmptyExtensionProvider::from(self)) - } -} - -#[derive(Debug)] -pub struct Quant { - stride: Stride, - shape: Shape, - data: Arc<[T]>, - offset: usize, -} - -impl EngineTensor for Quant { - type Unit = T; - - fn shape(&self) -> &Shape { - &self.shape - } - - fn get(&self, pos: &Position) -> T { - let index = pos.tensor_index(&self.stride).unwrap() + self.offset; - - *self.data.as_ref().get(index).unwrap() - } - - fn iter_units(&self) -> EngineTensorUnitIterator<'_, T> { - EngineTensorUnitIterator::new(self) - } - - fn clone(&self) -> Box> { - Box::new(Self { - stride: self.stride.clone() , - shape: self.shape.clone(), - data: self.data.clone(), - offset: self.offset, - }) - } - - fn slice(&self, slice: &Slice) -> Box> { - let offset = slice.start().tensor_index(&self.stride).unwrap(); - - Box::from(Self { - stride: self.stride.clone(), - shape: self.shape.clone(), - data: self.data.clone(), - offset, - }) - } - - fn reshape(&self, shape: &Shape) -> Box> { - todo!() - } - - fn broadcast_splice(&self, pos: usize, sub: &[usize]) -> Box> { - if pos <= self.shape().len() { - let mut shape_buffer = self.shape().as_slice().to_vec(); - shape_buffer.splice(pos..pos, sub.iter().copied()); - - let broadcast_shape = Shape::new(shape_buffer.as_slice().into()); - - let mut stride_buffer = self.stride.as_slice().to_vec(); - stride_buffer.splice(pos..pos, std::iter::repeat(0).take(sub.len())); - - let broadcast_stride = Stride::new(stride_buffer.as_slice().into()); - - Box::new(Self { - stride: broadcast_stride, - shape: broadcast_shape, - data: self.data.clone(), - offset: self.offset, - }) - } else { - todo!() - } - } - - fn extensions(&self) -> Box { - Box::from(EmptyExtensionProvider::from(self)) - } } \ No newline at end of file diff --git a/src/engine/unit/core_value.rs b/src/engine/unit/core_value.rs index 7753104..0fa6572 100644 --- a/src/engine/unit/core_value.rs +++ b/src/engine/unit/core_value.rs @@ -1,4 +1,4 @@ -pub trait CoreValue { +pub trait CoreValue: Default { fn zero() -> Self; fn one() -> Self; } diff --git a/src/engine/basic.rs b/src/engine_impl/basic.rs similarity index 93% rename from src/engine/basic.rs rename to src/engine_impl/basic.rs index 1881ccd..64074ed 100644 --- a/src/engine/basic.rs +++ b/src/engine_impl/basic.rs @@ -1,8 +1,6 @@ use itertools::Itertools; -use crate::{engine::{util::{err_if_dimension_mismatch, err_if_incorrect_dimensions, err_if_too_few_dimensions}, Engine, EngineError, EngineTensor, EngineTensorFactory}, helper::{shape, Interval, Position, Shape, Slice, Stride, VarArrayCompatible}}; - -use super::{shared::im2col_2d, unit::UnitCompatible}; +use crate::{engine::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible, Engine, EngineError}, engine_impl::{shared::im2col_2d, util::{err_if_dimension_mismatch, err_if_incorrect_dimensions, err_if_too_few_dimensions}}, helper::{shape, Shape, VarArrayCompatible}}; pub struct Basic {} @@ -117,7 +115,7 @@ impl Engine for Basic { #[cfg(test)] mod test { - use crate::{engine::tensor::{generic::EngineTensorGeneric, Array}, helper::shape}; + use crate::{engine_impl::tensor::array::Array, helper::shape}; use super::*; diff --git a/src/engine_impl/mod.rs b/src/engine_impl/mod.rs new file mode 100644 index 0000000..7fa0263 --- /dev/null +++ b/src/engine_impl/mod.rs @@ -0,0 +1,5 @@ +pub mod tensor; +pub mod basic; + +mod shared; +mod util; \ No newline at end of file diff --git a/src/engine/shared.rs b/src/engine_impl/shared.rs similarity index 96% rename from src/engine/shared.rs rename to src/engine_impl/shared.rs index 0bc0fdd..8ef82c9 100644 --- a/src/engine/shared.rs +++ b/src/engine_impl/shared.rs @@ -1,13 +1,8 @@ use std::iter; -use num::Zero; +use crate::{engine::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible}, helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible}}; -use crate::{ - engine::tensor::padded::Padded, - helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible}, -}; - -use super::{tensor::{factory::EngineTensorFactory, Array, EngineTensor}, unit::UnitCompatible}; +use super::tensor::padded::Padded; //a: (batches, in_channels, img_y, img_x) //kernel_shape: (in_channels, k_y, k_x) @@ -91,7 +86,7 @@ pub fn im2col_2d>( #[cfg(test)] mod test { - use crate::engine::tensor::{generic::EngineTensorGeneric, Array}; + use crate::engine_impl::tensor::array::Array; use super::*; diff --git a/src/engine_impl/tensor/array.rs b/src/engine_impl/tensor/array.rs new file mode 100644 index 0000000..481fc88 --- /dev/null +++ b/src/engine_impl/tensor/array.rs @@ -0,0 +1,209 @@ +use std::{iter, sync::Arc}; + +use crate::{ + engine::{ + tensor::{ + builder::EngineTensorBuilder, + extension::{EmptyExtensionProvider, ExtensionProvider}, + factory::EngineTensorFactory, + iter::EngineTensorUnitIterator, + EngineTensor, + }, + unit::UnitCompatible, + }, + helper::{Position, Shape, Slice, Stride, VarArrayCompatible}, +}; + +#[derive(Debug)] +pub struct Array { + stride: Stride, + shape: Shape, + data: Arc<[T]>, + offset: usize, +} + +pub trait AllowedArray: UnitCompatible {} +impl AllowedArray for T {} + +impl Array { + pub fn from_data(stride: Stride, shape: Shape, data: Arc<[T]>, offset: usize) -> Self { + Self { + stride, + shape, + data, + offset, + } + } + + //Am I dumb? + //This is wrong!!! + //TODO + fn is_contiguous(&self) -> bool { + let mut check: Option = None; + + for curr in self.stride.iter() { + match check { + Some(prev) => { + if prev * prev == curr { + check = Some(curr); + } else { + return false; + } + } + None => { + check = Some(curr); + } + } + } + + return true; + } +} + +impl EngineTensor for Array { + type Unit = T; + + fn shape(&self) -> &Shape { + &self.shape + } + + fn get(&self, pos: &Position) -> T { + let index = pos.tensor_index(&self.stride).unwrap() + self.offset; + + *self.data.as_ref().get(index).unwrap() + } + + fn iter_units(&self) -> EngineTensorUnitIterator<'_, T> { + EngineTensorUnitIterator::new(self) + } + + fn clone(&self) -> Box> { + Box::new(Self { + stride: self.stride.clone(), + shape: self.shape.clone(), + data: self.data.clone(), + offset: self.offset, + }) + } + + fn mat(&self) -> Box> { + Array::from_iter(self.iter_units(), self.shape().clone()).generic() + } + + fn slice(&self, slice: &Slice) -> Box> { + let offset = slice.start().tensor_index(&self.stride).unwrap(); + + Box::from(Self { + stride: self.stride.clone(), + shape: slice.inferred_shape(self.shape()), + data: self.data.clone(), + offset, + }) + } + + //Attempts to efficiently reuse memory if tensor is contiguous + //If this is not an option it will copy from an iterator + fn reshape(&self, shape: &Shape) -> Box> { + if shape.elements() == self.shape().elements() { + if self.is_contiguous() { + Box::new(Array:: { + stride: Stride::default_from_shape(shape), + shape: shape.clone(), + data: self.data.clone(), + offset: self.offset, + }) + } else { + Array::::from_iter(self.iter_units(), shape.clone()).generic() + } + } else { + todo!() + } + } + + fn broadcast_splice( + &self, + pos: usize, + sub: &[usize], + ) -> Box> { + if pos <= self.shape().len() { + let mut shape_buffer = self.shape().as_slice().to_vec(); + shape_buffer.splice(pos..pos, sub.iter().copied()); + + let broadcast_shape = Shape::new(shape_buffer.as_slice().into()); + + let mut stride_buffer = self.stride.as_slice().to_vec(); + stride_buffer.splice(pos..pos, std::iter::repeat(0).take(sub.len())); + + let broadcast_stride = Stride::new(stride_buffer.as_slice().into()); + + Box::new(Self { + stride: broadcast_stride, + shape: broadcast_shape, + data: self.data.clone(), + offset: self.offset, + }) + } else { + todo!() + } + } + + fn extensions(&self) -> Box { + Box::from(EmptyExtensionProvider::from(self)) + } +} + +impl EngineTensorFactory for Array { + fn from_iter(iter: impl Iterator, shape: Shape) -> Self { + Array { + stride: Stride::default_from_shape(&shape), + shape, + data: iter.collect(), + offset: 0, + } + } + + fn from_slice(data: &[T], shape: Shape) -> Self { + Array { + stride: Stride::default_from_shape(&shape), + shape, + data: Arc::from(data), + offset: 0, + } + } +} + +struct ArrayBuilder { + shape: Shape, + data: Vec, +} + +impl EngineTensorBuilder for ArrayBuilder { + type Unit = T; + type Tensor = Array; + + fn new(shape: Shape) -> Self { + let elements = shape.elements(); + + Self { + shape, + data: Vec::from_iter(iter::repeat(T::default()).take(elements)), + } + } + + fn splice, I: IntoIterator>( + &mut self, + range: R, + replace_with: I, + ) { + self.data.splice(range, replace_with); + } + + fn construct(self) -> Self::Tensor { + Array::from_data( + Stride::default_from_shape(&self.shape), + self.shape, + self.data.into(), + 0, + ) + } +} diff --git a/src/engine_impl/tensor/mod.rs b/src/engine_impl/tensor/mod.rs new file mode 100644 index 0000000..1dfe971 --- /dev/null +++ b/src/engine_impl/tensor/mod.rs @@ -0,0 +1,2 @@ +pub mod array; +pub mod padded; \ No newline at end of file diff --git a/src/engine/tensor/padded.rs b/src/engine_impl/tensor/padded.rs similarity index 91% rename from src/engine/tensor/padded.rs rename to src/engine_impl/tensor/padded.rs index 7175796..cf6a6b6 100644 --- a/src/engine/tensor/padded.rs +++ b/src/engine_impl/tensor/padded.rs @@ -1,5 +1,6 @@ -use crate::{engine::unit::UnitCompatible, helper::{Position, Shape, Slice, VarArray, VarArrayCompatible}}; -use super::{extension::EmptyExtensionProvider, factory::EngineTensorFactory, generic::EngineTensorGeneric, iter::EngineTensorUnitIterator, Array, EngineTensor}; +use crate::{engine::{tensor::{extension::{EmptyExtensionProvider, ExtensionProvider}, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible}, helper::{Position, Shape, Slice, VarArray, VarArrayCompatible}}; + +use super::array::Array; pub trait AllowedPadded: UnitCompatible {} impl AllowedPadded for T {} @@ -71,8 +72,6 @@ impl Padded { } } -impl EngineTensorGeneric for Padded {} - impl EngineTensor for Padded { type Unit = T; @@ -98,7 +97,7 @@ impl EngineTensor for Padded { } } - fn iter_units(&self) -> super::iter::EngineTensorUnitIterator<'_, Self::Unit> { + fn iter_units(&self) -> EngineTensorUnitIterator<'_, Self::Unit> { EngineTensorUnitIterator::new(self) } @@ -118,6 +117,10 @@ impl EngineTensor for Padded { }) } + fn mat(&self) -> Box> { + Array::from_iter(self.iter_units(), self.shape().clone()).generic() + } + //We can handle slices but changing anything more drastic needs a deep copy fn slice(&self, slice: &Slice) -> Box> { let slice_shape = slice.inferred_shape(self.shape()); @@ -174,14 +177,14 @@ impl EngineTensor for Padded { }) } - fn extensions(&self)-> Box { + fn extensions(&self)-> Box { Box::new(EmptyExtensionProvider::from(self)) } } #[cfg(test)] mod test { - use crate::helper::{shape, varr}; + use crate::{engine::tensor::factory::EngineTensorFactory, engine_impl::tensor::array::Array, helper::{shape, varr}}; use super::*; diff --git a/src/engine/util.rs b/src/engine_impl/util.rs similarity index 92% rename from src/engine/util.rs rename to src/engine_impl/util.rs index 1ec2934..ccf0a0b 100644 --- a/src/engine/util.rs +++ b/src/engine_impl/util.rs @@ -1,6 +1,4 @@ -use crate::helper::{Shape, VarArrayCompatible}; - -use super::EngineError; +use crate::{engine::EngineError, helper::{Shape, VarArrayCompatible}}; pub fn return_if_matched_shape(a: &Shape, b: &Shape, out: T) -> Result { if a == b { diff --git a/src/helper/position/iter.rs b/src/helper/position/iter.rs index f52fcb5..faa88ca 100644 --- a/src/helper/position/iter.rs +++ b/src/helper/position/iter.rs @@ -1,4 +1,4 @@ -use crate::helper::{Shape, VarArrayCompatible}; +use crate::helper::Shape; use super::Position; @@ -8,7 +8,6 @@ pub struct Iter<'a> { pos: Position, until: &'a Position, bounds: &'a Shape, - is_done: bool, } impl<'a> Iter<'a> { @@ -17,7 +16,6 @@ impl<'a> Iter<'a> { pos, until, bounds, - is_done: bounds.len() == 0, //If it's an empty shape then nothing to iterate on } } } @@ -27,16 +25,12 @@ impl<'a> Iterator for Iter<'a> { type Item = Position; fn next(&mut self) -> Option { - if !self.is_done { + if self.pos == *self.until { + None + } else { self.pos.incdec_mut(self.bounds, 1).unwrap(); - if self.pos == *self.until { - self.is_done = true; - } - Some(self.pos.clone()) - } else { - None } } } \ No newline at end of file diff --git a/src/helper/stride.rs b/src/helper/stride.rs index a02f654..4216150 100644 --- a/src/helper/stride.rs +++ b/src/helper/stride.rs @@ -1,5 +1,3 @@ -use std::ops::Index; - use crate::helper::Shape; use super::{VarArray, VarArrayCompatible, Unit}; diff --git a/src/main.rs b/src/main.rs index 62ef2ce..ec781e1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,14 @@ #![allow(dead_code)] #![allow(unused_macros)] -use engine::{tensor::Array, basic::Basic}; use helper::Shape; -use crate::{comp_graph::CompGraph, engine::tensor::factory::EngineTensorFactory}; +use crate::{comp_graph::CompGraph, engine::tensor::factory::EngineTensorFactory, engine_impl::{basic::Basic, tensor::array::Array}, helper::shape}; mod engine; mod helper; mod comp_graph; +mod engine_impl; fn main() { let mut graph = CompGraph::::new(); @@ -17,8 +17,9 @@ fn main() { let b = graph.create_root(Box::new(Array::from_slice([2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.].as_slice(), Shape::from([4, 3].as_slice())))); let mut c = a; + let divider = graph.create_root(Array::from_slice(&[0.99], shape![1]).generic().broadcast_splice(0, &[4, 3]).reshape(&shape![4, 3]).mat()); for _ in 0..100000 { - //c = graph.div_scalar_rh::>(&c, 1.00001); + c = graph.div::>(&c, ÷r); } graph.non_populating_eval(&c).unwrap();