From dbaf1b28af1c4b3000c65383cdecb6cb3d725c02 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Tue, 15 Oct 2024 13:30:43 +0200 Subject: [PATCH] Add safe loading and saving of expressions --- src/api/python.rs | 72 +++++++++++++++++++++++++++++++++++++- src/atom/representation.rs | 43 +++++++++++++++++++++-- src/state.rs | 23 ++++++++---- symbolica.pyi | 43 +++++++++++++++++++++++ tests/import_export.rs | 8 ++--- 5 files changed, 176 insertions(+), 13 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 6e138a11..7e1a4c7f 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -2,7 +2,7 @@ use std::{ borrow::Borrow, fs::File, hash::{Hash, Hasher}, - io::BufWriter, + io::{BufReader, BufWriter}, ops::{Deref, Neg}, sync::Arc, }; @@ -2399,6 +2399,76 @@ impl PythonExpression { hasher.finish() } + /// Save the expression and its state to a binary file. + /// The data is compressed and the compression level can be set between 0 and 11. + /// + /// The expression can be loaded using `Expression.load`. + /// + /// Examples + /// -------- + /// >>> e = E("f(x)+f(y)").expand() + /// >>> e.save('export.dat') + #[pyo3(signature = (filename, compression_level=9))] + pub fn save(&self, filename: &str, compression_level: u32) -> PyResult<()> { + let f = File::create(filename) + .map_err(|e| exceptions::PyIOError::new_err(format!("Could not create file: {}", e)))?; + let mut writer = CompressorWriter::new(BufWriter::new(f), 4096, compression_level, 22); + + self.expr + .as_view() + .export(&mut writer) + .map_err(|e| exceptions::PyIOError::new_err(format!("Could not write file: {}", e))) + } + + /// Load an expression and its state from a file. The state will be merged + /// with the current one. If a symbol has conflicting attributes, the conflict + /// can be resolved using the renaming function `conflict_fn`. + /// + /// Expressions can be saved using `Expression.save`. + /// + /// Examples + /// -------- + /// If `export.dat` contains a serialized expression: `f(x)+f(y)`: + /// >>> e = Expression.load('export.dat') + /// + /// whill yield `f(x)+f(y)`. + /// + /// If we have defined symbols in a different order: + /// >>> y, x = S('y', 'x') + /// >>> e = Expression.load('export.dat') + /// + /// we get `f(y)+f(x)`. + /// + /// If we define a symbol with conflicting attributes, we can resolve the conflict + /// using a renaming function: + /// + /// >>> x = S('x', is_symmetric=True) + /// >>> e = Expression.load('export.dat', lambda x: x + '_new') + /// print(e) + /// + /// will yield `f(x_new)+f(y)`. + #[classmethod] + pub fn load(_cls: &PyType, filename: &str, conflict_fn: Option) -> PyResult { + let f = File::open(filename) + .map_err(|e| exceptions::PyIOError::new_err(format!("Could not read file: {}", e)))?; + let mut reader = brotli::Decompressor::new(BufReader::new(f), 4096); + + Atom::import( + &mut reader, + match conflict_fn { + Some(f) => Some(Box::new(move |name: &str| -> SmartString { + Python::with_gil(|py| { + f.call1(py, (name,)).unwrap().extract::(py).unwrap() + }) + .into() + })), + None => None, + }, + ) + .map(|a| a.into()) + .map_err(|e| exceptions::PyIOError::new_err(format!("Could not read file: {}", e))) + } + /// Get the type of the atom. pub fn get_type(&self) -> PythonAtomType { match self.expr.as_ref() { diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 3492f303..5fb92de5 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -1,5 +1,6 @@ use byteorder::{LittleEndian, WriteBytesExt}; use bytes::{Buf, BufMut}; +use smartstring::alias::String; use std::{ cmp::Ordering, io::{Read, Write}, @@ -7,7 +8,7 @@ use std::{ use crate::{ coefficient::{Coefficient, CoefficientView}, - state::{StateMap, Workspace}, + state::{State, StateMap, Workspace}, }; use super::{ @@ -167,7 +168,33 @@ impl Atom { Ok(()) } - pub fn import(source: R, state_map: &StateMap) -> Result { + /// Export the atom and state to a binary stream. It can be loaded + /// with [Atom::import]. + pub fn export(&self, dest: W) -> Result<(), std::io::Error> { + self.as_view().export(dest) + } + + /// Import an expression and its state from a binary stream. The state will be merged + /// with the current one. If a symbol has conflicting attributes, the conflict + /// can be resolved using the renaming function `conflict_fn`. + /// + /// Expressions can be exported using [Atom::export]. + pub fn import( + mut source: R, + conflict_fn: Option String>>, + ) -> Result { + let state_map = State::import(&mut source, conflict_fn)?; + + let mut a = Atom::new(); + a.read(source)?; + Ok(a.as_view().rename(&state_map)) + } + + /// Read a stateless expression from a binary stream, renaming the symbols using the provided state map. + pub fn import_with_map( + source: R, + state_map: &StateMap, + ) -> Result { let mut a = Atom::new(); a.read(source)?; Ok(a.as_view().rename(state_map)) @@ -1478,6 +1505,18 @@ impl<'a> AtomView<'a> { } } + /// Export the atom and state to a binary stream. It can be loaded + /// with [Atom::import]. + #[inline(always)] + pub fn export(&self, mut dest: W) -> Result<(), std::io::Error> { + State::export(&mut dest)?; + + let d = self.get_data(); + dest.write_u8(0)?; + dest.write_u64::(d.len() as u64)?; + dest.write_all(d) + } + /// Write the expression to a binary stream. The format is the byte-length first /// followed by the data. To import the expression in new session, also export the [`State`]. #[inline(always)] diff --git a/src/state.rs b/src/state.rs index 4c15a16a..7303e704 100644 --- a/src/state.rs +++ b/src/state.rs @@ -27,6 +27,7 @@ use crate::{ LicenseManager, }; +pub const SYMBOLICA_MAGIC: u32 = 0x37871367; pub const EXPORT_FORMAT_VERSION: u16 = 1; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -486,7 +487,8 @@ impl State { /// Write the state to a binary stream. #[inline(always)] - pub fn export(mut dest: W) -> Result<(), std::io::Error> { + pub fn export(dest: &mut W) -> Result<(), std::io::Error> { + dest.write_u32::(SYMBOLICA_MAGIC)?; dest.write_u16::(EXPORT_FORMAT_VERSION)?; dest.write_u64::( @@ -539,13 +541,22 @@ impl State { /// Import a state, merging it with the current state. /// Upon a conflict, i.e. when a symbol with the same name but different attributes is - /// encountered, the conflict_fn is called with the conflicting name as argument which + /// encountered, `conflict_fn` is called with the conflicting name as argument which /// should yield a new name for the symbol. #[inline(always)] pub fn import( - mut source: R, + source: &mut R, conflict_fn: Option String>>, ) -> Result { + let magic = source.read_u32::()?; + + if magic != SYMBOLICA_MAGIC { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Invalid magic number: the file is not exported from Symbolica", + )); + } + let version = source.read_u16::()?; if version != EXPORT_FORMAT_VERSION { return Err(std::io::Error::new( @@ -661,14 +672,14 @@ impl State { }; let mut f = Atom::new(); - f.read(&mut source)?; + f.read(&mut *source)?; let f_r = f.as_view().rename(&state_map); variables.push(Variable::Function(symb, Arc::new(f_r))); } 3 => { let mut f = Atom::new(); - f.read(&mut source)?; + f.read(&mut *source)?; let f_r = f.as_view().rename(&state_map); variables.push(Variable::Other(Arc::new(f_r))); @@ -862,7 +873,7 @@ mod tests { let mut export = vec![]; State::export(&mut export).unwrap(); - let i = State::import(Cursor::new(&export), None).unwrap(); + let i = State::import(&mut Cursor::new(&export), None).unwrap(); assert!(i.is_empty()); } diff --git a/symbolica.pyi b/symbolica.pyi index 185f908a..6666b592 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -330,6 +330,49 @@ class Expression: Convert the expression into a human-readable string. """ + @classmethod + def load(_cls, filename: str, conflict_fn: Callable[[str], str]) -> Expression: + """Load an expression and its state from a file. The state will be merged + with the current one. If a symbol has conflicting attributes, the conflict + can be resolved using the renaming function `conflict_fn`. + + Expressions can be saved using `Expression.save`. + + Examples + -------- + If `export.dat` contains a serialized expression: `f(x)+f(y)`: + >>> e = Expression.load('export.dat') + + whill yield `f(x)+f(y)`. + + If we have defined symbols in a different order: + >>> y, x = S('y', 'x') + >>> e = Expression.load('export.dat') + + we get `f(y)+f(x)`. + + If we define a symbol with conflicting attributes, we can resolve the conflict + using a renaming function: + + >>> x = S('x', is_symmetric=True) + >>> e = Expression.load('export.dat', lambda x: x + '_new') + print(e) + + will yield `f(x_new)+f(y)`. + """ + + def save(self, filename: str, compression_level: int = 9): + """Save the expression and its state to a binary file. + The data is compressed and the compression level can be set between 0 and 11. + + The data can be loaded using `Expression.load`. + + Examples + -------- + >>> e = E("f(x)+f(y)").expand() + >>> e.save('export.dat') + """ + def get_byte_size(self) -> int: """ Get the number of bytes that this expression takes up in memory.""" diff --git a/tests/import_export.rs b/tests/import_export.rs index 00881c31..f161aa04 100644 --- a/tests/import_export.rs +++ b/tests/import_export.rs @@ -27,12 +27,12 @@ fn conflict() { State::get_symbol("f"); let state_map = State::import( - Cursor::new(&state_export), + &mut Cursor::new(&state_export), Some(Box::new(|old_name| SmartString::from(old_name) + "1")), ) .unwrap(); - let a_rec = Atom::import(Cursor::new(&a_export), &state_map).unwrap(); + let a_rec = Atom::import_with_map(Cursor::new(&a_export), &state_map).unwrap(); let r = Atom::parse("x^2*f1(y, x)").unwrap(); assert_eq!(a_rec, r); @@ -55,9 +55,9 @@ fn rational_rename() { State::get_symbol("y"); - let state_map = State::import(Cursor::new(&state_export), None).unwrap(); + let state_map = State::import(&mut Cursor::new(&state_export), None).unwrap(); - let a_rec = Atom::import(Cursor::new(&a_export), &state_map).unwrap(); + let a_rec = Atom::import_with_map(Cursor::new(&a_export), &state_map).unwrap(); let r = Atom::parse("x^2*coeff(x)").unwrap(); assert_eq!(a_rec, r);