diff --git a/ideas.md b/ideas.md index e4038ff..92cfdd9 100644 --- a/ideas.md +++ b/ideas.md @@ -29,4 +29,5 @@ - Use phantom to make tensors references to graph so it can't outlive graph - Probably remove the distinction between context and comp_graph - Dump graph on calculation - - Allow for recalc maybe \ No newline at end of file + - Allow for recalc maybe (as in mutating tensors within the graph) +- Model how the graph interface should look externally as ergonomics is an issue rn \ No newline at end of file diff --git a/src/context/edge.rs b/src/comp_graph/edge.rs similarity index 98% rename from src/context/edge.rs rename to src/comp_graph/edge.rs index 2706118..8ff6167 100644 --- a/src/context/edge.rs +++ b/src/comp_graph/edge.rs @@ -1,6 +1,6 @@ use crate::engine::{tensor::{allowed_unit::AllowedUnit, EngineTensor}, EngineError}; -use super::comp_graph::{NodeKey, ComputationGraphError}; +use super::{NodeKey, ComputationGraphError}; #[derive(Clone, Copy, Debug, PartialEq)] pub enum Edge { diff --git a/src/context/comp_graph.rs b/src/comp_graph/mod.rs similarity index 67% rename from src/context/comp_graph.rs rename to src/comp_graph/mod.rs index 186fd99..927b1f6 100644 --- a/src/context/comp_graph.rs +++ b/src/comp_graph/mod.rs @@ -1,11 +1,13 @@ +mod edge; + use std::collections::{HashSet, HashMap}; use slotmap::{SlotMap, new_key_type}; use thiserror::Error; -use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory}, Engine, EngineError}; +use crate::engine::{tensor::{EngineTensor, allowed_unit::AllowedUnit, factory::EngineTensorFactory, iter::EngineTensorIterator}, Engine, EngineError}; -use super::edge::Edge; +use self::edge::Edge; #[derive(Debug)] pub struct Node { @@ -14,29 +16,29 @@ pub struct Node { } impl Node { - pub fn create_root(tensor: Box>) -> Self { + fn create_root(tensor: Box>) -> Self { Self { tensor: Some(tensor), edge: Edge::Root, } } - pub fn create_node(edge: Edge) -> Self { + fn create_node(edge: Edge) -> Self { Self { tensor: None, edge, } } - pub fn tensor(&self) -> Option<&dyn EngineTensor> { + fn tensor(&self) -> Option<&dyn EngineTensor> { self.tensor.as_deref() } - pub fn set_tensor(&mut self, tensor: Box>) { + fn set_tensor(&mut self, tensor: Box>) { self.tensor = Some(tensor) } - pub fn clear_tensor(&mut self) -> Result<(), ComputationGraphError> { + fn clear_tensor(&mut self) -> Result<(), ComputationGraphError> { if self.is_root() { return Err(ComputationGraphError::CannotClearRoot()) } @@ -45,11 +47,11 @@ impl Node { Ok(()) } - pub fn edge(&self) -> &Edge { + fn edge(&self) -> &Edge { &self.edge } - pub fn is_root(&self) -> bool { + fn is_root(&self) -> bool { *self.edge() == Edge::Root } } @@ -70,32 +72,38 @@ impl CompGraph { } } - pub fn get_node(&self, node_key: &NodeKey) -> Option<&Node> { + fn get_node(&self, node_key: &NodeKey) -> Option<&Node> { self.nodes.get(*node_key) } - pub fn get_node_error(&self, node_key: &NodeKey) -> Result<&Node, ComputationGraphError> { + fn get_node_error(&self, node_key: &NodeKey) -> Result<&Node, ComputationGraphError> { self.nodes.get(*node_key).ok_or(ComputationGraphError::NodeDoesNotExist(*node_key)) } - pub fn get_node_mut(&mut self, node_key: &NodeKey) -> Option<&mut Node> { + fn get_node_mut(&mut self, node_key: &NodeKey) -> Option<&mut Node> { self.nodes.get_mut(*node_key) } - pub fn get_node_mut_error(&mut self, node_key: &NodeKey) -> Result<&mut Node, ComputationGraphError> { + fn get_node_mut_error(&mut self, node_key: &NodeKey) -> Result<&mut Node, ComputationGraphError> { self.nodes.get_mut(*node_key).ok_or(ComputationGraphError::NodeDoesNotExist(*node_key)) } //Root is a node that is a starting point for computation - pub fn create_root(&mut self, tensor: Box>) -> NodeKey { + fn create_root_node(&mut self, tensor: Box>) -> NodeKey { self.nodes.insert(Node::create_root(tensor)) } - pub fn create_node(&mut self, edge: Edge) -> NodeKey { - self.nodes.insert(Node::create_node(edge)) + pub fn create_root(&mut self, tensor: Box>) -> CompGraphTensor { + CompGraphTensor::new(self.create_root_node(tensor)) } + fn create_node(&mut self, edge: Edge) -> NodeKey { + self.nodes.insert(Node::create_node(edge)) + } + pub fn iter(&self, tensor: &CompGraphTensor) -> EngineTensorIterator { + EngineTensorIterator::new(self.get_node(tensor.node_key()).unwrap().tensor().unwrap()) + } //First return is open nodes, second is node_to_children //The algorithm is more efficient if done at the same time @@ -140,7 +148,7 @@ impl CompGraph { } //Uses Kahn's Algorithm - pub fn populating_eval(&mut self, target: NodeKey) -> Result<(), ComputationGraphError> { + fn populating_eval_node(&mut self, target: NodeKey) -> Result<(), ComputationGraphError> { //Nodes that have all dependencies satisfied let (open_roots, node_to_children) = self.generate_node_to_children(&target)?; @@ -181,7 +189,11 @@ impl CompGraph { Ok(()) } - pub fn non_populating_eval(&mut self, target: NodeKey) -> Result<(), ComputationGraphError> { + pub fn populating_eval(&mut self, target: &CompGraphTensor) -> Result<(), ComputationGraphError> { + self.populating_eval_node(*target.node_key()) + } + + fn non_populating_eval_node(&mut self, target: NodeKey) -> Result<(), ComputationGraphError> { //Nodes that have all dependencies satisfied let (open_roots, node_to_children) = self.generate_node_to_children(&target)?; @@ -240,52 +252,75 @@ impl CompGraph { Ok(()) } - pub fn abs, F: EngineTensorFactory>(&mut self, a: NodeKey) -> NodeKey { - self.create_node(Edge::Abs(a, E::abs::)) + pub fn non_populating_eval(&mut self, target: &CompGraphTensor) -> Result<(), ComputationGraphError> { + self.non_populating_eval_node(*target.node_key()) + } + + pub fn abs, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Abs(*a.node_key(), E::abs::))) } - pub fn neg, F: EngineTensorFactory>(&mut self, a: NodeKey) -> NodeKey { - self.create_node(Edge::Neg(a, E::neg::)) + pub fn neg, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Neg(*a.node_key(), E::neg::))) } - pub fn add_scalar, F: EngineTensorFactory>(&mut self, s: T, a: NodeKey) -> NodeKey { - self.create_node(Edge::AddScalar(s, a, E::add_scalar::)) + pub fn add_scalar, F: EngineTensorFactory>(&mut self, s: T, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::AddScalar(s, *a.node_key(), E::add_scalar::))) } - pub fn sub_scalar_lh, F: EngineTensorFactory>(&mut self, s: T, a: NodeKey) -> NodeKey { - self.create_node(Edge::SubScalarLH(s, a, E::sub_scalar_lh::)) + pub fn sub_scalar_lh, F: EngineTensorFactory>(&mut self, s: T, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::SubScalarLH(s, *a.node_key(), E::sub_scalar_lh::))) } - pub fn sub_scalar_rh, F: EngineTensorFactory>(&mut self, a: NodeKey, s: T) -> NodeKey { - self.create_node(Edge::SubScalarRH(a, s, E::sub_scalar_rh::)) + pub fn sub_scalar_rh, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, s: T) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::SubScalarRH(*a.node_key(), s, E::sub_scalar_rh::))) } - pub fn mul_scalar, F: EngineTensorFactory>(&mut self, s: T, a: NodeKey) -> NodeKey { - self.create_node(Edge::MulScalar(s, a, E::mul_scalar::)) + pub fn mul_scalar, F: EngineTensorFactory>(&mut self, s: T, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::MulScalar(s, *a.node_key(), E::mul_scalar::))) } - pub fn div_scalar_lh, F: EngineTensorFactory>(&mut self, s: T, a: NodeKey) -> NodeKey { - self.create_node(Edge::DivScalarLH(s, a, E::div_scalar_lh::)) + pub fn div_scalar_lh, F: EngineTensorFactory>(&mut self, s: T, a: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::DivScalarLH(s, *a.node_key(), E::div_scalar_lh::))) } - pub fn div_scalar_rh, F: EngineTensorFactory>(&mut self, a: NodeKey, s: T) -> NodeKey { - self.create_node(Edge::DivScalarRH(a, s, E::div_scalar_rh::)) + pub fn div_scalar_rh, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, s: T) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::DivScalarRH(*a.node_key(), s, E::div_scalar_rh::))) } - pub fn add, F: EngineTensorFactory>(&mut self, a: NodeKey, b: NodeKey) -> NodeKey { - self.create_node(Edge::Add(a, b, E::add::)) + pub fn add, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, b: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Add(*a.node_key(), *b.node_key(), E::add::))) } - pub fn sub, F: EngineTensorFactory>(&mut self, a: NodeKey, b: NodeKey) -> NodeKey { - self.create_node(Edge::Sub(a, b, E::sub::)) + pub fn sub, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, b: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Sub(*a.node_key(), *b.node_key(), E::sub::))) } - pub fn mul, F: EngineTensorFactory>(&mut self, a: NodeKey, b: NodeKey) -> NodeKey { - self.create_node(Edge::Mul(a, b, E::mul::)) + pub fn mul, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, b: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Mul(*a.node_key(), *b.node_key(), E::mul::))) + } + + pub fn div, F: EngineTensorFactory>(&mut self, a: &CompGraphTensor, b: &CompGraphTensor) -> CompGraphTensor { + CompGraphTensor::new(self.create_node(Edge::Div(*a.node_key(), *b.node_key(), E::div::))) + } +} + +//External handle for nodes +//Lifetime tied to graph +#[derive(Debug, Clone)] +pub struct CompGraphTensor { + node_key: NodeKey, +} + +impl CompGraphTensor { + fn new(node_key: NodeKey) -> Self { + Self { + node_key, + } } - pub fn div, F: EngineTensorFactory>(&mut self, a: NodeKey, b: NodeKey) -> NodeKey { - self.create_node(Edge::Div(a, b, E::div::)) + fn node_key(&self) -> &NodeKey { + &self.node_key } } @@ -313,20 +348,20 @@ mod test { use super::*; - pub fn init_simple_graph() -> (NodeKey, NodeKey, NodeKey, Box>, CompGraph) { + pub fn init_simple_graph() -> (CompGraphTensor, CompGraphTensor, CompGraphTensor, Box>, CompGraph) { let mut graph = CompGraph::::new(); let root1 = graph.create_root(Array::from_slice([0.0, 1.0, 2.0, 3.0].as_slice(), Shape::from([2, 2].as_slice()))); let root2 = graph.create_root(Array::from_slice([0.0, 1.0, 2.0, 3.0].as_slice(), Shape::from([2, 2].as_slice()))); - let added = graph.add::>(root1, root2); + let added = graph.add::>(&root1, &root2); let expected = Array::from_slice([0.0, 2.0, 4.0, 6.0].as_slice(), Shape::from([2, 2].as_slice())); return (root1, root2, added, expected, graph); } - pub fn init_complex_graph() -> (NodeKey, Box>, Box>, CompGraph) { + pub fn init_complex_graph() -> (CompGraphTensor, Box>, Box>, CompGraph) { let mut graph = CompGraph::::new(); let root1 = graph.create_root(Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].as_slice(), Shape::from([3, 3].as_slice()))); @@ -334,16 +369,16 @@ mod test { let root3 = graph.create_root(Array::from_slice([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0].as_slice(), Shape::from([3, 3].as_slice()))); let root4 = graph.create_root(Array::from_slice([1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0].as_slice(), Shape::from([3, 3].as_slice()))); - let op1 = graph.div::>(root4, root1); - let op2 = graph.mul::>(op1, root2); - let op3 = graph.sub::>(op2, root3); + let op1 = graph.div::>(&root4, &root1); + let op2 = graph.mul::>(&op1, &root2); + let op3 = graph.sub::>(&op2, &root3); - let op4 = graph.mul_scalar::>(2.0, op3); - let op5 = graph.div_scalar_rh::>(op4, 2.0); + let op4 = graph.mul_scalar::>(2.0, &op3); + let op5 = graph.div_scalar_rh::>(&op4, 2.0); - let op6 = graph.mul::>(op5, op5); + let op6 = graph.mul::>(&op5, &op5); - let op7 = graph.div::>(op6, root1); + let op7 = graph.div::>(&op6, &root1); return (op7, Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0].as_slice(), Shape::from([3, 3].as_slice())), Array::from_slice([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0].as_slice(), Shape::from([3, 3].as_slice())), graph); } @@ -352,9 +387,9 @@ mod test { fn simple_no_eval() { let (_, _, added, _, graph) = init_simple_graph(); - assert!(graph.get_node(&added).is_some()); + assert!(graph.get_node(added.node_key()).is_some()); - let node = graph.get_node(&added).unwrap(); + let node = graph.get_node(added.node_key()).unwrap(); assert!(node.tensor().is_none()); } @@ -363,11 +398,11 @@ mod test { fn simple_eval() { let (_, _, added, expected, mut graph) = init_simple_graph(); - graph.non_populating_eval(added).unwrap(); + graph.non_populating_eval(&added).unwrap(); - assert!(graph.get_node(&added).is_some()); + assert!(graph.get_node(added.node_key()).is_some()); - let node = graph.get_node_mut(&added).unwrap(); + let node = graph.get_node_mut(added.node_key()).unwrap(); assert!(node.tensor().is_some()); @@ -375,11 +410,11 @@ mod test { node.clear_tensor().unwrap(); - graph.populating_eval(added).unwrap(); + graph.populating_eval(&added).unwrap(); - assert!(graph.get_node(&added).is_some()); + assert!(graph.get_node(added.node_key()).is_some()); - let node = graph.get_node(&added).unwrap(); + let node = graph.get_node(added.node_key()).unwrap(); assert!(node.tensor().is_some()); @@ -394,20 +429,20 @@ mod test { let mut out = node_key; for _ in 0..2_usize.pow(power as u32) { - out = graph.div::>(out, out); + out = graph.div::>(&out, &out); } - graph.non_populating_eval(out).unwrap(); + graph.non_populating_eval(&out).unwrap(); - let node = graph.get_node_mut(&out).unwrap(); + let node = graph.get_node_mut(out.node_key()).unwrap(); assert_eq!(*node.tensor().unwrap(), *expected_unit); node.clear_tensor().unwrap(); - graph.populating_eval(out).unwrap(); + graph.populating_eval(&out).unwrap(); - let node = graph.get_node(&out).unwrap(); + let node = graph.get_node(out.node_key()).unwrap(); assert_eq!(*node.tensor().unwrap(), *expected_unit); } @@ -418,36 +453,36 @@ mod test { let power = 12u16; - let mut curr_node_keys: Vec; + let mut curr_node_keys: Vec; let mut new_node_keys = vec![node_key; 2_usize.pow(power as u32)]; while new_node_keys.len() > 1 { curr_node_keys = new_node_keys; - new_node_keys = Vec::::new(); + new_node_keys = Vec::::new(); for keys in curr_node_keys.chunks_exact(2) { - let a_key = keys[0]; - let b_key = keys[1]; + let a_key = &keys[0]; + let b_key = &keys[1]; - new_node_keys.push(graph.add::>(a_key, b_key)); + new_node_keys.push(graph.add::>(&a_key, &b_key)); } } - let node_key = new_node_keys.last().unwrap(); + let tensor = new_node_keys.last().unwrap(); let expected = Array::from_iter( &mut expected_original.iter().map(|x| x * 2.0f32.pow(power)), expected_original.shape().clone()); - graph.non_populating_eval(*node_key).unwrap(); + graph.non_populating_eval(&tensor).unwrap(); - let node = graph.get_node_mut(node_key).unwrap(); + let node = graph.get_node_mut(tensor.node_key()).unwrap(); assert_eq!(*node.tensor().unwrap(), *expected); node.clear_tensor().unwrap(); - graph.populating_eval(*node_key).unwrap(); + graph.populating_eval(tensor).unwrap(); - let node = graph.get_node(node_key).unwrap(); + let node = graph.get_node(tensor.node_key()).unwrap(); assert_eq!(*node.tensor().unwrap(), *expected); } diff --git a/src/context/mod.rs b/src/context/mod.rs deleted file mode 100644 index 9d4ede2..0000000 --- a/src/context/mod.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::marker::PhantomData; - -use crate::{engine::{tensor::{allowed_unit::AllowedUnit, factory::EngineTensorFactory}, Engine}, helper::Shape}; - -use self::comp_graph::{CompGraph, NodeKey}; - -mod comp_graph; -mod edge; - -#[derive(Debug)] -pub struct Context> { - comp_graph: CompGraph, - default_engine: PhantomData, -} - -pub struct ContextTensor { - node: NodeKey, -} - -impl ContextTensor { - pub fn new(node: NodeKey) -> Self { - Self { - node, - } - } - - pub fn node(&self) -> NodeKey { - self.node - } -} - -impl> Context { - pub fn new() -> Self { - Self { - comp_graph: CompGraph::new(), - default_engine: PhantomData, - } - } - - pub fn eval(&mut self, tensor: &ContextTensor) { - self.comp_graph.populating_eval(tensor.node).unwrap(); - } - - pub fn from_iter>(&mut self, iter: &mut dyn Iterator, shape: Shape) -> ContextTensor { - ContextTensor::new(self.comp_graph.create_root(F::from_iter(iter, shape))) - } - - pub fn from_slice>(&mut self, slice: &[T], shape: Shape) -> ContextTensor { - ContextTensor::new(self.comp_graph.create_root(F::from_slice(slice, shape))) - } - - pub fn iter(&mut self, tensor: &ContextTensor) -> Box + '_> { - self.eval(tensor); - - Box::from(self.comp_graph.get_node(&tensor.node()).unwrap().tensor().unwrap().iter()) - } - - pub fn abs>(&mut self, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.abs::(a.node())) - } - - pub fn neg>(&mut self, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.neg::(a.node())) - } - - pub fn add_scalar>(&mut self, s: T, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.add_scalar::(s, a.node())) - } - - pub fn sub_scalar_lh>(&mut self, s: T, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.sub_scalar_lh::(s, a.node())) - } - - pub fn sub_scalar_rh>(&mut self, a: &ContextTensor, s: T) -> ContextTensor { - ContextTensor::new(self.comp_graph.sub_scalar_rh::(a.node(), s)) - } - - pub fn mul_scalar>(&mut self, s: T, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.mul_scalar::(s, a.node())) - } - - pub fn div_scalar_lh>(&mut self, s: T, a: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.div_scalar_lh::(s, a.node())) - } - - pub fn div_scalar_rh>(&mut self, a: &ContextTensor, s: T) -> ContextTensor { - ContextTensor::new(self.comp_graph.div_scalar_rh::(a.node(), s)) - } - - pub fn add>(&mut self, a: &ContextTensor, b: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.add::(a.node(), b.node())) - } - - pub fn sub>(&mut self, a: &ContextTensor, b: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.sub::(a.node(), b.node())) - } - - pub fn mul>(&mut self, a: &ContextTensor, b: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.mul::(a.node(), b.node())) - } - - pub fn div>(&mut self, a: &ContextTensor, b: &ContextTensor) -> ContextTensor { - ContextTensor::new(self.comp_graph.div::(a.node(), b.node())) - } -} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 8554b10..a0d4515 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,24 +1,27 @@ #![allow(dead_code)] -use context::Context; use engine::{tensor::Array, basic::Basic}; use helper::Shape; +use crate::{comp_graph::CompGraph, engine::tensor::factory::EngineTensorFactory}; + mod engine; mod helper; -mod context; +mod comp_graph; fn main() { - let mut context = Context::::new(); + let mut graph = CompGraph::::new(); - let a = context.from_slice::>([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.].as_slice(), Shape::from([4, 3].as_slice())); - let b = context.from_slice::>([2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.].as_slice(), Shape::from([4, 3].as_slice())); + let a = graph.create_root(Array::from_slice([1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.].as_slice(), Shape::from([4, 3].as_slice()))); + let b = graph.create_root(Array::from_slice([2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.].as_slice(), Shape::from([4, 3].as_slice()))); let mut c = a; for _ in 0..100000 { - c = context.div_scalar_rh::>(&c, 1.00001); + c = graph.div_scalar_rh::>(&c, 1.00001); } - println!("{:?}", context.iter(&c).collect::>()); + graph.non_populating_eval(&c).unwrap(); + + println!("{:?}", graph.iter(&c).collect::>()); //println!("{:?}", context); } \ No newline at end of file