Skip to content

Commit

Permalink
Get extraction working
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Nov 4, 2021
1 parent c30c99a commit a8e05a7
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 21 deletions.
60 changes: 45 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use egg::{Language, RecExpr};
use once_cell::sync::Lazy;

use std::cmp::Ordering;
Expand Down Expand Up @@ -302,6 +303,33 @@ impl EGraph {
self.egraph = runner.egraph;
Ok(())
}

#[args(exprs = "*")]
fn extract(&mut self, py: Python, exprs: &PyTuple) -> SingletonOrTuple<PyObject> {
let ids: Vec<egg::Id> = exprs.iter().map(|expr| self.add(expr).0).collect();
let extractor = egg::Extractor::new(&self.egraph, egg::AstSize);
ids.iter()
.map(|&id| {
let (_cost, recexpr) = extractor.find_best(id);
reconstruct(py, &recexpr)
})
.collect()
}
}

fn reconstruct(py: Python, recexpr: &RecExpr<PyLang>) -> PyObject {
let mut objs = vec![];
for node in recexpr.as_ref() {
if node.is_leaf() {
objs.push(node.obj.clone())
} else {
let get_child = |&id| objs[usize::from(id)].clone();
let args = PyTuple::new(py, node.children.iter().map(get_child));
let obj = node.obj.call1(py, args).expect("Failed to construct");
objs.push(obj)
}
}
objs.pop().unwrap()
}

impl EGraph {
Expand All @@ -322,17 +350,21 @@ impl EGraph {
}
}

fn singleton_or_tuple<T, TS>(py: Python<'_>, elems: TS) -> PyObject
where
T: IntoPy<PyObject>,
TS: IntoIterator<Item = T>,
TS::IntoIter: ExactSizeIterator<Item = T>,
{
let mut elems = elems.into_iter();
if elems.len() == 1 {
elems.next().unwrap().into_py(py)
} else {
PyTuple::new(py, elems.map(|x| x.into_py(py))).into_py(py)
struct SingletonOrTuple<T>(Vec<T>);

impl<T: IntoPy<PyObject>> IntoPy<PyObject> for SingletonOrTuple<T> {
fn into_py(mut self, py: Python) -> PyObject {
match self.0.len() {
0 => panic!("Shouldn't be empty"),
1 => self.0.pop().unwrap().into_py(py),
_ => PyTuple::new(py, self.0.into_iter().map(|x| x.into_py(py))).into_py(py),
}
}
}

impl<T: IntoPy<PyObject>> FromIterator<T> for SingletonOrTuple<T> {
fn from_iter<TS: IntoIterator<Item = T>>(iter: TS) -> Self {
Self(iter.into_iter().collect())
}
}

Expand All @@ -342,17 +374,15 @@ where
#[pymodule]
fn snake_egg(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<EGraph>()?;
m.add_class::<ENode>()?;
m.add_class::<Id>()?;
m.add_class::<Var>()?;
m.add_class::<Pattern>()?;
m.add_class::<Rewrite>()?;

#[pyfn(m)]
fn vars(py: Python<'_>, vars: &PyString) -> PyObject {
fn vars(vars: &PyString) -> SingletonOrTuple<Var> {
let s = vars.to_string_lossy();
let strs: Vec<&str> = s.split_whitespace().collect();
singleton_or_tuple(py, strs.iter().map(|s| Var::from_str(s)))
s.split_whitespace().map(|s| Var::from_str(s)).collect()
}
Ok(())
}
15 changes: 9 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,19 @@ def test_vars(self):
def test_simple(self):
egraph = EGraph()

add = egraph.add(Add(1, 1))
mul = egraph.add(Mul(1, 2))

egraph.run(rules, iter_limit=1)

self.assertTrue(egraph.equiv(add, mul))
add = egraph.add(Add(1, 0))
egraph.run(rules, iter_limit=2)

self.assertTrue(egraph.equiv(add, 1))
self.assertEqual(egraph.add(Add(7, 8)), egraph.add(Add(7, 8)))
self.assertTrue(egraph.union(2, Add(1, 1), Add(2, 0)))

# extract two separately
self.assertEqual(egraph.extract(add), egraph.extract(1))
# extract two at same time
a, b = egraph.extract(add, 1)
self.assertEqual(a, b)


if __name__ == '__main__':
# import snake_egg
Expand Down

0 comments on commit a8e05a7

Please sign in to comment.