Skip to content

Commit

Permalink
Allow pickling of expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jun 20, 2024
1 parent f49b609 commit 82bba80
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
24 changes: 24 additions & 0 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,30 @@ impl PythonExpression {
Ok(e.into())
}

/// Create a new expression that represents 0.
#[new]
pub fn __new__() -> PythonExpression {
Atom::new().into()
}

/// Construct an expression from a serialized state.
pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
unsafe {
self.expr = Atom::from_raw(state);
}
Ok(())
}

/// Get a serialized version of the expression.
pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
Ok(self.expr.clone().into_raw())
}

/// Get the default positional arguments for `__new__`.
pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyTuple> {
Ok(PyTuple::empty(py))
}

/// Copy the expression.
pub fn __copy__(&self) -> PythonExpression {
self.expr.clone().into()
Expand Down
12 changes: 12 additions & 0 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ impl Atom {
a.read(source)?;
Ok(a.as_view().rename(state_map))
}

pub(crate) unsafe fn from_raw(raw: RawAtom) -> Self {
match raw[0] & TYPE_MASK {
NUM_ID => Atom::Num(Num::from_raw(raw)),
VAR_ID => Atom::Var(Var::from_raw(raw)),
FUN_ID => Atom::Fun(Fun::from_raw(raw)),
MUL_ID => Atom::Mul(Mul::from_raw(raw)),
ADD_ID => Atom::Add(Add::from_raw(raw)),
POW_ID => Atom::Pow(Pow::from_raw(raw)),
_ => unreachable!("Unknown type {}", raw[0]),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down
6 changes: 5 additions & 1 deletion symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ class Expression:
If the input is not a valid Symbolica expression.
"""

def __new__(cls) -> Expression:
"""Create a new expression that represents 0."""

def __copy__(self) -> Expression:
"""
Copy the expression.
Expand Down Expand Up @@ -618,10 +621,11 @@ class Expression:
def map(
self,
transformations: Transformer,
n_cores: Optional[int] = 1,
) -> Expression:
"""
Map the transformations to every term in the expression.
The execution happen in parallel.
The execution happens in parallel.
No new functions or variables can be defined and no new
Expand Down

0 comments on commit 82bba80

Please sign in to comment.