diff --git a/snake_egg/tests/test_egg_interface.py b/snake_egg/tests/test_egg_interface.py index 1399cf7..63d12f7 100644 --- a/snake_egg/tests/test_egg_interface.py +++ b/snake_egg/tests/test_egg_interface.py @@ -14,10 +14,15 @@ def __init__(self, x: Any, y: Any): self.x = x self.y = y + def __str__(self) -> str: + return f"Add({self.x}, {self.y})" + + @property def __egg_head__(self): return self.__class__ - def __egg_args(self): + @property + def __egg_args__(self): return self.x, self.y @@ -26,10 +31,15 @@ def __init__(self, x: Any, y: Any): self.x = x self.y = y + def __str__(self) -> str: + return f"Mul({self.x}, {self.y})" + + @property def __egg_head__(self): return self.__class__ - def __egg_args(self): + @property + def __egg_args__(self): return self.x, self.y @@ -45,25 +55,29 @@ def __egg_args(self): ] -def simplify(expr, iters=7): +def is_equal(expr_a, expr_b, iters=5): egraph = EGraph() - egraph.add(expr) + + id_a = egraph.add(expr_a) + id_b = egraph.add(expr_b) + egraph.run(rules, iters) - best = egraph.extract(expr) - return best + + return egraph.equiv(id_a, id_b) def test_simple_1(): - assert simplify(Mul(0, 42)) == 0 + assert is_equal(Mul(0, 42), 0) def test_simple_2(): foo = "foo" - assert simplify(Add(0, Mul(1, foo))) == foo + assert is_equal(Add(0, Mul(1, foo)), foo) def test_simple_3(): - assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo") + foo = "foo" + assert is_equal(Mul(2, Mul(1, foo)), Mul(2, foo)) test_simple_1() diff --git a/src/util.rs b/src/util.rs index 18edd78..62b329c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -17,7 +17,7 @@ pub fn build_node(egraph: &mut EGraph, expr: &PyAny) egraph.find(id) } else if let Ok(PyVar(var)) = expr.extract() { panic!("Can't add a var: {}", var) - } else if let Ok(args) = expr.getattr("__egg_args") { + } else if let Ok(args) = expr.getattr("__egg_args__") { let args = args.downcast::().unwrap(); let class = if let Ok(class) = expr.getattr("__egg_head__") { class.downcast::().unwrap()