Skip to content

Commit

Permalink
Refactor impls and spike for tensor builder
Browse files Browse the repository at this point in the history
  • Loading branch information
Riley Sutton authored and Riley Sutton committed Mar 21, 2024
1 parent 5fb329d commit abf594b
Show file tree
Hide file tree
Showing 19 changed files with 286 additions and 357 deletions.
57 changes: 28 additions & 29 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/comp_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ pub enum ComputationGraphError {
mod test {
use num::traits::Pow;

use crate::{engine::{basic::Basic, tensor::{generic::EngineTensorGeneric, Array}}, helper::Shape};
use crate::{engine_impl::{basic::Basic, tensor::array::Array}, helper::Shape};

use super::*;

Expand Down
4 changes: 0 additions & 4 deletions src/engine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
pub mod tensor;
pub mod unit;
pub mod basic;

mod shared;
mod util;

use crate::helper::{Shape, PositionError};
use self::{tensor::{factory::EngineTensorFactory, EngineTensor}, unit::UnitCompatible};
Expand Down
8 changes: 0 additions & 8 deletions src/engine/tensor/allowed_unit.rs

This file was deleted.

12 changes: 6 additions & 6 deletions src/engine/tensor/builder.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::ops::RangeBounds;

use crate::engine::unit::UnitCompatible;
use crate::{engine::unit::UnitCompatible, helper::Shape};

use super::EngineTensor;

trait EngineTensorBuilder {
pub trait EngineTensorBuilder {
type Unit: UnitCompatible;
type Tensor: EngineTensor;

fn splice<R: RangeBounds<usize>, I: IntoIterator<Item = Self::Unit>>(range: R, replace_with: I);
fn new(shape: Shape) -> Self;

fn construct() -> Box<dyn EngineTensor<Unit = Self::Unit>>;
}
fn splice<R: RangeBounds<usize>, I: IntoIterator<Item = Self::Unit>>(&mut self, range: R, replace_with: I);

struct ArrayBuilder;
fn construct(self) -> Self::Tensor;
}
49 changes: 6 additions & 43 deletions src/engine/tensor/factory.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,14 @@
use std::sync::Arc;
use crate::helper::{Shape, Stride};
use super::{allowed_unit::{AllowedArray, AllowedQuant}, generic::EngineTensorGeneric, Array, EngineTensor, Quant};
use crate::helper::Shape;
use super::EngineTensor;

pub trait EngineTensorFactory: EngineTensor + EngineTensorGeneric
where Self: Sized
pub trait EngineTensorFactory: EngineTensor
where Self: Sized + 'static
{
fn from_iter(iter: impl Iterator<Item = Self::Unit>, shape: Shape) -> Self;
fn from_slice(data: &[Self::Unit], shape: Shape) -> Self;
//fn builder() -> impl EngineTensorBuilder<Unit = Self::Unit>;
}

impl<T: AllowedArray> EngineTensorFactory for Array<T> {
fn from_iter(iter: impl Iterator<Item = T>, shape: Shape) -> Self {
Array {
stride: Stride::default_from_shape(&shape),
shape,
data: iter.collect(),
offset: 0,
}
}

fn from_slice(data: &[T], shape: Shape) -> Self {
Array {
stride: Stride::default_from_shape(&shape),
shape,
data: Arc::from(data),
offset: 0,
}
}
}

impl<T: AllowedQuant> EngineTensorFactory for Quant<T> {
fn from_iter(iter: impl Iterator<Item = T>, shape: Shape) -> Self {
Quant {
stride: Stride::default_from_shape(&shape),
shape,
data: iter.collect(),
offset: 0,
}
}

fn from_slice(data: &[T], shape: Shape) -> Self {
Quant {
stride: Stride::default_from_shape(&shape),
shape,
data: Arc::from(data),
offset: 0,
}
fn generic(self) -> Box<dyn EngineTensor<Unit = Self::Unit>> {
Box::from(self)
}
}
13 changes: 0 additions & 13 deletions src/engine/tensor/generic.rs

This file was deleted.

Loading

0 comments on commit abf594b

Please sign in to comment.