From 82bba808b924e4c20c73079d72255a90c8c08d06 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Thu, 20 Jun 2024 12:52:57 +0200 Subject: [PATCH] Allow pickling of expressions --- src/api/python.rs | 24 ++++++++++++++++++++++++ src/atom/representation.rs | 12 ++++++++++++ symbolica.pyi | 6 +++++- 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/api/python.rs b/src/api/python.rs index 9db7134..6d4934b 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -1618,6 +1618,30 @@ impl PythonExpression { Ok(e.into()) } + /// Create a new expression that represents 0. + #[new] + pub fn __new__() -> PythonExpression { + Atom::new().into() + } + + /// Construct an expression from a serialized state. + pub fn __setstate__(&mut self, state: Vec) -> PyResult<()> { + unsafe { + self.expr = Atom::from_raw(state); + } + Ok(()) + } + + /// Get a serialized version of the expression. + pub fn __getstate__(&self) -> PyResult> { + Ok(self.expr.clone().into_raw()) + } + + /// Get the default positional arguments for `__new__`. + pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyTuple> { + Ok(PyTuple::empty(py)) + } + /// Copy the expression. pub fn __copy__(&self) -> PythonExpression { self.expr.clone().into() diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 073965e..f85c8d8 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -168,6 +168,18 @@ impl Atom { a.read(source)?; Ok(a.as_view().rename(state_map)) } + + pub(crate) unsafe fn from_raw(raw: RawAtom) -> Self { + match raw[0] & TYPE_MASK { + NUM_ID => Atom::Num(Num::from_raw(raw)), + VAR_ID => Atom::Var(Var::from_raw(raw)), + FUN_ID => Atom::Fun(Fun::from_raw(raw)), + MUL_ID => Atom::Mul(Mul::from_raw(raw)), + ADD_ID => Atom::Add(Add::from_raw(raw)), + POW_ID => Atom::Pow(Pow::from_raw(raw)), + _ => unreachable!("Unknown type {}", raw[0]), + } + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/symbolica.pyi b/symbolica.pyi index 6053109..11743d6 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -236,6 +236,9 @@ class Expression: If the input is not a valid Symbolica expression. """ + def __new__(cls) -> Expression: + """Create a new expression that represents 0.""" + def __copy__(self) -> Expression: """ Copy the expression. @@ -618,10 +621,11 @@ class Expression: def map( self, transformations: Transformer, + n_cores: Optional[int] = 1, ) -> Expression: """ Map the transformations to every term in the expression. - The execution happen in parallel. + The execution happens in parallel. No new functions or variables can be defined and no new