Skip to content

Commit

Permalink
Add pattern specialization
Browse files Browse the repository at this point in the history
- `replace_all_multiple` now takes a tuple of replacements as arguments
  • Loading branch information
benruijl committed Oct 21, 2024
1 parent 136feba commit fdbdada
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 3 deletions.
31 changes: 29 additions & 2 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4080,12 +4080,17 @@ impl PythonExpression {
/// The list of replacements to apply.
/// repeat: bool, optional
/// If set to `True`, the entire operation will be repeated until there are no more matches.
#[pyo3(signature = (replacements, repeat = None))]
#[pyo3(signature = (*replacements, repeat = None))]
pub fn replace_all_multiple(
&self,
replacements: Vec<PythonReplacement>,
replacements: &Bound<'_, PyTuple>,
repeat: Option<bool>,
) -> PyResult<PythonExpression> {
let replacements = replacements
.iter()
.map(|x| x.extract::<PythonReplacement>())
.collect::<PyResult<Vec<_>>>()?;

let reps = replacements
.iter()
.map(|x| {
Expand Down Expand Up @@ -4793,6 +4798,28 @@ impl PythonReplacement {
settings,
})
}

/// Specialize the replacement by providing the wildcards that should be replaced.
pub fn specialize(
&self,
wildcard_map: HashMap<PythonExpression, ConvertibleToExpression>,
) -> PyResult<PythonReplacement> {
let wildcard_map: HashMap<Symbol, Atom> = wildcard_map
.into_iter()
.map(|(k, v)| Ok(((&k.expr).try_into()?, v.to_expression().expr)))
.collect::<Result<HashMap<_, _>, String>>()
.map_err(|e| exceptions::PyValueError::new_err(e))?;

Ok(PythonReplacement {
pattern: self.pattern.specialize(&wildcard_map),
rhs: match &self.rhs {
PatternOrMap::Pattern(p) => PatternOrMap::Pattern(p.specialize(&wildcard_map)),
PatternOrMap::Map(f) => PatternOrMap::Map(f.clone()),
},
cond: self.cond.clone(),
settings: self.settings.clone(),
})
}
}

#[derive(FromPyObject)]
Expand Down
9 changes: 9 additions & 0 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ impl std::fmt::Debug for Symbol {
}
}

impl TryFrom<&Atom> for Symbol {
type Error = String;

fn try_from(v: &Atom) -> Result<Symbol, Self::Error> {
v.get_symbol()
.ok_or_else(|| format!("{} is not a variable or a function", v))
}
}

impl Symbol {
/// Create a new variable symbol. This constructor should be used with care as there are no checks
/// about the validity of the identifier.
Expand Down
86 changes: 86 additions & 0 deletions src/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ impl std::fmt::Debug for PatternOrMap {
}
}

/// A replacement, specified by a pattern and the right-hand side,
/// with optional conditions and settings.
pub struct OwnedReplacement {
pat: Pattern,
rhs: PatternOrMap,
conditions: Option<Condition<PatternRestriction>>,
settings: Option<MatchSettings>,
}

impl OwnedReplacement {
pub fn borrow<'a>(&'a self) -> Replacement<'a> {
Replacement {
pat: &self.pat,
rhs: &self.rhs,
conditions: self.conditions.as_ref(),
settings: self.settings.as_ref(),
}
}
}

/// A replacement, specified by a pattern and the right-hand side,
/// with optional conditions and settings.
pub struct Replacement<'a> {
Expand Down Expand Up @@ -79,6 +99,19 @@ impl<'a> Replacement<'a> {
self.settings = Some(settings);
self
}

/// Specialize the replacement by replacing wildcards.
pub fn specialize(&self, wildcard_map: &HashMap<Symbol, Atom>) -> OwnedReplacement {
OwnedReplacement {
pat: self.pat.specialize(wildcard_map),
rhs: match self.rhs {
PatternOrMap::Pattern(p) => PatternOrMap::Pattern(p.specialize(wildcard_map)),
PatternOrMap::Map(f) => PatternOrMap::Map(f.clone()),
},
conditions: self.conditions.cloned(),
settings: self.settings.cloned(),
}
}
}

impl Atom {
Expand Down Expand Up @@ -549,6 +582,59 @@ impl Pattern {
Ok(Atom::parse(input)?.into_pattern())
}

/// Specialize the pattern by replacing wildcards.
pub fn specialize(&self, wildcard_map: &HashMap<Symbol, Atom>) -> Pattern {
if let Ok(mut a) = self.to_atom() {
for (s, r) in wildcard_map {
a = a.replace_all(
&Pattern::Literal(Atom::new_var(*s)),
&r.into_pattern().into(),
None,
None,
);
}

return a.into_pattern();
}

match self {
Pattern::Mul(arg) => {
let mut new_args = vec![];
for a in arg {
new_args.push(a.specialize(&wildcard_map));
}
Pattern::Mul(new_args)
}
Pattern::Add(arg) => {
let mut new_args = vec![];
for a in arg {
new_args.push(a.specialize(&wildcard_map));
}
Pattern::Add(new_args)
}
Pattern::Fn(symbol, arg) => {
let mut new_args = vec![];
for a in arg {
new_args.push(a.specialize(&wildcard_map));
}
Pattern::Fn(*symbol, new_args)
}
Pattern::Pow(p) => Pattern::Pow(Box::new([
p[0].specialize(wildcard_map),
p[1].specialize(wildcard_map),
])),
Pattern::Transformer(t) => {
if let Some(p) = &t.0 {
let p = p.specialize(wildcard_map);
Pattern::Transformer(Box::new((Some(p), t.1.clone())))
} else {
self.clone()
}
}
_ => unreachable!(),
}
}

/// Convert the pattern to an atom, if there are not transformers present.
pub fn to_atom(&self) -> Result<Atom, &'static str> {
Workspace::get_local().with(|ws| {
Expand Down
5 changes: 4 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ class Expression:
If set to `True`, the entire operation will be repeated until there are no more matches.
"""

def replace_all_multiple(self, replacements: Sequence[Replacement], repeat: Optional[bool] = False) -> Expression:
def replace_all_multiple(self, *replacements: Replacement, repeat: Optional[bool] = False) -> Expression:
"""
Replace all atoms matching the patterns. See `replace_all` for more information.
Expand Down Expand Up @@ -1425,6 +1425,9 @@ class Replacement:
rhs_cache_size: Optional[int] = None) -> Replacement:
"""Create a new replacement. See `replace_all` for more information."""

def specialize(self, wildcards: dict[Expression, Expression | int | float | Decimal]) -> Replacement:
"""Specialize the replacement by providing the wildcards that should be replaced."""


class PatternRestriction:
"""A restriction on wildcards."""
Expand Down

0 comments on commit fdbdada

Please sign in to comment.