Skip to content

Commit

Permalink
Add in-memory parallel term map
Browse files Browse the repository at this point in the history
- Fix parallelization by moving function call
  • Loading branch information
benruijl committed Jun 24, 2024
1 parent 82bba80 commit 02b878f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 31 deletions.
39 changes: 16 additions & 23 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2424,7 +2424,7 @@ impl PythonExpression {
}

/// Map the transformations to every term in the expression.
/// The execution happen in parallel.
/// The execution happens in parallel, using `n_cores`.
///
/// Examples
/// --------
Expand Down Expand Up @@ -2457,30 +2457,23 @@ impl PythonExpression {

// release the GIL as Python functions may be called from
// within the term mapper
let mut stream = py.allow_threads(move || {
// map every term in the expression
let mut stream = TermStreamer::<CompressorWriter<_>>::new(TermStreamerConfig {
n_cores: n_cores.unwrap_or(1),
..Default::default()
});
stream.push(self.expr.clone());

let m = stream.map(|x| {
let mut out = Atom::default();
Workspace::get_local().with(|ws| {
Transformer::execute(x.as_view(), &t, ws, &mut out).unwrap_or_else(|e| {
// TODO: capture and abort the parallel run
panic!("Transformer failed during parallel execution: {:?}", e)
let r = py.allow_threads(move || {
self.expr.as_view().map_terms(
|x| {
let mut out = Atom::default();
Workspace::get_local().with(|ws| {
Transformer::execute(x, &t, ws, &mut out).unwrap_or_else(|e| {
// TODO: capture and abort the parallel run
panic!("Transformer failed during parallel execution: {:?}", e)
});
});
});
out
});
Ok::<_, PyErr>(m)
})?;

let b = stream.to_expression();
out
},
n_cores.unwrap_or(1),
)
});

Ok(b.into())
Ok(r.into())
}

/// Set the coefficient ring to contain the variables in the `vars` list.
Expand Down
78 changes: 76 additions & 2 deletions src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rayon::prelude::*;

use crate::{
atom::{Atom, AtomView},
state::RecycledAtom,
state::{RecycledAtom, Workspace},
LicenseManager,
};

Expand Down Expand Up @@ -460,7 +460,8 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
#[inline(always)]
|| {
reader.par_bridge().for_each(|x| {
out_wrap.lock().unwrap().push(f(x));
let r = f(x);
out_wrap.lock().unwrap().push(r);
});
},
);
Expand Down Expand Up @@ -494,6 +495,62 @@ impl<W: WriteableNamedStream> TermStreamer<W> {
}
}

impl<'a> AtomView<'a> {
/// Map the function `f` over all its terms, using parallel execution with `n_cores` cores.
pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom {
if let AtomView::Add(aa) = self {
if n_cores < 2 {
return Workspace::get_local().with(|ws| {
let mut r = ws.new_atom();
let rr = r.to_add();
for arg in aa {
rr.extend(f(arg).as_view());
}
let mut out = Atom::new();
r.as_view().normalize(ws, &mut out);
out
});
}

let out_wrap = Mutex::new(vec![]);

let t = rayon::ThreadPoolBuilder::new()
.num_threads(if LicenseManager::is_licensed() {
n_cores
} else {
1
})
.build()
.unwrap();

t.install(
#[inline(always)]
|| {
aa.iter().par_bridge().for_each(|x| {
let r = f(x);
out_wrap.lock().unwrap().push(r);
});
},
);

let res = out_wrap.into_inner().unwrap();

Workspace::get_local().with(|ws| {
let mut r = ws.new_atom();
let rr = r.to_add();
for arg in res {
rr.extend(arg.as_view());
}
let mut out = Atom::new();
r.as_view().normalize(ws, &mut out);
out
})
} else {
f(self.clone())
}
}
}

#[cfg(test)]
mod test {
use std::{fs::File, io::BufWriter};
Expand Down Expand Up @@ -595,4 +652,21 @@ mod test {
let res = Atom::parse("11*v1+10*f1(v1)").unwrap();
assert_eq!(r, res);
}

#[test]
fn term_map() {
let input = Atom::parse("v1 + v2 + v3 + v4").unwrap();

let r = input
.as_view()
.map_terms(|x| Atom::new_num(1) + &x.to_owned(), 4);

let r2 = input
.as_view()
.map_terms(|x| Atom::new_num(1) + &x.to_owned(), 1);
assert_eq!(r, r2);

let res = Atom::parse("v1 + v2 + v3 + v4 + 4").unwrap();
assert_eq!(r, res);
}
}
7 changes: 1 addition & 6 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,7 @@ class Expression:
) -> Expression:
"""
Map the transformations to every term in the expression.
The execution happens in parallel.
No new functions or variables can be defined and no new
expressions can be parsed inside the map. Doing so will
result in a deadlock.
The execution happens in parallel using `n_cores`.
Examples
--------
Expand Down

0 comments on commit 02b878f

Please sign in to comment.