Skip to content

Commit

Permalink
Update test_dataclass.py
Browse files Browse the repository at this point in the history
alternative solution
  • Loading branch information
wraith1995 authored Dec 6, 2022
1 parent 8737926 commit 54642e7
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions snake_egg/tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,37 @@

from dataclasses import dataclass
from typing import Any

import sys
current_version = sys.version_info()
isatleast10 = current_version[1] >= 10
from snake_egg import EGraph, Rewrite, Var, vars



# Operations
@dataclass(frozen=True)
class Add:
__match_args__ = ("x", "y")
x: Any
y: Any

@dataclass(frozen=True)
class Mul:
__match_args__ = ("x", "y")
x: Any
y: Any


# Rewrite rules
a, b = vars("a b") # type: ignore

rules = [
Rewrite(Add(a, b), Add(b, a), name="commute-add"),
Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"),
Rewrite(Add(a, 0), a, name="add-0"),
Rewrite(Mul(a, 0), 0, name="mul-0"),
Rewrite(Mul(a, 1), a, name="mul-1"),
]
if isatleast10:
rules = [
Rewrite(Add(a, b), Add(b, a), name="commute-add"),
Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"),
Rewrite(Add(a, 0), a, name="add-0"),
Rewrite(Mul(a, 0), 0, name="mul-0"),
Rewrite(Mul(a, 1), a, name="mul-1"),
]
else:
rules = []


def simplify(expr, iters=7):
Expand All @@ -43,16 +46,16 @@ def simplify(expr, iters=7):


def test_simple_1():
assert simplify(Mul(0, 42)) == 0
assert (not isatleast10) or simplify(Mul(0, 42)) == 0


def test_simple_2():
foo = "foo"
assert simplify(Add(0, Mul(1, foo))) == foo
assert (not isatleast10) or simplify(Add(0, Mul(1, foo))) == foo


def test_simple_3():
assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")
assert (not isatleast10) or simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")


test_simple_1()
Expand Down

0 comments on commit 54642e7

Please sign in to comment.