Skip to content

Commit

Permalink
Revert to first fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wraith1995 authored Dec 6, 2022
1 parent 54642e7 commit adcecf6
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions snake_egg/tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,34 @@

from dataclasses import dataclass
from typing import Any
import sys
current_version = sys.version_info()
isatleast10 = current_version[1] >= 10
from snake_egg import EGraph, Rewrite, Var, vars

from snake_egg import EGraph, Rewrite, Var, vars


# Operations
@dataclass(frozen=True)
class Add:
__match_args__ = ("x", "y")
x: Any
y: Any

@dataclass(frozen=True)
class Mul:
__match_args__ = ("x", "y")
x: Any
y: Any


# Rewrite rules
a, b = vars("a b") # type: ignore
if isatleast10:
rules = [
Rewrite(Add(a, b), Add(b, a), name="commute-add"),
Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"),
Rewrite(Add(a, 0), a, name="add-0"),
Rewrite(Mul(a, 0), 0, name="mul-0"),
Rewrite(Mul(a, 1), a, name="mul-1"),
]
else:
rules = []

rules = [
Rewrite(Add(a, b), Add(b, a), name="commute-add"),
Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"),
Rewrite(Add(a, 0), a, name="add-0"),
Rewrite(Mul(a, 0), 0, name="mul-0"),
Rewrite(Mul(a, 1), a, name="mul-1"),
]


def simplify(expr, iters=7):
Expand All @@ -46,16 +43,16 @@ def simplify(expr, iters=7):


def test_simple_1():
assert (not isatleast10) or simplify(Mul(0, 42)) == 0
assert simplify(Mul(0, 42)) == 0


def test_simple_2():
foo = "foo"
assert (not isatleast10) or simplify(Add(0, Mul(1, foo))) == foo
assert simplify(Add(0, Mul(1, foo))) == foo


def test_simple_3():
assert (not isatleast10) or simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")
assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")


test_simple_1()
Expand Down

0 comments on commit adcecf6

Please sign in to comment.