diff --git a/.gitignore b/.gitignore index 785c2eb..e26358a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ Cargo.lock venv -.vscode \ No newline at end of file +.vscode +snake_egg.html diff --git a/Cargo.toml b/Cargo.toml index b7443a6..1c650fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,13 @@ name = "snake_egg" crate-type = ["cdylib"] [dependencies] +hashbrown = "0.11" +once_cell = "1" egg = { git = "https://github.com/egraphs-good/egg", rev = "c637bbd9243118a07e2ccad3f812f66e0952e603" } # pyo3 = { version = "0.14.5", features = ["extension-module"] } [dependencies.pyo3] features = ["extension-module", "abi3", "abi3-py36"] -version = "0.14.5" +version = "0.15.0" # git = "https://github.com/PyO3/pyo3" # rev = "64df791741e61c331a03dbed42085b0c1adffea1" diff --git a/Makefile b/Makefile index 3cbe936..4f73f92 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build test +.PHONY: all build test install doc all: test @@ -10,7 +10,14 @@ build: ln -fs $(dir)/libsnake_egg.so $(dir)/snake_egg.so test: test.py build - env PYTHONPATH=$(dir)/ python3 test.py + env PYTHONPATH=$(dir) python3 test.py + +install: + maturin build + python3 -m pip install snake_egg --force-reinstall --no-index --find-link ./target/wheels/ + +doc: build + env PYTHONPATH=$(dir) python3 -m pydoc -w snake_egg shell: build $(python) -ic 'import snake_egg' diff --git a/README.md b/README.md new file mode 100644 index 0000000..1df3088 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# Python bindings for [`egg`](https://github.com/egraphs-good/egg) + +# Installing + +- Install [`maturin`](https://github.com/PyO3/maturin), a cool Rust/Python builder thingy. + - Download from their site or just `pip install maturin`. +- Type `make install` to build and install `snake_egg` into your python installation. + - This will reinstall over any existing `snake_egg` installation. + - You may want to do this in a `virtualenv`. + +If you'd like to manually install it, + just run `maturin build` and find the wheels in `./target/wheels/`. \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index b06d87d..e1e55bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,15 @@ -use std::{array::IntoIter, collections::HashMap, hash::Hash, time::Duration}; +use once_cell::sync::Lazy; +use std::cmp::Ordering; +use std::sync::Mutex; +use std::{borrow::Cow, fmt::Display, hash::Hash, time::Duration}; + +use pyo3::AsPyPointer; use pyo3::{ basic::CompareOp, prelude::*, - types::{PyList, PyString, PyTuple}, - AsPyPointer, PyNativeType, PyObjectProtocol, ToPyObject, + types::{PyList, PyString, PyTuple, PyType}, + PyObjectProtocol, }; macro_rules! impl_py_object { @@ -60,6 +65,114 @@ impl Var { } } +#[derive(Debug, Clone)] +struct PyLang { + obj: PyObject, + children: Vec, +} + +impl PyLang { + fn op(ty: &PyType, children: impl IntoIterator) -> Self { + let any = ty.as_ref(); + let py = any.py(); + Self { + obj: any.to_object(py), + children: children.into_iter().collect(), + } + } + + fn leaf(any: &PyAny) -> Self { + struct Hashable { + obj: PyObject, + hash: isize, + } + + impl Hash for Hashable { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + let py = unsafe { Python::assume_gil_acquired() }; + let cmp = self.obj.as_ref(py).rich_compare(&other.obj, CompareOp::Eq); + cmp.unwrap().is_true().unwrap() + } + } + + impl Eq for Hashable {} + + static LEAVES: Lazy>> = Lazy::new(Default::default); + + let hash = any.hash().expect("failed to hash"); + let py = any.py(); + let obj = any.to_object(py); + + let mut leaves = LEAVES.lock().unwrap(); + let hashable = leaves.get_or_insert(Hashable { obj, hash }); + + Self { + obj: hashable.obj.clone(), + children: vec![], + } + } +} + +impl PartialEq for PyLang { + fn eq(&self, other: &Self) -> bool { + self.obj.as_ptr() == other.obj.as_ptr() && self.children == other.children + } +} + +impl Hash for PyLang { + fn hash(&self, state: &mut H) { + self.obj.as_ptr().hash(state); + self.children.hash(state); + } +} + +impl Ord for PyLang { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).expect("comparison failed") + } +} + +impl PartialOrd for PyLang { + fn partial_cmp(&self, other: &Self) -> Option { + match self.obj.as_ptr().partial_cmp(&other.obj.as_ptr()) { + Some(Ordering::Equal) => {} + ord => return ord, + } + self.children.partial_cmp(&other.children) + } +} + +impl Eq for PyLang {} + +impl egg::Language for PyLang { + fn matches(&self, other: &Self) -> bool { + self.obj.as_ptr() == other.obj.as_ptr() && self.children.len() == other.children.len() + } + + fn children(&self) -> &[egg::Id] { + &self.children + } + + fn children_mut(&mut self) -> &mut [egg::Id] { + &mut self.children + } +} + +impl Display for PyLang { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Python::with_gil(|py| match self.obj.as_ref(py).str() { + Ok(s) => s.fmt(f), + Err(_) => "<>".fmt(f), + }) + } +} + #[pyclass] struct Pattern { pattern: egg::Pattern, @@ -82,17 +195,13 @@ fn build_pattern(ast: &mut egg::PatternAst, tree: &PyAny) -> egg::Id { } else if let Ok(var) = tree.extract::() { ast.add(egg::ENodeOrVar::Var(var.0)) } else if let Ok(tuple) = tree.downcast::() { - let ty = tree.get_type_ptr() as usize; - let ids: Vec = tuple - .iter() - .map(|child| build_pattern(ast, child)) - .collect(); - ast.add(egg::ENodeOrVar::ENode(PyLang::Node(ty, ids))) - } else if let Ok(n) = tree.extract::() { - ast.add(egg::ENodeOrVar::ENode(PyLang::Int(n))) + let op = PyLang::op( + tree.get_type(), + tuple.iter().map(|child| build_pattern(ast, child)), + ); + ast.add(egg::ENodeOrVar::ENode(op)) } else { - let repr = tree.repr().expect("failed to repr"); - panic!("Cannot convert to pattern: {}", repr) + ast.add(egg::ENodeOrVar::ENode(PyLang::leaf(tree))) } } @@ -104,43 +213,22 @@ struct Rewrite { #[pymethods] impl Rewrite { #[new] - fn new(from: &PyAny, to: &PyAny) -> Self { - let searcher = Pattern::new(from).pattern; - let applier = Pattern::new(to).pattern; - let rewrite = - egg::Rewrite::new("name", searcher, applier).expect("Failed to create rewrite"); - Rewrite { rewrite } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum PyLang { - Int(i64), - Node(usize, Vec), -} - -impl egg::Language for PyLang { - fn matches(&self, other: &Self) -> bool { - use PyLang::*; - match (self, other) { - (Node(op1, args1), Node(op2, args2)) => op1 == op2 && args1.len() == args2.len(), - (Int(a), Int(b)) => a == b, - _ => false, - } - } - - fn children(&self) -> &[egg::Id] { - match self { - PyLang::Node(_, args) => args, - _ => &[], + #[args(name = "\"\"")] + fn new(lhs: &PyAny, rhs: &PyAny, name: &str) -> Self { + let searcher = Pattern::new(lhs).pattern; + let applier = Pattern::new(rhs).pattern; + + let mut name = Cow::Borrowed(name); + if name == "" { + name = Cow::Owned(format!("{} => {}", searcher, applier)); } + let rewrite = egg::Rewrite::new(name, searcher, applier).expect("Failed to create rewrite"); + Rewrite { rewrite } } - fn children_mut(&mut self) -> &mut [egg::Id] { - match self { - PyLang::Node(_, args) => args, - _ => &mut [], - } + #[getter] + fn name(&self) -> &str { + self.rewrite.name.as_str() } } @@ -220,19 +308,18 @@ impl EGraph { impl EGraph { fn add_rec(&mut self, expr: &PyAny) -> egg::Id { - if let Ok(id) = expr.extract::() { - self.egraph.find(id.0) + if let Ok(Id(id)) = expr.extract() { + self.egraph.find(id) + } else if let Ok(Var(var)) = expr.extract() { + panic!("Can't add a var: {}", var) } else if let Ok(tuple) = expr.downcast::() { - let ty = expr.get_type_ptr() as usize; - let ids: Vec = tuple.iter().map(|child| self.add_rec(child)).collect(); - self.egraph.add(PyLang::Node(ty, ids)) - } else if let Ok(n) = expr.extract::() { - self.egraph.add(PyLang::Int(n)) - // } else if let Ok(s) = expr.extract::<&str>() { - // self.egraph.add(ENode::Symbol(s.into())) + let enode = PyLang::op( + expr.get_type(), + tuple.iter().map(|child| self.add_rec(child)), + ); + self.egraph.add(enode) } else { - let repr = expr.repr().expect("failed to repr"); - panic!("Cannot add {}", repr) + self.egraph.add(PyLang::leaf(expr)) } } } diff --git a/test.py b/test.py index 59f06e3..b8f6ebc 100644 --- a/test.py +++ b/test.py @@ -1,19 +1,45 @@ -from typing import Any, Tuple +import inspect import unittest +# import doctest - -from collections import namedtuple from snake_egg import EGraph, Rewrite, Var, vars +from collections import namedtuple Add = namedtuple('Add', 'x y') Mul = namedtuple('Mul', 'x y') -x, y = vars('x y') + +class ENode: + def __rshift__(self, other): + return Rewrite(self, other) + + +class Foo(tuple, ENode): + def __new__(cls, *args): + return super().__new__(cls, tuple(args)) + + +print(inspect.getmro(Foo)) + + +x, y, z = vars('x y z') + +print(str(Add(x, y))) rules = [ - Rewrite(Add(x, x), to=Mul(x, 2)) + Rewrite(lhs=Add(x, y), rhs=Add(y, x), name='add_comm'), + Rewrite(Mul(x, y), Mul(y, x)), + Rewrite(Add(x, Add(y, z)), Add(Add(x, y), z)), + Rewrite(Mul(x, Mul(y, z)), Mul(Mul(x, y), z)), + Rewrite(Add(x, 0), x), + Rewrite(Mul(x, 0), 0), + Foo(x, 1) >> x, + Rewrite(Add(x, x), Mul(x, 2)), ] +for r in rules: + print(r.name) + egraph = EGraph() @@ -46,4 +72,10 @@ def test_simple(self): if __name__ == '__main__': - unittest.main() + # import snake_egg + # print("--- doc tests ---") + # failed, tested = doctest.testmod(snake_egg, verbose=True, report=True) + # if failed > 0: + # exit(1) + print("\n\n--- unit tests ---") + unittest.main(verbosity=2)