Skip to content

Commit

Permalink
Merge branch 'main' into nested_evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jul 26, 2024
2 parents 9f63a52 + 3653a54 commit 743806d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
23 changes: 22 additions & 1 deletion src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ impl PythonPattern {

/// Create a transformer that replaces all patterns matching the left-hand side `self` by the right-hand side `rhs`.
/// Restrictions on pattern can be supplied through `cond`. The settings `non_greedy_wildcards` can be used to specify
/// wildcards that try to match as little as possible.
/// wildcards that try to match as little as possible. The settings `allow_new_wildcards_on_rhs` can be used to allow
/// wildcards that do not appear in the pattern on the right-hand side.
///
/// The `level_range` specifies the `[min,max]` level at which the pattern is allowed to match.
/// The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree,
Expand All @@ -849,6 +850,7 @@ impl PythonPattern {
non_greedy_wildcards: Option<Vec<PythonExpression>>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
) -> PyResult<PythonPattern> {
let mut settings = MatchSettings::default();

Expand Down Expand Up @@ -877,6 +879,9 @@ impl PythonPattern {
if let Some(level_is_tree_depth) = level_is_tree_depth {
settings.level_is_tree_depth = level_is_tree_depth;
}
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}

return append_transformer!(
self,
Expand Down Expand Up @@ -3141,13 +3146,15 @@ impl PythonExpression {
cond: Option<PythonPatternRestriction>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
) -> PyResult<PythonMatchIterator> {
let conditions = cond
.map(|r| r.condition.clone())
.unwrap_or(Condition::default());
let settings = MatchSettings {
level_range: level_range.unwrap_or((0, None)),
level_is_tree_depth: level_is_tree_depth.unwrap_or(false),
allow_new_wildcards_on_rhs: allow_new_wildcards_on_rhs.unwrap_or(false),
..MatchSettings::default()
};
Ok(PythonMatchIterator::new(
Expand Down Expand Up @@ -3178,6 +3185,7 @@ impl PythonExpression {
cond: Option<PythonPatternRestriction>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
) -> PyResult<bool> {
let pat = lhs.to_pattern()?.expr;
let conditions = cond
Expand All @@ -3186,6 +3194,7 @@ impl PythonExpression {
let settings = MatchSettings {
level_range: level_range.unwrap_or((0, None)),
level_is_tree_depth: level_is_tree_depth.unwrap_or(false),
allow_new_wildcards_on_rhs: allow_new_wildcards_on_rhs.unwrap_or(false),
..MatchSettings::default()
};

Expand Down Expand Up @@ -3226,13 +3235,15 @@ impl PythonExpression {
cond: Option<PythonPatternRestriction>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
) -> PyResult<PythonReplaceIterator> {
let conditions = cond
.map(|r| r.condition.clone())
.unwrap_or(Condition::default());
let settings = MatchSettings {
level_range: level_range.unwrap_or((0, None)),
level_is_tree_depth: level_is_tree_depth.unwrap_or(false),
allow_new_wildcards_on_rhs: allow_new_wildcards_on_rhs.unwrap_or(false),
..MatchSettings::default()
};

Expand Down Expand Up @@ -3280,6 +3291,8 @@ impl PythonExpression {
/// Specifies the `[min,max]` level at which the pattern is allowed to match. The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree, depending on `level_is_tree_depth`.
/// level_is_tree_depth: bool, optional
/// If set to `True`, the level is increased when going one level deeper in the expression tree.
/// allow_new_wildcards_on_rhs: bool, optional
/// If set to `True`, wildcards that do not appear ion the pattern are allowed on the right-hand side.
/// repeat: bool, optional
/// If set to `True`, the entire operation will be repeated until there are no more matches.
pub fn replace_all(
Expand All @@ -3290,6 +3303,7 @@ impl PythonExpression {
non_greedy_wildcards: Option<Vec<PythonExpression>>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
repeat: Option<bool>,
) -> PyResult<PythonExpression> {
let pattern = &pattern.to_pattern()?.expr;
Expand Down Expand Up @@ -3322,6 +3336,9 @@ impl PythonExpression {
if let Some(level_is_tree_depth) = level_is_tree_depth {
settings.level_is_tree_depth = level_is_tree_depth;
}
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}

let mut expr_ref = self.expr.as_view();

Expand Down Expand Up @@ -3654,6 +3671,7 @@ impl PythonReplacement {
non_greedy_wildcards: Option<Vec<PythonExpression>>,
level_range: Option<(usize, Option<usize>)>,
level_is_tree_depth: Option<bool>,
allow_new_wildcards_on_rhs: Option<bool>,
) -> PyResult<Self> {
let pattern = pattern.to_pattern()?.expr;
let rhs = rhs.to_pattern()?.expr;
Expand Down Expand Up @@ -3685,6 +3703,9 @@ impl PythonReplacement {
if let Some(level_is_tree_depth) = level_is_tree_depth {
settings.level_is_tree_depth = level_is_tree_depth;
}
if let Some(allow_new_wildcards_on_rhs) = allow_new_wildcards_on_rhs {
settings.allow_new_wildcards_on_rhs = allow_new_wildcards_on_rhs;
}

let cond = cond
.as_ref()
Expand Down
60 changes: 45 additions & 15 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,8 +822,13 @@ impl Pattern {
Pattern::Wildcard(name) => {
if let Some(w) = match_stack.get(*name) {
w.to_atom(out);
} else if match_stack.settings.allow_new_wildcards_on_rhs {
out.to_var(*name);
} else {
panic!("Unsubstituted wildcard {}", name.get_id());
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
name
)))?;
}
}
Pattern::Fn(mut name, args) => {
Expand All @@ -834,8 +839,11 @@ impl Pattern {
} else {
unreachable!("Wildcard must be a function name")
}
} else {
panic!("Unsubstituted wildcard {}", name.get_id());
} else if !match_stack.settings.allow_new_wildcards_on_rhs {
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
name
)))?;
}
}

Expand Down Expand Up @@ -863,11 +871,16 @@ impl Pattern {
unreachable!("Wildcard cannot be function name")
}
}

continue;
} else if match_stack.settings.allow_new_wildcards_on_rhs {
func.add_arg(workspace.new_var(*w).as_view())
} else {
panic!("Unsubstituted wildcard {}", name.get_id());
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
w
)))?;
}

continue;
}

let mut handle = workspace.new_atom();
Expand Down Expand Up @@ -896,11 +909,16 @@ impl Pattern {
unreachable!("Wildcard cannot be function name")
}
}

continue;
} else if match_stack.settings.allow_new_wildcards_on_rhs {
out.set_from_view(&workspace.new_var(*w).as_view());
} else {
panic!("Unsubstituted wildcard {}", w.get_id());
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
w
)))?;
}

continue;
}

let mut handle = workspace.new_atom();
Expand Down Expand Up @@ -937,11 +955,16 @@ impl Pattern {
unreachable!("Wildcard cannot be function name")
}
}

continue;
} else if match_stack.settings.allow_new_wildcards_on_rhs {
mul.extend(workspace.new_var(*w).as_view());
} else {
panic!("Unsubstituted wildcard {}", w.get_id());
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
w
)))?;
}

continue;
}

let mut handle = workspace.new_atom();
Expand Down Expand Up @@ -975,11 +998,16 @@ impl Pattern {
unreachable!("Wildcard cannot be function name")
}
}

continue;
} else if match_stack.settings.allow_new_wildcards_on_rhs {
add.extend(workspace.new_var(*w).as_view());
} else {
panic!("Unsubstituted wildcard {}", w.get_id());
Err(TransformerError::ValueError(format!(
"Unsubstituted wildcard {}",
w
)))?;
}

continue;
}

let mut handle = workspace.new_atom();
Expand Down Expand Up @@ -1521,6 +1549,8 @@ pub struct MatchSettings {
pub level_range: (usize, Option<usize>),
/// Determine whether a level reflects the expression tree depth or the function depth.
pub level_is_tree_depth: bool,
/// Allow wildcards on the right-hand side that do not appear in the pattern.
pub allow_new_wildcards_on_rhs: bool,
}

/// An insertion-ordered map of wildcard identifiers to a subexpressions.
Expand Down
11 changes: 10 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ class Expression:
cond: Optional[PatternRestriction] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False,
) -> MatchIterator:
"""
Return an iterator over the pattern `self` matching to `lhs`.
Expand All @@ -896,6 +897,7 @@ class Expression:
cond: Optional[PatternRestriction] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False,
) -> bool:
"""
Test whether the pattern is found in the expression.
Expand All @@ -916,6 +918,7 @@ class Expression:
cond: Optional[PatternRestriction] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False,
) -> ReplaceIterator:
"""
Return an iterator over the replacement of the pattern `self` on `lhs` by `rhs`.
Expand Down Expand Up @@ -945,6 +948,7 @@ class Expression:
cond: Conditions on the pattern.
level_range: Specifies the `[min,max]` level at which the pattern is allowed to match. The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree, depending on `level_is_tree_depth`.
level_is_tree_depth: If set to `True`, the level is increased when going one level deeper in the expression tree.
allow_new_wildcards_on_rhs: If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
"""

def replace_all(
Expand All @@ -955,6 +959,7 @@ class Expression:
non_greedy_wildcards: Optional[Sequence[Expression]] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False,
repeat: Optional[bool] = False,
) -> Expression:
"""
Expand All @@ -978,6 +983,7 @@ class Expression:
non_greedy_wildcards: Wildcards that try to match as little as possible.
level_range: Specifies the `[min,max]` level at which the pattern is allowed to match. The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree, depending on `level_is_tree_depth`.
level_is_tree_depth: If set to `True`, the level is increased when going one level deeper in the expression tree.
allow_new_wildcards_on_rhs: If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
repeat: If set to `True`, the entire operation will be repeated until there are no more matches.
"""

Expand Down Expand Up @@ -1084,7 +1090,8 @@ class Replacement:
cond: Optional[PatternRestriction] = None,
non_greedy_wildcards: Optional[Sequence[Expression]] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False) -> Replacement:
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False) -> Replacement:
"""Create a new replacement. See `replace_all` for more information."""


Expand Down Expand Up @@ -1407,6 +1414,7 @@ class Transformer:
non_greedy_wildcards: Optional[Sequence[Expression]] = None,
level_range: Optional[Tuple[int, Optional[int]]] = None,
level_is_tree_depth: Optional[bool] = False,
allow_new_wildcards_on_rhs: Optional[bool] = False
) -> Transformer:
"""
Create a transformer that replaces all subexpressions matching the pattern `pat` by the right-hand side `rhs`.
Expand All @@ -1428,6 +1436,7 @@ class Transformer:
non_greedy_wildcards: Wildcards that try to match as little as possible.
level_range: Specifies the `[min,max]` level at which the pattern is allowed to match. The first level is 0 and the level is increased when going into a function or one level deeper in the expression tree, depending on `level_is_tree_depth`.
level_is_tree_depth: If set to `True`, the level is increased when going one level deeper in the expression tree.
allow_new_wildcards_on_rhs: If set to `True`, allow wildcards that do not appear in the pattern on the right-hand side.
repeat: If set to `True`, the entire operation will be repeated until there are no more matches.
"""

Expand Down

0 comments on commit 743806d

Please sign in to comment.