Skip to content

Commit

Permalink
Add a transformer that parallelizes over terms
Browse files Browse the repository at this point in the history
- Fix missing call to custom normalization function
  • Loading branch information
benruijl committed Oct 5, 2024
1 parent 91da5f5 commit 0e3d7e5
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 24 deletions.
49 changes: 49 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,55 @@ impl PythonTransformer {
return append_transformer!(self, transformer);
}

/// Map a chain of transformers over the terms of the expression, optionally using multiple cores.
///
/// Examples
/// --------
/// >>> from symbolica import *
/// >>> x, y = S('x', 'y')
/// >>> t = Transformer().map_terms(Transformer().print(), n_cores=2)
/// >>> e = t(x + y)
#[pyo3(signature = (*transformers, n_cores=1))]
pub fn map_terms(&self, transformers: &PyTuple, n_cores: usize) -> PyResult<PythonTransformer> {
let mut rep_chain = vec![];
// fuse all sub-transformers into one chain
for r in transformers {
let p = r.extract::<PythonTransformer>()?;

let Pattern::Transformer(t) = p.expr.borrow() else {
return Err(exceptions::PyValueError::new_err(
"Argument must be a transformer",
));
};

if t.0.is_some() {
return Err(exceptions::PyValueError::new_err(
"Transformers in a for_each must be unbound. Use Transformer() to create it.",
));
}

rep_chain.extend_from_slice(&t.1);
}

let pool = if n_cores < 2 || !LicenseManager::is_licensed() {
None
} else {
Some(Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(n_cores)
.build()
.map_err(|e| {
exceptions::PyValueError::new_err(format!(
"Could not create thread pool: {}",
e
))
})?,
))
};

return append_transformer!(self, Transformer::MapTerms(rep_chain, pool));
}

/// Create a transformer that applies a transformer chain to every argument of the `arg()` function.
/// If the input is not `arg()`, the transformer is applied to the input.
///
Expand Down
8 changes: 8 additions & 0 deletions src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,14 @@ impl<'a> AtomView<'a> {

out_f.set_normalized(true);

if let Some(f) = State::get_normalization_function(id) {
let mut fs = workspace.new_atom();
if f(handle.as_view(), &mut fs) {
std::mem::swap(&mut handle, &mut fs);
}
debug_assert!(!handle.as_view().needs_normalization());
}

let m = out.to_mul();
m.extend(handle.as_view());
handle.to_num((-1).into());
Expand Down
48 changes: 26 additions & 22 deletions src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use brotli::{CompressorWriter, Decompressor};
use rand::{thread_rng, Rng};
use rayon::prelude::*;
use rayon::{prelude::*, ThreadPool};

use crate::{
atom::{Atom, AtomView},
Expand Down Expand Up @@ -528,32 +528,36 @@ impl<'a> AtomView<'a> {

/// Map the function `f` over all terms, using parallel execution with `n_cores` cores.
pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom {
if let AtomView::Add(aa) = self {
if n_cores < 2 {
return Workspace::get_local().with(|ws| {
let mut r = ws.new_atom();
let rr = r.to_add();
for arg in aa {
rr.extend(f(arg).as_view());
}
let mut out = Atom::new();
r.as_view().normalize(ws, &mut out);
out
});
}

let out_wrap = Mutex::new(vec![]);
if n_cores < 2 || !LicenseManager::is_licensed() {
return self.map_terms_single_core(f);
}

if let AtomView::Add(_) = self {
let t = rayon::ThreadPoolBuilder::new()
.num_threads(if LicenseManager::is_licensed() {
n_cores
} else {
1
})
.num_threads(n_cores)
.build()
.unwrap();

t.install(
self.map_terms_with_pool(f, &t)
} else {
f(*self)
}
}

/// Map the function `f` over all terms, using parallel execution with `n_cores` cores.
pub fn map_terms_with_pool(
&self,
f: impl Fn(AtomView) -> Atom + Send + Sync,
p: &ThreadPool,
) -> Atom {
if !LicenseManager::is_licensed() {
return self.map_terms_single_core(f);
}

if let AtomView::Add(aa) = self {
let out_wrap = Mutex::new(vec![]);

p.install(
#[inline(always)]
|| {
aa.iter().par_bridge().for_each(|x| {
Expand Down
28 changes: 27 additions & 1 deletion src/transformer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Instant;
use std::{sync::Arc, time::Instant};

use crate::{
atom::{representation::FunView, Atom, AtomView, Fun, Symbol},
Expand All @@ -12,6 +12,7 @@ use crate::{
use ahash::HashMap;
use colored::Colorize;
use dyn_clone::DynClone;
use rayon::ThreadPool;

pub trait Map:
Fn(AtomView, &mut Atom) -> Result<(), TransformerError> + DynClone + Send + Sync
Expand Down Expand Up @@ -145,6 +146,8 @@ pub enum Transformer {
/// Apply a transformation to each argument of the `arg()` function.
/// If the input is not `arg()`, map the current input.
ForEach(Vec<Transformer>),
/// Map the transformers over the terms, potentially in parallel
MapTerms(Vec<Transformer>, Option<Arc<ThreadPool>>),
/// Split a `Mul` or `Add` into a list of arguments.
Split,
Partition(Vec<(Symbol, usize)>, bool, bool),
Expand Down Expand Up @@ -177,6 +180,7 @@ impl std::fmt::Debug for Transformer {
Transformer::ArgCount(p) => f.debug_tuple("ArgCount").field(p).finish(),
Transformer::Linearize(s) => f.debug_tuple("Linearize").field(s).finish(),
Transformer::Map(_) => f.debug_tuple("Map").finish(),
Transformer::MapTerms(v, c) => f.debug_tuple("Map").field(v).field(c).finish(),
Transformer::ForEach(t) => f.debug_tuple("ForEach").field(t).finish(),
Transformer::Split => f.debug_tuple("Split").finish(),
Transformer::Partition(g, b1, b2) => f
Expand Down Expand Up @@ -426,6 +430,28 @@ impl Transformer {
Transformer::Map(f) => {
f(cur_input, out)?;
}
Transformer::MapTerms(t, p) => {
if let Some(p) = p {
*out = cur_input.map_terms_with_pool(
|arg| {
Workspace::get_local().with(|ws| {
let mut a = Atom::new();
Self::execute_chain(arg, t, ws, &mut a).unwrap();
a
})
},
p,
);
} else {
*out = cur_input.map_terms_single_core(|arg| {
Workspace::get_local().with(|ws| {
let mut a = Atom::new();
Self::execute_chain(arg, t, ws, &mut a).unwrap();
a
})
})
}
}
Transformer::ForEach(t) => {
if let AtomView::Fun(f) = cur_input {
if f.get_symbol() == State::ARG {
Expand Down
15 changes: 14 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class Expression:
is_symmetric: Optional[bool] = None,
is_antisymmetric: Optional[bool] = None,
is_cyclesymmetric: Optional[bool] = None,
is_linear: Optional[bool] = None) -> Expression:
is_linear: Optional[bool] = None,
custom_normalization: Optional[Transformer] = None) -> Expression:
"""
Create a new symbol from a `name`. Symbols carry information about their attributes.
The symbol can signal that it is symmetric if it is used as a function
Expand Down Expand Up @@ -1636,6 +1637,17 @@ class Transformer:
>>> print(e)
"""

def map_terms(self, *transformers: Transformer, n_cores: int = 1) -> Transformer:
"""Map a chain of transformer over the terms of the expression, optionally using multiple cores.
Examples
--------
>>> from symbolica import *
>>> x, y = S('x', 'y')
>>> t = Transformer().map_terms(Transformer().print(), n_cores=2)
>>> e = t(x + y)
"""

def for_each(self, *transformers: Transformer) -> Transformer:
"""Create a transformer that applies a transformer chain to every argument of the `arg()` function.
If the input is not `arg()`, the transformer is applied to the input.
Expand Down Expand Up @@ -3260,6 +3272,7 @@ class CompiledEvaluator:
_cls,
filename: str,
function_name: str,
input_len: int,
output_len: int,
) -> CompiledEvaluator:
"""Load a compiled library, previously generated with `Evaluator.compile()`."""
Expand Down

0 comments on commit 0e3d7e5

Please sign in to comment.