Skip to content

Commit

Permalink
Refactor Unit
Browse files Browse the repository at this point in the history
  • Loading branch information
rileysu committed Mar 7, 2024
1 parent 22f995f commit 26fa9f4
Show file tree
Hide file tree
Showing 17 changed files with 399 additions and 258 deletions.
12 changes: 6 additions & 6 deletions src/comp_graph/edge.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::engine::{tensor::{allowed_unit::AllowedUnit, EngineTensor}, EngineError};
use crate::engine::{tensor::EngineTensor, unit::UnitCompatible, EngineError};

use super::{NodeKey, ComputationGraphError};

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Edge<T: AllowedUnit> {
pub enum Edge<T: UnitCompatible> {
Root,

Abs(NodeKey, fn(&dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>),
Expand All @@ -22,7 +22,7 @@ pub enum Edge<T: AllowedUnit> {
Div(NodeKey, NodeKey, fn(&dyn EngineTensor<Unit = T>, &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>),
}

impl<T: AllowedUnit> Edge<T> {
impl<T: UnitCompatible> Edge<T> {
pub fn nodes(&self) -> EdgeNodesIterator<T> {
EdgeNodesIterator::<T>::new(self)
}
Expand Down Expand Up @@ -63,12 +63,12 @@ impl<T: AllowedUnit> Edge<T> {
}
}

pub struct EdgeNodesIterator<'a, T: AllowedUnit> {
pub struct EdgeNodesIterator<'a, T: UnitCompatible> {
edge: &'a Edge<T>,
pos: usize,
}

impl<'a, T: AllowedUnit> EdgeNodesIterator<'a, T> {
impl<'a, T: UnitCompatible> EdgeNodesIterator<'a, T> {
pub fn new(edge: &'a Edge<T>) -> Self {
Self {
edge,
Expand All @@ -77,7 +77,7 @@ impl<'a, T: AllowedUnit> EdgeNodesIterator<'a, T> {
}
}

impl<'a, T: AllowedUnit> Iterator for EdgeNodesIterator<'a, T> {
impl<'a, T: UnitCompatible> Iterator for EdgeNodesIterator<'a, T> {
type Item = NodeKey;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
10 changes: 5 additions & 5 deletions src/comp_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ use std::collections::{HashSet, HashMap};
use slotmap::{SlotMap, new_key_type};
use thiserror::Error;

use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorUnitIterator}, Engine, EngineError};
use crate::engine::{tensor::{factory::EngineTensorFactory, iter::EngineTensorUnitIterator, EngineTensor}, unit::UnitCompatible, Engine, EngineError};

use self::edge::Edge;

#[derive(Debug)]
pub struct Node<T: AllowedUnit> {
pub struct Node<T: UnitCompatible> {
tensor: Option<Box<dyn EngineTensor<Unit = T>>>,
edge: Edge<T>,
}

impl<T: AllowedUnit> Node<T> {
impl<T: UnitCompatible> Node<T> {
fn create_root(tensor: Box<dyn EngineTensor<Unit = T>>) -> Self {
Self {
tensor: Some(tensor),
Expand Down Expand Up @@ -61,11 +61,11 @@ impl<T: AllowedUnit> Node<T> {
new_key_type! { pub struct NodeKey; }

#[derive(Debug)]
pub struct CompGraph<T: AllowedUnit> {
pub struct CompGraph<T: UnitCompatible> {
nodes: SlotMap<NodeKey, Node<T>>,
}

impl<T: AllowedUnit> CompGraph<T> {
impl<T: UnitCompatible> CompGraph<T> {
pub fn new() -> Self {
Self {
nodes: SlotMap::with_key(),
Expand Down
280 changes: 89 additions & 191 deletions src/engine/basic.rs

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
pub mod tensor;
pub mod unit;
pub mod basic;

mod shared;
mod util;

use crate::helper::{Shape, PositionError};
use self::tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory};
use self::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible};
use thiserror::Error;

//Using PyTorch operations as a base
//Using a trait over an enum has little extra cost and allows for extension
//Engines provide different optimisations for Tensor operations
//Factory defines the unit as well as output tensor type
pub trait Engine<T: AllowedUnit> {
pub trait Engine<T: UnitCompatible> {
//Pointwise Single
fn abs<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
fn neg<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
Expand All @@ -33,6 +34,7 @@ pub trait Engine<T: AllowedUnit> {

//Conv
fn conv2d<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, kernel: &dyn EngineTensor<Unit = T>, padding: usize, stride: usize) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
//fn im2col_2d<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, kernel_shape: &Shape, padding: usize, stride: usize) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
}

#[derive(Error, Debug)]
Expand Down
10 changes: 6 additions & 4 deletions src/engine/shared.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::iter;

use num::Zero;

use crate::{
engine::tensor::padded::Padded,
helper::{Interval, Position, Shape, Slice, Stride, VarArrayCompatible},
};

use super::tensor::{allowed_unit::AllowedUnit, factory::EngineTensorFactory, Array, EngineTensor};
use super::{tensor::{factory::EngineTensorFactory, Array, EngineTensor}, unit::UnitCompatible};

//a: (batches, in_channels, img_y, img_x)
//kernel_shape: (in_channels, k_y, k_x)
//out: (batches, in_channels, out_y, out_x, k_y * k_x)
pub fn im2col_2d<T: AllowedUnit + Default, E: EngineTensorFactory<Unit = T>>(
pub fn im2col_2d<T: UnitCompatible, E: EngineTensorFactory<Unit = T>>(
a: &dyn EngineTensor<Unit = T>,
kernel_shape: &Shape,
padding: usize,
Expand All @@ -23,7 +25,7 @@ pub fn im2col_2d<T: AllowedUnit + Default, E: EngineTensorFactory<Unit = T>>(
let a_padded = Padded::pad_from(
a.clone(),
[0, 0, padding, padding].as_slice().into(),
T::default(),
T::zero(),
);

let img_y = a_padded.shape().get(2).unwrap();
Expand All @@ -46,7 +48,7 @@ pub fn im2col_2d<T: AllowedUnit + Default, E: EngineTensorFactory<Unit = T>>(

//Buffer used for output

let mut buffer = Vec::<T>::from_iter(iter::repeat(T::default()).take(out_shape.elements()));
let mut buffer = Vec::<T>::from_iter(iter::repeat(T::zero()).take(out_shape.elements()));
buffer.shrink_to_fit();

for y in 0..out_y {
Expand Down
13 changes: 4 additions & 9 deletions src/engine/tensor/allowed_unit.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
use std::fmt::Debug;
use crate::engine::unit::UnitCompatible;

use num::Num;
pub trait AllowedArray: UnitCompatible {}
impl<T: UnitCompatible> AllowedArray for T {}

pub trait AllowedUnit: Num + Sized + Copy + Debug + 'static {}
impl<T: Num + Sized + Copy + Debug + 'static> AllowedUnit for T {}

pub trait AllowedArray: AllowedUnit {}
impl<T: AllowedUnit> AllowedArray for T {}

pub trait AllowedQuant: AllowedUnit {}
pub trait AllowedQuant: UnitCompatible {}
impl AllowedQuant for f32 {}
impl AllowedQuant for f64 {}
6 changes: 3 additions & 3 deletions src/engine/tensor/factory.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::sync::Arc;
use crate::helper::{Shape, Stride};
use super::{allowed_unit::{AllowedUnit, AllowedArray, AllowedQuant}, EngineTensor, Array, Quant};
use crate::{engine::unit::UnitCompatible, helper::{Shape, Stride}};
use super::{allowed_unit::{AllowedArray, AllowedQuant}, EngineTensor, Array, Quant};

pub trait EngineTensorFactory
where Self: Sized
{
type Unit: AllowedUnit;
type Unit: UnitCompatible;

fn from_iter(iter: impl Iterator<Item = Self::Unit>, shape: Shape) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
fn from_slice(data: &[Self::Unit], shape: Shape) -> Box<dyn EngineTensor<Unit = Self::Unit>>;
Expand Down
10 changes: 5 additions & 5 deletions src/engine/tensor/iter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::helper::Position;
use super::{allowed_unit::AllowedUnit, EngineTensor};
use crate::{engine::unit::UnitCompatible, helper::Position};
use super::EngineTensor;

/*pub struct EngineTensorIterator<'a, T: AllowedUnit> {
tensor: &'a dyn EngineTensor<Unit = T>,
Expand Down Expand Up @@ -41,14 +41,14 @@ impl<'a, T: AllowedUnit> Iterator for EngineTensorIterator<'a, T> {

//TODO basic impl that isn't optimised
//It can be enhanced by fetching chunks of contig memory if available
pub struct EngineTensorUnitIterator<'a, T: AllowedUnit> {
pub struct EngineTensorUnitIterator<'a, T: UnitCompatible> {
tensor: &'a dyn EngineTensor<Unit = T>,
curr: Position,
finish: Position,
ended: bool,
}

impl<'a, T: AllowedUnit> EngineTensorUnitIterator<'a, T> {
impl<'a, T: UnitCompatible> EngineTensorUnitIterator<'a, T> {
pub fn new(tensor: &'a dyn EngineTensor<Unit = T>) -> Self {
Self {
tensor,
Expand All @@ -59,7 +59,7 @@ impl<'a, T: AllowedUnit> EngineTensorUnitIterator<'a, T> {
}
}

impl<T: AllowedUnit> Iterator for EngineTensorUnitIterator<'_, T> {
impl<T: UnitCompatible> Iterator for EngineTensorUnitIterator<'_, T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
7 changes: 4 additions & 3 deletions src/engine/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ use std::sync::Arc;
use crate::helper::{Shape, Stride, Position, Slice, VarArrayCompatible};
use self::extension::{ExtensionProvider, EmptyExtensionProvider};
use self::factory::EngineTensorFactory;
use self::{iter::EngineTensorUnitIterator, allowed_unit::{AllowedUnit, AllowedArray, AllowedQuant}};
use self::{iter::EngineTensorUnitIterator, allowed_unit::{AllowedArray, AllowedQuant}};
use std::fmt::Debug;
use super::unit::UnitCompatible;

//Unless otherwise specified every function should make as shallow of a copy as possible
pub trait EngineTensor: Debug {
type Unit: AllowedUnit;
type Unit: UnitCompatible;

fn shape(&self) -> &Shape;
fn get(&self, pos: &Position) -> Self::Unit;
Expand All @@ -28,7 +29,7 @@ pub trait EngineTensor: Debug {
fn extensions(&self)-> Box<dyn ExtensionProvider + '_>;
}

impl<T: AllowedUnit> PartialEq for dyn EngineTensor<Unit = T> + '_ {
impl<T: UnitCompatible> PartialEq for dyn EngineTensor<Unit = T> + '_ {
fn eq(&self, other: &Self) -> bool {
self.shape() == other.shape() && self.iter_units().zip(other.iter_units()).all(|(a, b)| a == b)
}
Expand Down
8 changes: 4 additions & 4 deletions src/engine/tensor/padded.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::helper::{Shape, Stride, Position, VarArrayCompatible, VarArray, Slice};
use super::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, Array, extension::EmptyExtensionProvider};
use crate::{engine::unit::UnitCompatible, helper::{Position, Shape, Slice, VarArray, VarArrayCompatible}};
use super::{EngineTensor, factory::EngineTensorFactory, iter::EngineTensorUnitIterator, Array, extension::EmptyExtensionProvider};

pub trait AllowedPadded: AllowedUnit {}
impl<T: AllowedUnit> AllowedPadded for T {}
pub trait AllowedPadded: UnitCompatible {}
impl<T: UnitCompatible> AllowedPadded for T {}

#[derive(Debug)]
pub struct Padded<T: AllowedPadded> {
Expand Down
52 changes: 52 additions & 0 deletions src/engine/unit/basic_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::Base;

pub trait BasicOp: Base {
fn add(self, other: Self) -> Self;
fn sub(self, other: Self) -> Self;
fn mul(self, other: Self) -> Self;
fn div(self, other: Self) -> Self;
fn rem(self, other: Self) -> Self;
}

macro_rules! basic_op {
($unit:ty) => {
impl BasicOp for $unit {
fn add(self, other: Self) -> Self {
self + other
}

fn sub(self, other: Self) -> Self {
self - other
}

fn mul(self, other: Self) -> Self {
self * other
}

fn div(self, other: Self) -> Self {
self / other
}

fn rem(self, other: Self) -> Self {
self % other
}
}
};
}

basic_op!{f32}
basic_op!{f64}

basic_op!{i8}
basic_op!{i16}
basic_op!{i32}
basic_op!{i64}
basic_op!{i128}
basic_op!{isize}

basic_op!{u8}
basic_op!{u16}
basic_op!{u32}
basic_op!{u64}
basic_op!{u128}
basic_op!{usize}
49 changes: 49 additions & 0 deletions src/engine/unit/core_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
pub trait CoreValue {
fn zero() -> Self;
fn one() -> Self;
}

macro_rules! core_value_float {
($unit:ty) => {
impl CoreValue for $unit {
fn zero() -> Self {
0.0
}

fn one() -> Self {
1.0
}
}
};
}

macro_rules! core_value_int {
($unit:ty) => {
impl CoreValue for $unit {
fn zero() -> Self {
0
}

fn one() -> Self {
1
}
}
};
}

core_value_float!{f32}
core_value_float!{f64}

core_value_int!{i8}
core_value_int!{i16}
core_value_int!{i32}
core_value_int!{i64}
core_value_int!{i128}
core_value_int!{isize}

core_value_int!{u8}
core_value_int!{u16}
core_value_int!{u32}
core_value_int!{u64}
core_value_int!{u128}
core_value_int!{usize}
Loading

0 comments on commit 26fa9f4

Please sign in to comment.