diff --git a/examples/project.rs b/examples/project.rs index 3cfb0ab..8e0dc32 100644 --- a/examples/project.rs +++ b/examples/project.rs @@ -64,17 +64,17 @@ impl gantz::Node for Debug { } #[typetag::serde] -impl gantz::project::SerdeNode for One { +impl gantz::node::SerdeNode for One { fn node(&self) -> &gantz::Node { self } } #[typetag::serde] -impl gantz::project::SerdeNode for Add { +impl gantz::node::SerdeNode for Add { fn node(&self) -> &gantz::Node { self } } #[typetag::serde] -impl gantz::project::SerdeNode for Debug { +impl gantz::node::SerdeNode for Debug { fn node(&self) -> &gantz::Node { self } } @@ -84,7 +84,7 @@ fn main() { let mut project = gantz::Project::open(path.into()).unwrap(); // Instantiate the core nodes. - let one = Box::new(One) as Box; + let one = Box::new(One) as Box; let add = Box::new(Add) as Box<_>; let debug = Box::new(Debug) as Box<_>; diff --git a/src/graph.rs b/src/graph.rs index d1c234e..0877653 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -15,9 +15,20 @@ pub struct Edge { pub input: node::Input, } -/// The petgraph type used to represent the **Graph**. +/// The petgraph type used to represent a gantz graph. +pub type Graph = petgraph::Graph; + +/// The petgraph type used to represent a stable gantz graph. pub type StableGraph = petgraph::stable_graph::StableGraph; +impl Edge { + /// Create an edge representing a connection from the given node `Output` to the given node + /// `Input`. + pub fn new(output: node::Output, input: node::Input) -> Self { + Edge { output, input } + } +} + impl Node for StableGraph where N: Node, @@ -35,11 +46,23 @@ where } } +impl From<(A, B)> for Edge +where + A: Into, + B: Into, +{ + fn from((a, b): (A, B)) -> Self { + let output = a.into(); + let input = b.into(); + Edge { output, input } + } +} + pub mod codegen { use crate::node::{self, Node}; use petgraph::visit::{Data, EdgeRef, GraphRef, IntoEdgesDirected, IntoNodeReferences, - NodeIndexable, NodeRef, Visitable}; - use std::collections::HashMap; + NodeIndexable, NodeRef, Visitable, Walker}; + use std::collections::{HashMap, HashSet}; use std::hash::Hash; use super::Edge; use syn::punctuated::Punctuated; @@ -102,16 +125,24 @@ pub mod codegen { where G: GraphRef + IntoEdgesDirected + IntoNodeReferences + NodeIndexable + Visitable, G: Data, + G::NodeId: Eq + Hash, ::Weight: Node, { + // First, find all nodes reachable by a `DFS` from this node. + let dfs: HashSet = petgraph::visit::Dfs::new(g, n).iter(g).collect(); + // The order of evaluation is DFS post order. - let mut dfs_post_order = petgraph::visit::Dfs::new(g, n); + let mut traversal = petgraph::visit::Topo::new(g); // Track the evaluation steps. let mut eval_steps = vec![]; // Step through each of the nodes. - while let Some(node) = dfs_post_order.next(g) { + while let Some(node) = traversal.next(g) { + if !dfs.contains(&node) { + continue; + } + // Fetch the node reference. let child = g.node_references() .nth(g.to_index(node)) diff --git a/src/node/expr.rs b/src/node/expr.rs new file mode 100644 index 0000000..b05dc77 --- /dev/null +++ b/src/node/expr.rs @@ -0,0 +1,174 @@ +use crate::node::Node; +use proc_macro2::{TokenStream, TokenTree}; +use quote::{TokenStreamExt, ToTokens}; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::str::FromStr; + +/// A simple node that allows for representing rust expressions as nodes within a gantz graph. +/// +/// E.g. the following expression: +/// +/// ```ignore +/// #freq.sin() * #amp +/// ``` +/// +/// will result in a single node with two inputs (`#freq` and `#amp`) and a single output which is +/// the result of the expression. +/// +/// ## Limitations +/// +/// Currently expressions cannot contain any of the following: +/// +/// - Attributes, e.g. `{ #[cfg(target_os = "macos")] { 2 + 2 } }` is a valid expr but not allowed. +/// - Raw strings, e.g. `{ r#"blah blah"# }` is a valid expr but not allowed. +/// - Comments containing the `#` token. +/// +/// These limitations are caused by the primitive way in which string interpolation is achieved (we +/// simply count each of the occurrences of `#`). This may be improved in the future. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Expr { + #[serde(with = "crate::node::serde::tts")] + tokens: TokenStream, +} + +/// An error occurred while constructing the `Expr` node. +#[derive(Debug, Fail, From)] +pub enum NewExprError { + #[fail(display = "failed to parse the `str` as a valid `TokenStream`")] + InvalidTokenStream, + #[fail(display = "failed to parse the `str` as a valid expr: {}", err)] + InvalidExpr { + #[fail(cause)] + err: syn::Error, + }, +} + +impl Expr { + /// Construct an **Expr** node from the given rust expression. + /// + /// Returns an **Err** if the given string is not a valid expression when interpolated with + /// valid sub-expressions. + /// + /// ```rust + /// fn main() { + /// let _node = gantz::node::Expr::new("#foo + #bar").unwrap(); + /// } + /// ``` + pub fn new(expr: &str) -> Result { + // Retrieve the `TokenStream`. + let tokens = TokenStream::from_str(expr).map_err(|_| NewExprError::InvalidTokenStream)?; + // Count the number of inputs. + let n_inputs = count_hashes(&tokens); + // Interpolate the `TokenStream` with some temp `{}` expressions. + let unit_expr: syn::Expr = syn::parse_quote!{ {} }; + let test_expr_tokens = interpolate_tokens(&tokens, vec![unit_expr; n_inputs as usize]); + let test_expr_str = format!("{}", test_expr_tokens); + let _: syn::Expr = syn::parse_str(&test_expr_str)?; + // If we got this far, we have a valid `Expr`! + Ok(Expr { tokens }) + } +} + +impl Node for Expr { + fn n_inputs(&self) -> u32 { + count_hashes(&self.tokens) + } + + fn n_outputs(&self) -> u32 { + 1 + } + + fn expr(&self, args: Vec) -> syn::Expr { + let args_tokens = args.into_iter().map(|expr| expr.into_token_stream()); + let expr_tokens = interpolate_tokens(&self.tokens, args_tokens); + syn::parse_quote! { #expr_tokens } + } +} + +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.tokens) + } +} + +// A `Punct` instance representing a `#`. +fn hash_punct() -> proc_macro2::Punct { + proc_macro2::Punct::new('#', proc_macro2::Spacing::Alone) +} + +// Given a token stream, count all occurrences of `#`. +fn count_hashes(tokens: T) -> u32 +where + T: ToTokens, +{ + let mut count = 0; + for t in tokens.into_token_stream() { + match t { + TokenTree::Punct(ref p) if format!("{}", p) == format!("{}", hash_punct()) => { + count += 1; + } + TokenTree::Group(ref g) => { + count += count_hashes(g.stream()); + } + _ => (), + } + } + count +} + +// Given a token stream, sequentially replace each occurrence of `#var` with each expression. +fn interpolate_tokens(tokens: T, exprs: E) -> TokenStream +where + T: ToTokens, + E: IntoIterator, + E::Item: ToTokens, +{ + fn interpolate_tokens_inner(tokens: TokenStream, exprs: &mut E) -> TokenStream + where + E: Iterator, + E::Item: ToTokens, + { + let mut tokens = tokens.into_iter(); + let mut new_tokens = TokenStream::default(); + while let Some(t) = tokens.next() { + match t { + TokenTree::Punct(ref p) if format!("{}", p) == format!("{}", hash_punct()) => { + if let Some(expr) = exprs.next() { + tokens.next(); + new_tokens.append_all(expr.into_token_stream()); + } + } + TokenTree::Group(g) => { + let new_group_tokens = interpolate_tokens_inner(g.stream(), exprs); + let new_group = proc_macro2::Group::new(g.delimiter(), new_group_tokens); + new_tokens.append(new_group); + } + t => new_tokens.append(t), + } + } + new_tokens + } + + let tokens = tokens.into_token_stream(); + let mut exprs = exprs.into_iter(); + interpolate_tokens_inner(tokens, &mut exprs) +} + +#[test] +fn test_count_hashes() { + let expr = TokenStream::from_str("#l + #r").unwrap(); + assert_eq!(count_hashes(expr), 2); + + let expr = TokenStream::from_str("#freq.sin() * #amp").unwrap(); + assert_eq!(count_hashes(expr), 2); + + let expr = TokenStream::from_str("&#foo").unwrap(); + assert_eq!(count_hashes(expr), 1); + + let expr = TokenStream::from_str("[#a, #b, #c, #d, #e]").unwrap(); + assert_eq!(count_hashes(expr), 5); + + let expr = TokenStream::from_str("{}").unwrap(); + assert_eq!(count_hashes(expr), 0); +} diff --git a/src/node.rs b/src/node/mod.rs similarity index 87% rename from src/node.rs rename to src/node/mod.rs index b385dfa..c8e7af8 100644 --- a/src/node.rs +++ b/src/node/mod.rs @@ -1,3 +1,11 @@ +pub mod expr; +pub mod push; +pub mod serde; + +pub use self::expr::{Expr, NewExprError}; +pub use self::push::{Push, WithPushEval}; +pub use self::serde::SerdeNode; + /// Gantz allows for constructing executable directed graphs by composing together **Node**s. /// /// **Node**s are a way to allow users to abstract and encapsulate logic into smaller, re-usable @@ -49,13 +57,15 @@ pub trait Node { } /// Items that need to be known in order to generate a push evaluation function for a node. -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)] pub struct PushEval { /// The type for each argument. + #[serde(with = "crate::node::serde::fn_decl")] pub fn_decl: syn::FnDecl, /// The name for the function. pub fn_name: String, /// Attributes for the generated `ItemFn`. + #[serde(with = "crate::node::serde::fn_attrs")] pub fn_attrs: Vec, } @@ -134,3 +144,22 @@ impl From for PushEval { PushEval { fn_decl, fn_name, fn_attrs } } } + +impl From for Input { + fn from(u: u32) -> Self { + Input(u) + } +} + +impl From for Output { + fn from(u: u32) -> Self { + Output(u) + } +} + +/// Create a node from the given Rust expression. +/// +/// Shorthand for `node::Expr::new`. +pub fn expr(expr: &str) -> Result { + Expr::new(expr) +} diff --git a/src/node/push.rs b/src/node/push.rs new file mode 100644 index 0000000..5c66708 --- /dev/null +++ b/src/node/push.rs @@ -0,0 +1,80 @@ +use crate::node::{self, Node}; + +/// A wrapper around a `Node` that enables push evaluation. +/// +/// The implementation of `Node` will match the inner node type `N`, but with a unique +/// implementation of `Node::push_eval`. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Push { + node: N, + push_eval: node::PushEval, +} + +/// A trait implemented for all `Node` types allowing to enable push evaluation. +pub trait WithPushEval: Sized + Node { + /// Consume `self` and return a `Node` that has push evaluation enabled. + fn with_push_eval(self, push_eval: node::PushEval) -> Push; + + /// Enable push evaluation using the given push evaluation function. + /// + /// Internally, this calls `with_push_eval`. + /// + /// Note: Only the name, function declaration and attributes are used - the function definition + /// is ignored. + fn with_push_eval_fn(self, item_fn: syn::ItemFn) -> Push { + self.with_push_eval(item_fn.into()) + } + + /// Enable push evaluation. + /// + /// Internally, this calls `with_push_eval_fn` with a function that looks like `fn #name() {}`. + fn with_push_eval_name(self, fn_name: &str) -> Push { + let fn_ident = syn::Ident::new(fn_name, proc_macro2::Span::call_site()); + self.with_push_eval_fn(syn::parse_quote!{ fn #fn_ident() {} }) + } +} + +impl Push +where + N: Node, +{ + /// Given some node, return a `Push` node enabling push evaluation. + pub fn new(node: N, push_eval: node::PushEval) -> Self { + Push { node, push_eval } + } +} + +impl WithPushEval for N +where + N: Node, +{ + /// Consume `self` and return an equivalent node with push evaluation enabled. + fn with_push_eval(self, push_eval: node::PushEval) -> Push { + Push::new(self, push_eval) + } +} + +impl Node for Push +where + N: Node, +{ + fn n_inputs(&self) -> u32 { + self.node.n_inputs() + } + + fn n_outputs(&self) -> u32 { + self.node.n_outputs() + } + + fn expr(&self, args: Vec) -> syn::Expr { + self.node.expr(args) + } + + fn push_eval(&self) -> Option { + Some(self.push_eval.clone()) + } + + fn pull_eval(&self) -> Option { + self.node.pull_eval() + } +} diff --git a/src/node/serde.rs b/src/node/serde.rs new file mode 100644 index 0000000..a21a6d8 --- /dev/null +++ b/src/node/serde.rs @@ -0,0 +1,105 @@ +use crate::node::{self, Node}; + +/// A wrapper around the **Node** trait that allows for serializing and deserializing node trait +/// objects. +#[typetag::serde(tag = "type")] +pub trait SerdeNode { + fn node(&self) -> &Node; +} + +#[typetag::serde] +impl SerdeNode for node::Expr { + fn node(&self) -> &Node { self } +} + +#[typetag::serde] +impl SerdeNode for node::Push { + fn node(&self) -> &Node { self } +} + +pub mod fn_decl { + use serde::{Deserializer, Serializer}; + + pub fn serialize(t: &syn::FnDecl, s: S) -> Result + where + S: Serializer, + { + let item_fn = syn::ItemFn { + attrs: vec![], + vis: syn::Visibility::Public(syn::VisPublic { pub_token: Default::default() }), + constness: None, + unsafety: None, + asyncness: None, + abi: None, + ident: syn::Ident::new("foo", proc_macro2::Span::call_site()), + decl: Box::new(t.clone()), + block: Box::new(syn::Block { stmts: vec![], brace_token: <_>::default() }), + }; + super::tts::serialize(&item_fn, s) + } + + pub fn deserialize<'de, D>(d: D) -> Result + where + D: Deserializer<'de>, + { + let tts = super::tts::deserialize(d)?; + let syn::ItemFn { decl, .. } = syn::parse_quote!{ #tts }; + Ok(*decl) + } +} + +pub mod fn_attrs { + use serde::{Deserializer, Serializer}; + + pub fn serialize(t: &Vec, s: S) -> Result + where + S: Serializer, + { + let syn::ItemFn { decl, .. } = syn::parse_quote!{ fn foo() {} }; + let item_fn = syn::ItemFn { + attrs: t.clone(), + vis: syn::Visibility::Public(syn::VisPublic { pub_token: Default::default() }), + constness: None, + unsafety: None, + asyncness: None, + abi: None, + ident: syn::Ident::new("foo", proc_macro2::Span::call_site()), + decl: decl, + block: Box::new(syn::Block { stmts: vec![], brace_token: <_>::default() }), + }; + super::tts::serialize(&item_fn, s) + } + + pub fn deserialize<'de, D>(d: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let tts = super::tts::deserialize(d)?; + let syn::ItemFn { attrs, .. } = syn::parse_quote!{ #tts }; + Ok(attrs) + } +} + +pub mod tts { + use proc_macro2::TokenStream; + use quote::ToTokens; + use serde::{Deserialize, Deserializer, Serializer}; + use std::str::FromStr; + + pub fn serialize(t: &T, s: S) -> Result + where + T: ToTokens, + S: Serializer, + { + let string: String = format!("{}", t.into_token_stream()); + s.serialize_str(&string) + } + + pub fn deserialize<'de, D>(d: D) -> Result + where + D: Deserializer<'de>, + { + let string = String::deserialize(d)?; + Ok(TokenStream::from_str(&string).expect("failed to parse string as token stream")) + } +} diff --git a/src/project.rs b/src/project.rs index de9c290..467c8a6 100644 --- a/src/project.rs +++ b/src/project.rs @@ -1,5 +1,5 @@ use crate::graph; -use crate::node::{self, Node}; +use crate::node::{self, Node, SerdeNode}; use quote::ToTokens; use std::{fs, io, ops}; use std::collections::{BTreeMap, HashMap}; @@ -38,13 +38,6 @@ pub struct TempProject { project: Option, } -/// A wrapper around the **Node** trait that allows for serializing and deserializing node trait -/// objects. -#[typetag::serde(tag = "type")] -pub trait SerdeNode { - fn node(&self) -> &Node; -} - /// A unique identifier representing an imported node. #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, Deserialize, Serialize)] pub struct NodeId(u64); @@ -608,7 +601,7 @@ where /// Given some UTF-8 node name, return the name of the crate. pub fn node_crate_name(node_name: &str) -> String { - format!("{}{}", NODE_CRATE_PREFIX, slug::slugify(node_name)) + format!("{}{}", NODE_CRATE_PREFIX, slug::slugify(node_name).replace("-", "_")) } /// Given the workspace directory and some UTF-8 node name, return the path to the crate directory. @@ -866,7 +859,10 @@ where } // Given a `NodeIdGraph` and `NodeCollection`, return a graph capable of evaluation. -fn id_graph_to_node_graph<'a>(g: &NodeIdGraph, ns: &'a NodeCollection) -> graph::StableGraph> { +fn id_graph_to_node_graph<'a>( + g: &NodeIdGraph, + ns: &'a NodeCollection, +) -> graph::StableGraph> { g.map( |_, n_id| { match ns[n_id] { diff --git a/tests/graph.rs b/tests/graph.rs index 04cf147..1adefd8 100644 --- a/tests/graph.rs +++ b/tests/graph.rs @@ -1,119 +1,90 @@ // Tests for the graph module. -struct One; +use gantz::Edge; +use gantz::node::{self, SerdeNode, WithPushEval}; -struct Add; - -struct Debug; - -impl gantz::Node for One { - fn n_inputs(&self) -> u32 { - 0 - } - - fn n_outputs(&self) -> u32 { - 1 - } - - fn expr(&self, args: Vec) -> syn::Expr { - assert!(args.is_empty()); - syn::parse_quote! { 1 } - } - - fn push_eval(&self) -> Option { - let item_fn: syn::ItemFn = syn::parse_quote! { fn one() {} }; - Some(item_fn.into()) - } +fn node_push() -> node::Push { + node::expr("()").unwrap().with_push_eval_name("push") } -impl gantz::Node for Add { - fn n_inputs(&self) -> u32 { - 2 - } - - fn n_outputs(&self) -> u32 { - 1 - } - - fn expr(&self, args: Vec) -> syn::Expr { - assert_eq!(args.len(), 2); - let l = &args[0]; - let r = &args[1]; - syn::parse_quote! { #l + #r } - } +fn node_int(i: i32) -> node::Expr { + node::expr(&format!("{{ #push; {} }}", i)).unwrap() } -impl gantz::Node for Debug { - fn n_inputs(&self) -> u32 { - 1 - } - - fn n_outputs(&self) -> u32 { - 0 - } +fn node_add() -> node::Expr { + node::expr("#l + #r").unwrap() +} - fn expr(&self, args: Vec) -> syn::Expr { - assert_eq!(args.len(), 1); - let input = &args[0]; - syn::parse_quote! { println!("{:?}", #input) } - } +fn node_assert_eq() -> node::Expr { + node::expr("assert_eq!(#l, #r)").unwrap() } -// A simple test graph that adds two "one"s and outputs the result to stdout. +// A simple test graph that adds two "one"s and checks that it equals "two". // -// ------- -// | One | -// -+----- -// |\ -// | \ -// | \ -// -+---+- -// | Add | -// -+----- -// | +// -------- +// | push | // push_eval +// -+------ // | -// -+----- -// |Debug| -// ------- +// |--------- +// | | +// -+----- | +// | one | | +// -+----- | +// |\ | +// | \ | +// | \ | +// -+---+- -+----- +// | add | | two | +// -+----- -+----- +// | | +// | -- +// | | +// -+-------+- +// |assert_eq| +// ----------- #[test] fn test_graph1() { + // Create a temp project. + let mut project = gantz::TempProject::open_with_name("test_graph1").unwrap(); + // Instantiate the nodes. - let one = Box::new(One) as Box; - let add = Box::new(Add) as Box<_>; - let debug = Box::new(Debug) as Box<_>; + let push = node_push(); + let one = node_int(1); + let add = node_add(); + let two = node_int(2); + let assert_eq = node_assert_eq(); + + // Add the nodes to the project. + let push = project.add_core_node(Box::new(push) as Box); + let one = project.add_core_node(Box::new(one) as Box<_>); + let add = project.add_core_node(Box::new(add) as Box<_>); + let two = project.add_core_node(Box::new(two) as Box<_>); + let assert_eq = project.add_core_node(Box::new(assert_eq) as Box<_>); // Compose the graph. - let mut g = petgraph::Graph::new(); - let one = g.add_node(one); - let add = g.add_node(add); - let debug = g.add_node(debug); - g.add_edge(one, add, gantz::Edge { - output: gantz::node::Output(0), - input: gantz::node::Input(0), - }); - g.add_edge(one, add, gantz::Edge { - output: gantz::node::Output(0), - input: gantz::node::Input(1), - }); - g.add_edge(add, debug, gantz::Edge { - output: gantz::node::Output(0), - input: gantz::node::Input(0), - }); - - // Find all push evaluation enabled nodes. This should just be our `One` node. - let mut push_ns = gantz::graph::codegen::push_nodes(&g); - assert_eq!(push_ns.len(), 1); - let (push_n, fn_decl) = push_ns.pop().unwrap(); - - // Generate the push evaluation steps. There should be three, one for each node instance. - let eval_steps = gantz::graph::codegen::push_eval_steps(&g, push_n); - assert_eq!(eval_steps.len(), 3); - - // Ensure the order was correct. - let eval_order: Vec<_> = eval_steps.iter().map(|step| step.node).collect(); - assert_eq!(eval_order, vec![one, add, debug]); - - // Generate the push evaluation function. - let push_eval_fn = gantz::graph::codegen::push_eval_fn(&g, fn_decl, &eval_steps); - println!("{}", quote::ToTokens::into_token_stream(push_eval_fn)); + let root = project.root_node_id(); + project.update_graph(&root, |g| { + let push = g.add_node(push); + let one = g.add_node(one); + let add = g.add_node(add); + let two = g.add_node(two); + let assert_eq = g.add_node(assert_eq); + g.add_edge(push, one, Edge::from((0, 0))); + g.add_edge(push, two, Edge::from((0, 0))); + g.add_edge(one, add, Edge::from((0, 0))); + g.add_edge(one, add, Edge::from((0, 1))); + g.add_edge(add, assert_eq, Edge::from((0, 0))); + g.add_edge(two, assert_eq, Edge::from((0, 1))); + }).unwrap(); + + // Retrieve the path to the compiled library. + let dylib_path = project.graph_node_dylib(&root).unwrap().expect("no dylib or node"); + let lib = libloading::Library::new(&dylib_path).expect("failed to load library"); + let symbol_name = "push".as_bytes(); + unsafe { + let push_eval_fn: libloading::Symbol = + lib.get(symbol_name).expect("failed to load symbol"); + // Execute the gantz graph. + push_eval_fn(); + } }