Skip to content

Commit

Permalink
__egg_args__ and not __egg_args. Also, egg interface methods should b…
Browse files Browse the repository at this point in the history
…e properties.
  • Loading branch information
FedericoAureliano committed Dec 5, 2022
1 parent f07043a commit 7b561b6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
32 changes: 23 additions & 9 deletions snake_egg/tests/test_egg_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ def __init__(self, x: Any, y: Any):
self.x = x
self.y = y

def __str__(self) -> str:
return f"Add({self.x}, {self.y})"

@property
def __egg_head__(self):
return self.__class__

def __egg_args(self):
@property
def __egg_args__(self):
return self.x, self.y


Expand All @@ -26,10 +31,15 @@ def __init__(self, x: Any, y: Any):
self.x = x
self.y = y

def __str__(self) -> str:
return f"Mul({self.x}, {self.y})"

@property
def __egg_head__(self):
return self.__class__

def __egg_args(self):
@property
def __egg_args__(self):
return self.x, self.y


Expand All @@ -45,25 +55,29 @@ def __egg_args(self):
]


def simplify(expr, iters=7):
def is_equal(expr_a, expr_b, iters=5):
egraph = EGraph()
egraph.add(expr)

id_a = egraph.add(expr_a)
id_b = egraph.add(expr_b)

egraph.run(rules, iters)
best = egraph.extract(expr)
return best

return egraph.equiv(id_a, id_b)


def test_simple_1():
assert simplify(Mul(0, 42)) == 0
assert is_equal(Mul(0, 42), 0)


def test_simple_2():
foo = "foo"
assert simplify(Add(0, Mul(1, foo))) == foo
assert is_equal(Add(0, Mul(1, foo)), foo)


def test_simple_3():
assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")
foo = "foo"
assert is_equal(Mul(2, Mul(1, foo)), Mul(2, foo))


test_simple_1()
Expand Down
2 changes: 1 addition & 1 deletion src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn build_node(egraph: &mut EGraph<PythonNode, PythonAnalysis>, expr: &PyAny)
egraph.find(id)
} else if let Ok(PyVar(var)) = expr.extract() {
panic!("Can't add a var: {}", var)
} else if let Ok(args) = expr.getattr("__egg_args") {
} else if let Ok(args) = expr.getattr("__egg_args__") {
let args = args.downcast::<PyTuple>().unwrap();
let class = if let Ok(class) = expr.getattr("__egg_head__") {
class.downcast::<PyType>().unwrap()
Expand Down

0 comments on commit 7b561b6

Please sign in to comment.