diff --git a/src/api/python.rs b/src/api/python.rs index 0a4e62c..65ef75a 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -729,6 +729,55 @@ impl PythonTransformer { return append_transformer!(self, transformer); } + /// Map a chain of transformers over the terms of the expression, optionally using multiple cores. + /// + /// Examples + /// -------- + /// >>> from symbolica import * + /// >>> x, y = S('x', 'y') + /// >>> t = Transformer().map_terms(Transformer().print(), n_cores=2) + /// >>> e = t(x + y) + #[pyo3(signature = (*transformers, n_cores=1))] + pub fn map_terms(&self, transformers: &PyTuple, n_cores: usize) -> PyResult { + let mut rep_chain = vec![]; + // fuse all sub-transformers into one chain + for r in transformers { + let p = r.extract::()?; + + let Pattern::Transformer(t) = p.expr.borrow() else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + if t.0.is_some() { + return Err(exceptions::PyValueError::new_err( + "Transformers in a for_each must be unbound. Use Transformer() to create it.", + )); + } + + rep_chain.extend_from_slice(&t.1); + } + + let pool = if n_cores < 2 || !LicenseManager::is_licensed() { + None + } else { + Some(Arc::new( + rayon::ThreadPoolBuilder::new() + .num_threads(n_cores) + .build() + .map_err(|e| { + exceptions::PyValueError::new_err(format!( + "Could not create thread pool: {}", + e + )) + })?, + )) + }; + + return append_transformer!(self, Transformer::MapTerms(rep_chain, pool)); + } + /// Create a transformer that applies a transformer chain to every argument of the `arg()` function. /// If the input is not `arg()`, the transformer is applied to the input. /// diff --git a/src/normalize.rs b/src/normalize.rs index e978de0..53e894e 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -1071,6 +1071,14 @@ impl<'a> AtomView<'a> { out_f.set_normalized(true); + if let Some(f) = State::get_normalization_function(id) { + let mut fs = workspace.new_atom(); + if f(handle.as_view(), &mut fs) { + std::mem::swap(&mut handle, &mut fs); + } + debug_assert!(!handle.as_view().needs_normalization()); + } + let m = out.to_mul(); m.extend(handle.as_view()); handle.to_num((-1).into()); diff --git a/src/streaming.rs b/src/streaming.rs index 0bb6cd3..384582e 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -7,7 +7,7 @@ use std::{ use brotli::{CompressorWriter, Decompressor}; use rand::{thread_rng, Rng}; -use rayon::prelude::*; +use rayon::{prelude::*, ThreadPool}; use crate::{ atom::{Atom, AtomView}, @@ -528,32 +528,36 @@ impl<'a> AtomView<'a> { /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom { - if let AtomView::Add(aa) = self { - if n_cores < 2 { - return Workspace::get_local().with(|ws| { - let mut r = ws.new_atom(); - let rr = r.to_add(); - for arg in aa { - rr.extend(f(arg).as_view()); - } - let mut out = Atom::new(); - r.as_view().normalize(ws, &mut out); - out - }); - } - - let out_wrap = Mutex::new(vec![]); + if n_cores < 2 || !LicenseManager::is_licensed() { + return self.map_terms_single_core(f); + } + if let AtomView::Add(_) = self { let t = rayon::ThreadPoolBuilder::new() - .num_threads(if LicenseManager::is_licensed() { - n_cores - } else { - 1 - }) + .num_threads(n_cores) .build() .unwrap(); - t.install( + self.map_terms_with_pool(f, &t) + } else { + f(*self) + } + } + + /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. + pub fn map_terms_with_pool( + &self, + f: impl Fn(AtomView) -> Atom + Send + Sync, + p: &ThreadPool, + ) -> Atom { + if !LicenseManager::is_licensed() { + return self.map_terms_single_core(f); + } + + if let AtomView::Add(aa) = self { + let out_wrap = Mutex::new(vec![]); + + p.install( #[inline(always)] || { aa.iter().par_bridge().for_each(|x| { diff --git a/src/transformer.rs b/src/transformer.rs index c78105b..dc0fadc 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,4 +1,4 @@ -use std::time::Instant; +use std::{sync::Arc, time::Instant}; use crate::{ atom::{representation::FunView, Atom, AtomView, Fun, Symbol}, @@ -12,6 +12,7 @@ use crate::{ use ahash::HashMap; use colored::Colorize; use dyn_clone::DynClone; +use rayon::ThreadPool; pub trait Map: Fn(AtomView, &mut Atom) -> Result<(), TransformerError> + DynClone + Send + Sync @@ -145,6 +146,8 @@ pub enum Transformer { /// Apply a transformation to each argument of the `arg()` function. /// If the input is not `arg()`, map the current input. ForEach(Vec), + /// Map the transformers over the terms, potentially in parallel + MapTerms(Vec, Option>), /// Split a `Mul` or `Add` into a list of arguments. Split, Partition(Vec<(Symbol, usize)>, bool, bool), @@ -177,6 +180,7 @@ impl std::fmt::Debug for Transformer { Transformer::ArgCount(p) => f.debug_tuple("ArgCount").field(p).finish(), Transformer::Linearize(s) => f.debug_tuple("Linearize").field(s).finish(), Transformer::Map(_) => f.debug_tuple("Map").finish(), + Transformer::MapTerms(v, c) => f.debug_tuple("Map").field(v).field(c).finish(), Transformer::ForEach(t) => f.debug_tuple("ForEach").field(t).finish(), Transformer::Split => f.debug_tuple("Split").finish(), Transformer::Partition(g, b1, b2) => f @@ -426,6 +430,28 @@ impl Transformer { Transformer::Map(f) => { f(cur_input, out)?; } + Transformer::MapTerms(t, p) => { + if let Some(p) = p { + *out = cur_input.map_terms_with_pool( + |arg| { + Workspace::get_local().with(|ws| { + let mut a = Atom::new(); + Self::execute_chain(arg, t, ws, &mut a).unwrap(); + a + }) + }, + p, + ); + } else { + *out = cur_input.map_terms_single_core(|arg| { + Workspace::get_local().with(|ws| { + let mut a = Atom::new(); + Self::execute_chain(arg, t, ws, &mut a).unwrap(); + a + }) + }) + } + } Transformer::ForEach(t) => { if let AtomView::Fun(f) = cur_input { if f.get_symbol() == State::ARG { diff --git a/symbolica.pyi b/symbolica.pyi index 09eaa7a..1c870c1 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -147,7 +147,8 @@ class Expression: is_symmetric: Optional[bool] = None, is_antisymmetric: Optional[bool] = None, is_cyclesymmetric: Optional[bool] = None, - is_linear: Optional[bool] = None) -> Expression: + is_linear: Optional[bool] = None, + custom_normalization: Optional[Transformer] = None) -> Expression: """ Create a new symbol from a `name`. Symbols carry information about their attributes. The symbol can signal that it is symmetric if it is used as a function @@ -1636,6 +1637,17 @@ class Transformer: >>> print(e) """ + def map_terms(self, *transformers: Transformer, n_cores: int = 1) -> Transformer: + """Map a chain of transformer over the terms of the expression, optionally using multiple cores. + + Examples + -------- + >>> from symbolica import * + >>> x, y = S('x', 'y') + >>> t = Transformer().map_terms(Transformer().print(), n_cores=2) + >>> e = t(x + y) + """ + def for_each(self, *transformers: Transformer) -> Transformer: """Create a transformer that applies a transformer chain to every argument of the `arg()` function. If the input is not `arg()`, the transformer is applied to the input. @@ -3260,6 +3272,7 @@ class CompiledEvaluator: _cls, filename: str, function_name: str, + input_len: int, output_len: int, ) -> CompiledEvaluator: """Load a compiled library, previously generated with `Evaluator.compile()`."""