Skip to content

Commit

Permalink
Add safe loading and saving of expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Oct 15, 2024
1 parent 9de3aa2 commit dbaf1b2
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 13 deletions.
72 changes: 71 additions & 1 deletion src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
borrow::Borrow,
fs::File,
hash::{Hash, Hasher},
io::BufWriter,
io::{BufReader, BufWriter},
ops::{Deref, Neg},
sync::Arc,
};
Expand Down Expand Up @@ -2399,6 +2399,76 @@ impl PythonExpression {
hasher.finish()
}

/// Save the expression and its state to a binary file.
/// The data is compressed and the compression level can be set between 0 and 11.
///
/// The expression can be loaded using `Expression.load`.
///
/// Examples
/// --------
/// >>> e = E("f(x)+f(y)").expand()
/// >>> e.save('export.dat')
#[pyo3(signature = (filename, compression_level=9))]
pub fn save(&self, filename: &str, compression_level: u32) -> PyResult<()> {
let f = File::create(filename)
.map_err(|e| exceptions::PyIOError::new_err(format!("Could not create file: {}", e)))?;
let mut writer = CompressorWriter::new(BufWriter::new(f), 4096, compression_level, 22);

self.expr
.as_view()
.export(&mut writer)
.map_err(|e| exceptions::PyIOError::new_err(format!("Could not write file: {}", e)))
}

/// Load an expression and its state from a file. The state will be merged
/// with the current one. If a symbol has conflicting attributes, the conflict
/// can be resolved using the renaming function `conflict_fn`.
///
/// Expressions can be saved using `Expression.save`.
///
/// Examples
/// --------
/// If `export.dat` contains a serialized expression: `f(x)+f(y)`:
/// >>> e = Expression.load('export.dat')
///
/// whill yield `f(x)+f(y)`.
///
/// If we have defined symbols in a different order:
/// >>> y, x = S('y', 'x')
/// >>> e = Expression.load('export.dat')
///
/// we get `f(y)+f(x)`.
///
/// If we define a symbol with conflicting attributes, we can resolve the conflict
/// using a renaming function:
///
/// >>> x = S('x', is_symmetric=True)
/// >>> e = Expression.load('export.dat', lambda x: x + '_new')
/// print(e)
///
/// will yield `f(x_new)+f(y)`.
#[classmethod]
pub fn load(_cls: &PyType, filename: &str, conflict_fn: Option<PyObject>) -> PyResult<Self> {
let f = File::open(filename)
.map_err(|e| exceptions::PyIOError::new_err(format!("Could not read file: {}", e)))?;
let mut reader = brotli::Decompressor::new(BufReader::new(f), 4096);

Atom::import(
&mut reader,
match conflict_fn {
Some(f) => Some(Box::new(move |name: &str| -> SmartString<LazyCompact> {
Python::with_gil(|py| {
f.call1(py, (name,)).unwrap().extract::<String>(py).unwrap()
})
.into()
})),
None => None,
},
)
.map(|a| a.into())
.map_err(|e| exceptions::PyIOError::new_err(format!("Could not read file: {}", e)))
}

/// Get the type of the atom.
pub fn get_type(&self) -> PythonAtomType {
match self.expr.as_ref() {
Expand Down
43 changes: 41 additions & 2 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use byteorder::{LittleEndian, WriteBytesExt};
use bytes::{Buf, BufMut};
use smartstring::alias::String;
use std::{
cmp::Ordering,
io::{Read, Write},
};

use crate::{
coefficient::{Coefficient, CoefficientView},
state::{StateMap, Workspace},
state::{State, StateMap, Workspace},
};

use super::{
Expand Down Expand Up @@ -167,7 +168,33 @@ impl Atom {
Ok(())
}

pub fn import<R: Read>(source: R, state_map: &StateMap) -> Result<Atom, std::io::Error> {
/// Export the atom and state to a binary stream. It can be loaded
/// with [Atom::import].
pub fn export<W: Write>(&self, dest: W) -> Result<(), std::io::Error> {
self.as_view().export(dest)
}

/// Import an expression and its state from a binary stream. The state will be merged
/// with the current one. If a symbol has conflicting attributes, the conflict
/// can be resolved using the renaming function `conflict_fn`.
///
/// Expressions can be exported using [Atom::export].
pub fn import<R: Read>(
mut source: R,
conflict_fn: Option<Box<dyn Fn(&str) -> String>>,
) -> Result<Atom, std::io::Error> {
let state_map = State::import(&mut source, conflict_fn)?;

let mut a = Atom::new();
a.read(source)?;
Ok(a.as_view().rename(&state_map))
}

/// Read a stateless expression from a binary stream, renaming the symbols using the provided state map.
pub fn import_with_map<R: Read>(
source: R,
state_map: &StateMap,
) -> Result<Atom, std::io::Error> {
let mut a = Atom::new();
a.read(source)?;
Ok(a.as_view().rename(state_map))
Expand Down Expand Up @@ -1478,6 +1505,18 @@ impl<'a> AtomView<'a> {
}
}

/// Export the atom and state to a binary stream. It can be loaded
/// with [Atom::import].
#[inline(always)]
pub fn export<W: Write>(&self, mut dest: W) -> Result<(), std::io::Error> {
State::export(&mut dest)?;

let d = self.get_data();
dest.write_u8(0)?;
dest.write_u64::<LittleEndian>(d.len() as u64)?;
dest.write_all(d)
}

/// Write the expression to a binary stream. The format is the byte-length first
/// followed by the data. To import the expression in new session, also export the [`State`].
#[inline(always)]
Expand Down
23 changes: 17 additions & 6 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::{
LicenseManager,
};

pub const SYMBOLICA_MAGIC: u32 = 0x37871367;
pub const EXPORT_FORMAT_VERSION: u16 = 1;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -486,7 +487,8 @@ impl State {

/// Write the state to a binary stream.
#[inline(always)]
pub fn export<W: Write>(mut dest: W) -> Result<(), std::io::Error> {
pub fn export<W: Write>(dest: &mut W) -> Result<(), std::io::Error> {
dest.write_u32::<LittleEndian>(SYMBOLICA_MAGIC)?;
dest.write_u16::<LittleEndian>(EXPORT_FORMAT_VERSION)?;

dest.write_u64::<LittleEndian>(
Expand Down Expand Up @@ -539,13 +541,22 @@ impl State {

/// Import a state, merging it with the current state.
/// Upon a conflict, i.e. when a symbol with the same name but different attributes is
/// encountered, the conflict_fn is called with the conflicting name as argument which
/// encountered, `conflict_fn` is called with the conflicting name as argument which
/// should yield a new name for the symbol.
#[inline(always)]
pub fn import<R: Read>(
mut source: R,
source: &mut R,
conflict_fn: Option<Box<dyn Fn(&str) -> String>>,
) -> Result<StateMap, std::io::Error> {
let magic = source.read_u32::<LittleEndian>()?;

if magic != SYMBOLICA_MAGIC {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid magic number: the file is not exported from Symbolica",
));
}

let version = source.read_u16::<LittleEndian>()?;
if version != EXPORT_FORMAT_VERSION {
return Err(std::io::Error::new(
Expand Down Expand Up @@ -661,14 +672,14 @@ impl State {
};

let mut f = Atom::new();
f.read(&mut source)?;
f.read(&mut *source)?;

let f_r = f.as_view().rename(&state_map);
variables.push(Variable::Function(symb, Arc::new(f_r)));
}
3 => {
let mut f = Atom::new();
f.read(&mut source)?;
f.read(&mut *source)?;

let f_r = f.as_view().rename(&state_map);
variables.push(Variable::Other(Arc::new(f_r)));
Expand Down Expand Up @@ -862,7 +873,7 @@ mod tests {
let mut export = vec![];
State::export(&mut export).unwrap();

let i = State::import(Cursor::new(&export), None).unwrap();
let i = State::import(&mut Cursor::new(&export), None).unwrap();
assert!(i.is_empty());
}

Expand Down
43 changes: 43 additions & 0 deletions symbolica.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,49 @@ class Expression:
Convert the expression into a human-readable string.
"""

@classmethod
def load(_cls, filename: str, conflict_fn: Callable[[str], str]) -> Expression:
"""Load an expression and its state from a file. The state will be merged
with the current one. If a symbol has conflicting attributes, the conflict
can be resolved using the renaming function `conflict_fn`.
Expressions can be saved using `Expression.save`.
Examples
--------
If `export.dat` contains a serialized expression: `f(x)+f(y)`:
>>> e = Expression.load('export.dat')
whill yield `f(x)+f(y)`.
If we have defined symbols in a different order:
>>> y, x = S('y', 'x')
>>> e = Expression.load('export.dat')
we get `f(y)+f(x)`.
If we define a symbol with conflicting attributes, we can resolve the conflict
using a renaming function:
>>> x = S('x', is_symmetric=True)
>>> e = Expression.load('export.dat', lambda x: x + '_new')
print(e)
will yield `f(x_new)+f(y)`.
"""

def save(self, filename: str, compression_level: int = 9):
"""Save the expression and its state to a binary file.
The data is compressed and the compression level can be set between 0 and 11.
The data can be loaded using `Expression.load`.
Examples
--------
>>> e = E("f(x)+f(y)").expand()
>>> e.save('export.dat')
"""

def get_byte_size(self) -> int:
""" Get the number of bytes that this expression takes up in memory."""

Expand Down
8 changes: 4 additions & 4 deletions tests/import_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ fn conflict() {
State::get_symbol("f");

let state_map = State::import(
Cursor::new(&state_export),
&mut Cursor::new(&state_export),
Some(Box::new(|old_name| SmartString::from(old_name) + "1")),
)
.unwrap();

let a_rec = Atom::import(Cursor::new(&a_export), &state_map).unwrap();
let a_rec = Atom::import_with_map(Cursor::new(&a_export), &state_map).unwrap();

let r = Atom::parse("x^2*f1(y, x)").unwrap();
assert_eq!(a_rec, r);
Expand All @@ -55,9 +55,9 @@ fn rational_rename() {

State::get_symbol("y");

let state_map = State::import(Cursor::new(&state_export), None).unwrap();
let state_map = State::import(&mut Cursor::new(&state_export), None).unwrap();

let a_rec = Atom::import(Cursor::new(&a_export), &state_map).unwrap();
let a_rec = Atom::import_with_map(Cursor::new(&a_export), &state_map).unwrap();

let r = Atom::parse("x^2*coeff(x)").unwrap();
assert_eq!(a_rec, r);
Expand Down

0 comments on commit dbaf1b2

Please sign in to comment.