Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Nov 4, 2021
1 parent b0eceb5 commit 1356a28
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 68 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Cargo.lock
venv

.vscode
.vscode
snake_egg.html
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ name = "snake_egg"
crate-type = ["cdylib"]

[dependencies]
hashbrown = "0.11"
once_cell = "1"
egg = { git = "https://github.com/egraphs-good/egg", rev = "c637bbd9243118a07e2ccad3f812f66e0952e603" }
# pyo3 = { version = "0.14.5", features = ["extension-module"] }

[dependencies.pyo3]
features = ["extension-module", "abi3", "abi3-py36"]
version = "0.14.5"
version = "0.15.0"
# git = "https://github.com/PyO3/pyo3"
# rev = "64df791741e61c331a03dbed42085b0c1adffea1"
11 changes: 9 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: all build test
.PHONY: all build test install doc

all: test

Expand All @@ -10,7 +10,14 @@ build:
ln -fs $(dir)/libsnake_egg.so $(dir)/snake_egg.so

test: test.py build
env PYTHONPATH=$(dir)/ python3 test.py
env PYTHONPATH=$(dir) python3 test.py

install:
maturin build
python3 -m pip install snake_egg --force-reinstall --no-index --find-link ./target/wheels/

doc: build
env PYTHONPATH=$(dir) python3 -m pydoc -w snake_egg

shell: build
$(python) -ic 'import snake_egg'
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Python bindings for [`egg`](https://github.com/egraphs-good/egg)

# Installing

- Install [`maturin`](https://github.com/PyO3/maturin), a cool Rust/Python builder thingy.
- Download from their site or just `pip install maturin`.
- Type `make install` to build and install `snake_egg` into your python installation.
- This will reinstall over any existing `snake_egg` installation.
- You may want to do this in a `virtualenv`.

If you'd like to manually install it,
just run `maturin build` and find the wheels in `./target/wheels/`.
203 changes: 145 additions & 58 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::{array::IntoIter, collections::HashMap, hash::Hash, time::Duration};
use once_cell::sync::Lazy;

use std::cmp::Ordering;
use std::sync::Mutex;
use std::{borrow::Cow, fmt::Display, hash::Hash, time::Duration};

use pyo3::AsPyPointer;
use pyo3::{
basic::CompareOp,
prelude::*,
types::{PyList, PyString, PyTuple},
AsPyPointer, PyNativeType, PyObjectProtocol, ToPyObject,
types::{PyList, PyString, PyTuple, PyType},
PyObjectProtocol,
};

macro_rules! impl_py_object {
Expand Down Expand Up @@ -60,6 +65,114 @@ impl Var {
}
}

#[derive(Debug, Clone)]
struct PyLang {
obj: PyObject,
children: Vec<egg::Id>,
}

impl PyLang {
fn op(ty: &PyType, children: impl IntoIterator<Item = egg::Id>) -> Self {
let any = ty.as_ref();
let py = any.py();
Self {
obj: any.to_object(py),
children: children.into_iter().collect(),
}
}

fn leaf(any: &PyAny) -> Self {
struct Hashable {
obj: PyObject,
hash: isize,
}

impl Hash for Hashable {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.hash.hash(state);
}
}

impl PartialEq for Hashable {
fn eq(&self, other: &Self) -> bool {
let py = unsafe { Python::assume_gil_acquired() };
let cmp = self.obj.as_ref(py).rich_compare(&other.obj, CompareOp::Eq);
cmp.unwrap().is_true().unwrap()
}
}

impl Eq for Hashable {}

static LEAVES: Lazy<Mutex<hashbrown::HashSet<Hashable>>> = Lazy::new(Default::default);

let hash = any.hash().expect("failed to hash");
let py = any.py();
let obj = any.to_object(py);

let mut leaves = LEAVES.lock().unwrap();
let hashable = leaves.get_or_insert(Hashable { obj, hash });

Self {
obj: hashable.obj.clone(),
children: vec![],
}
}
}

impl PartialEq for PyLang {
fn eq(&self, other: &Self) -> bool {
self.obj.as_ptr() == other.obj.as_ptr() && self.children == other.children
}
}

impl Hash for PyLang {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.obj.as_ptr().hash(state);
self.children.hash(state);
}
}

impl Ord for PyLang {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).expect("comparison failed")
}
}

impl PartialOrd for PyLang {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
match self.obj.as_ptr().partial_cmp(&other.obj.as_ptr()) {
Some(Ordering::Equal) => {}
ord => return ord,
}
self.children.partial_cmp(&other.children)
}
}

impl Eq for PyLang {}

impl egg::Language for PyLang {
fn matches(&self, other: &Self) -> bool {
self.obj.as_ptr() == other.obj.as_ptr() && self.children.len() == other.children.len()
}

fn children(&self) -> &[egg::Id] {
&self.children
}

fn children_mut(&mut self) -> &mut [egg::Id] {
&mut self.children
}
}

impl Display for PyLang {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Python::with_gil(|py| match self.obj.as_ref(py).str() {
Ok(s) => s.fmt(f),
Err(_) => "<<NODE>>".fmt(f),
})
}
}

#[pyclass]
struct Pattern {
pattern: egg::Pattern<PyLang>,
Expand All @@ -82,17 +195,13 @@ fn build_pattern(ast: &mut egg::PatternAst<PyLang>, tree: &PyAny) -> egg::Id {
} else if let Ok(var) = tree.extract::<Var>() {
ast.add(egg::ENodeOrVar::Var(var.0))
} else if let Ok(tuple) = tree.downcast::<PyTuple>() {
let ty = tree.get_type_ptr() as usize;
let ids: Vec<egg::Id> = tuple
.iter()
.map(|child| build_pattern(ast, child))
.collect();
ast.add(egg::ENodeOrVar::ENode(PyLang::Node(ty, ids)))
} else if let Ok(n) = tree.extract::<i64>() {
ast.add(egg::ENodeOrVar::ENode(PyLang::Int(n)))
let op = PyLang::op(
tree.get_type(),
tuple.iter().map(|child| build_pattern(ast, child)),
);
ast.add(egg::ENodeOrVar::ENode(op))
} else {
let repr = tree.repr().expect("failed to repr");
panic!("Cannot convert to pattern: {}", repr)
ast.add(egg::ENodeOrVar::ENode(PyLang::leaf(tree)))
}
}

Expand All @@ -104,43 +213,22 @@ struct Rewrite {
#[pymethods]
impl Rewrite {
#[new]
fn new(from: &PyAny, to: &PyAny) -> Self {
let searcher = Pattern::new(from).pattern;
let applier = Pattern::new(to).pattern;
let rewrite =
egg::Rewrite::new("name", searcher, applier).expect("Failed to create rewrite");
Rewrite { rewrite }
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum PyLang {
Int(i64),
Node(usize, Vec<egg::Id>),
}

impl egg::Language for PyLang {
fn matches(&self, other: &Self) -> bool {
use PyLang::*;
match (self, other) {
(Node(op1, args1), Node(op2, args2)) => op1 == op2 && args1.len() == args2.len(),
(Int(a), Int(b)) => a == b,
_ => false,
}
}

fn children(&self) -> &[egg::Id] {
match self {
PyLang::Node(_, args) => args,
_ => &[],
#[args(name = "\"\"")]
fn new(lhs: &PyAny, rhs: &PyAny, name: &str) -> Self {
let searcher = Pattern::new(lhs).pattern;
let applier = Pattern::new(rhs).pattern;

let mut name = Cow::Borrowed(name);
if name == "" {
name = Cow::Owned(format!("{} => {}", searcher, applier));
}
let rewrite = egg::Rewrite::new(name, searcher, applier).expect("Failed to create rewrite");
Rewrite { rewrite }
}

fn children_mut(&mut self) -> &mut [egg::Id] {
match self {
PyLang::Node(_, args) => args,
_ => &mut [],
}
#[getter]
fn name(&self) -> &str {
self.rewrite.name.as_str()
}
}

Expand Down Expand Up @@ -220,19 +308,18 @@ impl EGraph {

impl EGraph {
fn add_rec(&mut self, expr: &PyAny) -> egg::Id {
if let Ok(id) = expr.extract::<Id>() {
self.egraph.find(id.0)
if let Ok(Id(id)) = expr.extract() {
self.egraph.find(id)
} else if let Ok(Var(var)) = expr.extract() {
panic!("Can't add a var: {}", var)
} else if let Ok(tuple) = expr.downcast::<PyTuple>() {
let ty = expr.get_type_ptr() as usize;
let ids: Vec<egg::Id> = tuple.iter().map(|child| self.add_rec(child)).collect();
self.egraph.add(PyLang::Node(ty, ids))
} else if let Ok(n) = expr.extract::<i64>() {
self.egraph.add(PyLang::Int(n))
// } else if let Ok(s) = expr.extract::<&str>() {
// self.egraph.add(ENode::Symbol(s.into()))
let enode = PyLang::op(
expr.get_type(),
tuple.iter().map(|child| self.add_rec(child)),
);
self.egraph.add(enode)
} else {
let repr = expr.repr().expect("failed to repr");
panic!("Cannot add {}", repr)
self.egraph.add(PyLang::leaf(expr))
}
}
}
Expand Down
44 changes: 38 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,45 @@
from typing import Any, Tuple
import inspect
import unittest
# import doctest


from collections import namedtuple
from snake_egg import EGraph, Rewrite, Var, vars

from collections import namedtuple
Add = namedtuple('Add', 'x y')
Mul = namedtuple('Mul', 'x y')

x, y = vars('x y')

class ENode:
def __rshift__(self, other):
return Rewrite(self, other)


class Foo(tuple, ENode):
def __new__(cls, *args):
return super().__new__(cls, tuple(args))


print(inspect.getmro(Foo))


x, y, z = vars('x y z')

print(str(Add(x, y)))

rules = [
Rewrite(Add(x, x), to=Mul(x, 2))
Rewrite(lhs=Add(x, y), rhs=Add(y, x), name='add_comm'),
Rewrite(Mul(x, y), Mul(y, x)),
Rewrite(Add(x, Add(y, z)), Add(Add(x, y), z)),
Rewrite(Mul(x, Mul(y, z)), Mul(Mul(x, y), z)),
Rewrite(Add(x, 0), x),
Rewrite(Mul(x, 0), 0),
Foo(x, 1) >> x,
Rewrite(Add(x, x), Mul(x, 2)),
]

for r in rules:
print(r.name)


egraph = EGraph()

Expand Down Expand Up @@ -46,4 +72,10 @@ def test_simple(self):


if __name__ == '__main__':
unittest.main()
# import snake_egg
# print("--- doc tests ---")
# failed, tested = doctest.testmod(snake_egg, verbose=True, report=True)
# if failed > 0:
# exit(1)
print("\n\n--- unit tests ---")
unittest.main(verbosity=2)

0 comments on commit 1356a28

Please sign in to comment.