Skip to content

Commit

Permalink
Fix caching for multiple replacements
Browse files Browse the repository at this point in the history
- Cache is no longer default in Rust, use `MatchSettings::cached()` instead of
  `MatchSettings::default()`
- Add cache warnings in Python API
- Fix dual mul_add convention
  • Loading branch information
benruijl committed Oct 8, 2024
1 parent c6dbe6d commit b304255
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 13 deletions.
8 changes: 5 additions & 3 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ impl PythonTransformer {
///
/// For efficiency, the first `rhs_cache_size` substituted patterns are cached.
/// If set to `None`, an internally determined cache size is used.
/// Caching should be disabled (`rhs_cache_size=0`) if the right-hand side contains side effects, such as updating a global variable.
///
/// Examples
/// --------
Expand All @@ -1235,7 +1236,7 @@ impl PythonTransformer {
allow_new_wildcards_on_rhs: Option<bool>,
rhs_cache_size: Option<usize>,
) -> PyResult<PythonTransformer> {
let mut settings = MatchSettings::default();
let mut settings = MatchSettings::cached();

if let Some(ngw) = non_greedy_wildcards {
settings.non_greedy_wildcards = ngw
Expand Down Expand Up @@ -3855,6 +3856,7 @@ impl PythonExpression {
/// If set to `True`, wildcards that do not appear ion the pattern are allowed on the right-hand side.
/// rhs_cache_size: int, optional
/// Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
/// Warning: caching should be disabled (`rhs_cache_size=0`) if the right-hand side contains side effects, such as updating a global variable.
/// repeat: bool, optional
/// If set to `True`, the entire operation will be repeated until there are no more matches.
pub fn replace_all(
Expand All @@ -3872,7 +3874,7 @@ impl PythonExpression {
let pattern = &pattern.to_pattern()?.expr;
let rhs = &rhs.to_pattern_or_map()?;

let mut settings = MatchSettings::default();
let mut settings = MatchSettings::cached();

if let Some(ngw) = non_greedy_wildcards {
settings.non_greedy_wildcards = ngw
Expand Down Expand Up @@ -4611,7 +4613,7 @@ impl PythonReplacement {
let pattern = pattern.to_pattern()?.expr;
let rhs = rhs.to_pattern_or_map()?;

let mut settings = MatchSettings::default();
let mut settings = MatchSettings::cached();

if let Some(ngw) = non_greedy_wildcards {
settings.non_greedy_wildcards = ngw
Expand Down
2 changes: 1 addition & 1 deletion src/domains/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ macro_rules! create_hyperdual_from_components {
{
#[inline(always)]
fn mul_add(&self, a: &Self, b: &Self) -> Self {
a.clone() * b + self
self.clone() * a + b
}

#[inline(always)]
Expand Down
4 changes: 2 additions & 2 deletions src/domains/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2509,8 +2509,8 @@ impl LowerExp for Rational {
}

impl NumericalFloatLike for Rational {
fn mul_add(&self, a: &Self, c: &Self) -> Self {
&(self * a) + c
fn mul_add(&self, a: &Self, b: &Self) -> Self {
self * a + b
}

fn neg(&self) -> Self {
Expand Down
33 changes: 27 additions & 6 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,11 @@ impl<'a> AtomView<'a> {
workspace: &Workspace,
tree_level: usize,
fn_level: usize,
rhs_cache: &mut HashMap<Vec<(Symbol, Match<'a>)>, Atom>,
rhs_cache: &mut HashMap<(usize, Vec<(Symbol, Match<'a>)>), Atom>,
out: &mut Atom,
) -> bool {
let mut beyond_max_level = true;
for r in replacements {
for (rep_id, r) in replacements.iter().enumerate() {
let def_c = Condition::default();
let def_s = MatchSettings::default();
let conditions = r.conditions.unwrap_or(&def_c);
Expand Down Expand Up @@ -359,9 +359,14 @@ impl<'a> AtomView<'a> {
if let Some((_, used_flags)) = it.next(&mut match_stack) {
let mut rhs_subs = workspace.new_atom();

if let Some(rhs) = rhs_cache.get(&match_stack.stack) {
let key = (rep_id, std::mem::take(&mut match_stack.stack));

if let Some(rhs) = rhs_cache.get(&key) {
match_stack.stack = key.1;
rhs_subs.set_from_view(&rhs.as_view());
} else {
match_stack.stack = key.1;

match r.rhs {
PatternOrMap::Pattern(rhs) => {
rhs.substitute_wildcards(workspace, &mut rhs_subs, &match_stack)
Expand All @@ -376,8 +381,10 @@ impl<'a> AtomView<'a> {
if rhs_cache.len() < settings.rhs_cache_size
&& !matches!(r.rhs, PatternOrMap::Pattern(Pattern::Literal(_)))
{
rhs_cache
.insert(match_stack.stack.clone(), rhs_subs.deref_mut().clone());
rhs_cache.insert(
(rep_id, match_stack.stack.clone()),
rhs_subs.deref_mut().clone(),
);
}
}

Expand Down Expand Up @@ -1707,14 +1714,28 @@ pub struct MatchSettings {
pub rhs_cache_size: usize,
}

impl MatchSettings {
/// Create default match settings, but enable caching of the rhs.
pub fn cached() -> Self {
Self {
non_greedy_wildcards: Vec::new(),
level_range: (0, None),
level_is_tree_depth: false,
allow_new_wildcards_on_rhs: false,
rhs_cache_size: 100,
}
}
}

impl Default for MatchSettings {
/// Create default match settings. Use [`MatchSettings::cached`] to enable caching.
fn default() -> Self {
Self {
non_greedy_wildcards: Vec::new(),
level_range: (0, None),
level_is_tree_depth: false,
allow_new_wildcards_on_rhs: false,
rhs_cache_size: 100,
rhs_cache_size: 0,
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ class Expression:
If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
rhs_cache_size: int, optional
Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
**Warning**: caching should be disabled (`rhs_cache_size=0`) if the right-hand side contains side effects, such as updating a global variable.
repeat: bool, optional
If set to `True`, the entire operation will be repeated until there are no more matches.
"""
Expand Down Expand Up @@ -1377,7 +1378,8 @@ class Replacement:
non_greedy_wildcards: Optional[Sequence[Expression]] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False) -> Replacement:
allow_new_wildcards_on_rhs: Optional[bool] = False,
rhs_cache_size: Optional[int] = None) -> Replacement:
"""Create a new replacement. See `replace_all` for more information."""


Expand Down Expand Up @@ -1859,6 +1861,7 @@ class Transformer:
If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
rhs_cache_size: int, optional
Cache the first `rhs_cache_size` substituted patterns. If set to `None`, an internally determined cache size is used.
**Warning**: caching should be disabled (`rhs_cache_size=0`) if the right-hand side contains side effects, such as updating a global variable.
repeat:
If set to `True`, the entire operation will be repeated until there are no more matches.
"""
Expand Down

0 comments on commit b304255

Please sign in to comment.