Skip to content

Commit

Permalink
Add matmul and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Riley Sutton authored and Riley Sutton committed Mar 28, 2024
1 parent ff03e7f commit 90a0cd3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub trait Engine<T: UnitCompatible> {
fn neg<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;

fn relu<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
fn leaky_relu<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, alpha: f32) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
fn leaky_relu<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, alpha: f64) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;
fn sigmoid<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError>;

//Pointwise Double
Expand Down
8 changes: 4 additions & 4 deletions src/engine/unit/core_func.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{iter::{Product, Sum}, ops::{Add, Div, Mul, Rem, Sub}};

use super::{core_value::CoreValue, exponential_op::ExponentialOp, Base};
use super::{core_value::CoreValue, exponential_op::ExponentialOp, scale::Scale, Base};

pub trait CoreFunc: CoreValue + ExponentialOp + Add<Output = Self>
pub trait CoreFunc: CoreValue + ExponentialOp + Scale + Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
Expand All @@ -16,8 +16,8 @@ pub trait CoreFunc: CoreValue + ExponentialOp + Add<Output = Self>
fn relu(self) -> Self {
if self > Self::zero() { self } else { Self::zero() }
}
fn leaky_relu(self) -> Self {
if self > Self::zero() { self } else { Self::zero() }
fn leaky_relu(self, alpha: f64) -> Self {
if self > Self::zero() { self } else { self.scale_double(alpha) }
}
fn sigmoid(self) -> Self {
let exp = self.exp();
Expand Down
112 changes: 98 additions & 14 deletions src/engine_impl/basic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use itertools::Itertools;
use std::iter;

use crate::{engine::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible, Engine, EngineError}, engine_impl::{shared::im2col_2d, util::{err_if_dimension_mismatch, err_if_dimensions_mistmatch, err_if_incorrect_num_dimensions, err_if_too_few_dimensions}}, helper::{shape, varr, Shape, VarArray, VarArrayCompatible}};
use itertools::Itertools;

use crate::{engine::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible, Engine, EngineError}, engine_impl::{shared::im2col_2d, util::{err_if_dimension_mismatch, err_if_dimensions_mistmatch, err_if_incorrect_num_dimensions, err_if_too_few_dimensions}}, helper::{shape, varr, Interval, Shape, VarArray, VarArrayCompatible}};
use crate::engine::tensor::builder::EngineTensorBuilder;
pub struct Basic {}

impl<T: UnitCompatible> Engine<T> for Basic {
Expand All @@ -19,16 +21,12 @@ impl<T: UnitCompatible> Engine<T> for Basic {
Ok(E::from_iter(a.iter_units().map(|x: T| if x > T::zero() { x } else { T::zero() }), a.shape().clone()).generic())
}

fn leaky_relu<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, alpha: f32) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError> {
Ok(E::from_iter(a.iter_units().map(|x: T| if x > T::zero() { x } else { x.scale_single(alpha) }), a.shape().clone()).generic())
fn leaky_relu<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>, alpha: f64) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError> {
Ok(E::from_iter(a.iter_units().map(|x: T| x.leaky_relu(alpha)), a.shape().clone()).generic())
}

fn sigmoid<E: EngineTensorFactory<Unit = T>>(a: &dyn EngineTensor<Unit = T>) -> Result<Box<dyn EngineTensor<Unit = T>>, EngineError> {
Ok(E::from_iter(a.iter_units().map(|x: T| {
let x_exp = x.exp();

x_exp / (T::one() + x_exp)
}), a.shape().clone()).generic())
Ok(E::from_iter(a.iter_units().map(|x: T| x.sigmoid()), a.shape().clone()).generic())
}

//Pointwise Double
Expand Down Expand Up @@ -79,14 +77,61 @@ impl<T: UnitCompatible> Engine<T> for Basic {
b = b.broadcast_splice(0, &a.shape().as_slice()[0..(a.shape().len() - b.shape().len())]);
}

err_if_dimensions_mistmatch(&a.shape().as_slice()[0..(a.shape().len() - 2)], &b.shape().as_slice()[0..(b.shape().len() - 2)])?;
err_if_dimension_mismatch(a.shape().get(a.shape().len() - 1).unwrap(), b.shape().get(b.shape().len() - 2).unwrap())?;
let out_rows = a.shape().get(a.shape().len() - 2).unwrap();
let out_columns = b.shape().get(b.shape().len() - 1).unwrap();

let a_batches_shape = Shape::from(&a.shape().as_slice()[0..(a.shape().len() - 2)]);
let b_batches_shape = Shape::from(&b.shape().as_slice()[0..(b.shape().len() - 2)]);

let a_columns = a.shape().get(a.shape().len() - 1).unwrap();
let b_rows = b.shape().get(b.shape().len() - 2).unwrap();

err_if_dimensions_mistmatch(a_batches_shape.as_slice(), b_batches_shape.as_slice())?;
err_if_dimension_mismatch(a_columns, b_rows)?;

let out_shape = Shape::new(VarArray::concat(a_batches_shape.vararray(), &varr![out_rows, out_columns]));

let mut builder = E::builder(out_shape, T::default());

let mut a_intervals = [
(0..a_batches_shape.len()).map(|_| Interval::all()).collect::<Vec<Interval>>().as_slice(),
[Interval::only(out_rows), Interval::all()].as_slice()
].concat().into_boxed_slice();
let a_intervals_row_index = a_intervals.len() - 2;

let out_shape = Shape::new(VarArray::concat(&VarArray::from(&a.shape().as_slice()[0..(a.shape().len() - 2)]), &varr![a.shape().get(a.shape().len() - 2).unwrap(), b.shape().get(b.shape().len() - 1).unwrap()]));
let mut b_intervals = [
(0..b_batches_shape.len()).map(|_| Interval::all()).collect::<Vec<Interval>>().as_slice(),
[Interval::all(), Interval::only(out_columns)].as_slice()
].concat().into_boxed_slice();
let b_intervals_column_index = b_intervals.len() - 1;

let builder = E::builder(out_shape, T::default());
let mut out_intervals: Box<[Interval]> = [
(0..b_batches_shape.len()).map(|_| Interval::all()).collect::<Vec<Interval>>().as_slice(),
[Interval::only(out_rows), Interval::only(out_columns)].as_slice()
].concat().into_boxed_slice();
let out_intervals_row_index = out_intervals.len() - 2;
let out_intervals_column_index = out_intervals.len() - 1;

todo!()
for row in 0..out_rows {
for column in 0..out_columns {
// sum(a(row, 1..n) * b(1..n, col))

*a_intervals.get_mut(a_intervals_row_index).unwrap() = Interval::only(row);
*b_intervals.get_mut(b_intervals_column_index).unwrap() = Interval::only(column);

let a_slice = a.slice(&a_intervals);
let b_slice = b.slice(&b_intervals);

let chunks = a_slice.iter_units().zip(b_slice.iter_units()).map(|(a_e, b_e)| a_e * b_e).chunks(a_columns);

*out_intervals.get_mut(out_intervals_row_index).unwrap() = Interval::only(row);
*out_intervals.get_mut(out_intervals_column_index).unwrap() = Interval::only(column);

builder.splice_slice(&out_intervals, chunks.into_iter().map(|c| c.sum()))
}
}

Ok(builder.construct().generic())
}

//Conv
Expand Down Expand Up @@ -142,4 +187,43 @@ mod test {
println!("{:?}", res.shape());
//println!("{:?}", res.iter_unit().collect::<Vec<f32>>());
}

#[test]
pub fn matmul_basic() {
let a = Array::from_slice(&[1., 2., 3., 4., 5., 6.], shape![2, 3]);
let b = Array::from_slice(&[10., 11., 20., 21., 30., 31.], shape![3, 2]);

let expected = Array::from_slice(&[140., 146., 320., 335.], shape![2, 2]);

let res = Basic::matmul::<Array<f32>>(&a, &b).unwrap();

assert!(res == expected.generic());

let a = Array::from_slice(&[1., 2., 3., 4., 5., 6., 2., 4., 6., 8., 10., 12.], shape![2, 2, 3]);
let b = Array::from_slice(&[10., 11., 20., 21., 30., 31.], shape![3, 2]);

let expected = Array::from_slice(&[140., 146., 320., 335., 280., 292., 640., 670.], shape![2, 2, 2]);

let res = Basic::matmul::<Array<f32>>(&a, &b).unwrap();

assert!(res == expected.generic());

let a = Array::from_slice(&[1., 2., 3., 4., 5., 6.], shape![2, 3]);
let b = Array::from_slice(&[10., 11., 20., 21., 30., 31., 20., 22., 40., 42., 60., 62.], shape![2, 3, 2]);

let expected = Array::from_slice(&[140., 146., 320., 335., 280., 292., 640., 670.], shape![2, 2, 2]);

let res = Basic::matmul::<Array<f32>>(&a, &b).unwrap();

assert!(res == expected.generic());

let a = Array::from_slice(&[1., 2., 3., 4., 5., 6., 2., 4., 6., 8., 10., 12.], shape![2, 2, 3]);
let b = Array::from_slice(&[10., 11., 20., 21., 30., 31., 20., 22., 40., 42., 60., 62.], shape![2, 3, 2]);

let expected = Array::from_slice(&[140., 146., 320., 335., 560., 584., 1280., 1340.], shape![2, 2, 2]);

let res = Basic::matmul::<Array<f32>>(&a, &b).unwrap();

assert!(res == expected.generic());
}
}

0 comments on commit 90a0cd3

Please sign in to comment.