From 02b878f68da5efb71e0ad27565c7fe2f253579b9 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Mon, 24 Jun 2024 10:20:31 +0200 Subject: [PATCH] Add in-memory parallel term map - Fix parallelization by moving function call --- src/api/python.rs | 39 ++++++++++-------------- src/streaming.rs | 78 +++++++++++++++++++++++++++++++++++++++++++++-- symbolica.pyi | 7 +---- 3 files changed, 93 insertions(+), 31 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 6d4934b8..8fe9dedf 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -2424,7 +2424,7 @@ impl PythonExpression { } /// Map the transformations to every term in the expression. - /// The execution happen in parallel. + /// The execution happens in parallel, using `n_cores`. /// /// Examples /// -------- @@ -2457,30 +2457,23 @@ impl PythonExpression { // release the GIL as Python functions may be called from // within the term mapper - let mut stream = py.allow_threads(move || { - // map every term in the expression - let mut stream = TermStreamer::>::new(TermStreamerConfig { - n_cores: n_cores.unwrap_or(1), - ..Default::default() - }); - stream.push(self.expr.clone()); - - let m = stream.map(|x| { - let mut out = Atom::default(); - Workspace::get_local().with(|ws| { - Transformer::execute(x.as_view(), &t, ws, &mut out).unwrap_or_else(|e| { - // TODO: capture and abort the parallel run - panic!("Transformer failed during parallel execution: {:?}", e) + let r = py.allow_threads(move || { + self.expr.as_view().map_terms( + |x| { + let mut out = Atom::default(); + Workspace::get_local().with(|ws| { + Transformer::execute(x, &t, ws, &mut out).unwrap_or_else(|e| { + // TODO: capture and abort the parallel run + panic!("Transformer failed during parallel execution: {:?}", e) + }); }); - }); - out - }); - Ok::<_, PyErr>(m) - })?; - - let b = stream.to_expression(); + out + }, + n_cores.unwrap_or(1), + ) + }); - Ok(b.into()) + Ok(r.into()) } /// Set the coefficient ring to contain the variables in the `vars` list. diff --git a/src/streaming.rs b/src/streaming.rs index 5e896312..8b31c442 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -11,7 +11,7 @@ use rayon::prelude::*; use crate::{ atom::{Atom, AtomView}, - state::RecycledAtom, + state::{RecycledAtom, Workspace}, LicenseManager, }; @@ -460,7 +460,8 @@ impl TermStreamer { #[inline(always)] || { reader.par_bridge().for_each(|x| { - out_wrap.lock().unwrap().push(f(x)); + let r = f(x); + out_wrap.lock().unwrap().push(r); }); }, ); @@ -494,6 +495,62 @@ impl TermStreamer { } } +impl<'a> AtomView<'a> { + /// Map the function `f` over all its terms, using parallel execution with `n_cores` cores. + pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom { + if let AtomView::Add(aa) = self { + if n_cores < 2 { + return Workspace::get_local().with(|ws| { + let mut r = ws.new_atom(); + let rr = r.to_add(); + for arg in aa { + rr.extend(f(arg).as_view()); + } + let mut out = Atom::new(); + r.as_view().normalize(ws, &mut out); + out + }); + } + + let out_wrap = Mutex::new(vec![]); + + let t = rayon::ThreadPoolBuilder::new() + .num_threads(if LicenseManager::is_licensed() { + n_cores + } else { + 1 + }) + .build() + .unwrap(); + + t.install( + #[inline(always)] + || { + aa.iter().par_bridge().for_each(|x| { + let r = f(x); + out_wrap.lock().unwrap().push(r); + }); + }, + ); + + let res = out_wrap.into_inner().unwrap(); + + Workspace::get_local().with(|ws| { + let mut r = ws.new_atom(); + let rr = r.to_add(); + for arg in res { + rr.extend(arg.as_view()); + } + let mut out = Atom::new(); + r.as_view().normalize(ws, &mut out); + out + }) + } else { + f(self.clone()) + } + } +} + #[cfg(test)] mod test { use std::{fs::File, io::BufWriter}; @@ -595,4 +652,21 @@ mod test { let res = Atom::parse("11*v1+10*f1(v1)").unwrap(); assert_eq!(r, res); } + + #[test] + fn term_map() { + let input = Atom::parse("v1 + v2 + v3 + v4").unwrap(); + + let r = input + .as_view() + .map_terms(|x| Atom::new_num(1) + &x.to_owned(), 4); + + let r2 = input + .as_view() + .map_terms(|x| Atom::new_num(1) + &x.to_owned(), 1); + assert_eq!(r, r2); + + let res = Atom::parse("v1 + v2 + v3 + v4 + 4").unwrap(); + assert_eq!(r, res); + } } diff --git a/symbolica.pyi b/symbolica.pyi index 11743d66..51a474e4 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -625,12 +625,7 @@ class Expression: ) -> Expression: """ Map the transformations to every term in the expression. - The execution happens in parallel. - - - No new functions or variables can be defined and no new - expressions can be parsed inside the map. Doing so will - result in a deadlock. + The execution happens in parallel using `n_cores`. Examples --------