diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b2a81e6..43ac8c0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,7 +15,7 @@ jobs: - name: Install Python uses: actions/setup-python@v2 with: - python-version: 3.6 + python-version: 3.7 architecture: x64 - name: Install Rust toolchain uses: actions-rs/toolchain@v1 diff --git a/.gitignore b/.gitignore index e26358a..26840e1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,175 @@ -/target +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html Cargo.lock -venv -.vscode -snake_egg.html +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/Cargo.toml b/Cargo.toml index 9e7e3cb..c5bc4fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,12 +3,19 @@ name = "snake-egg" version = "0.1.0" edition = "2021" +[package.metadata.maturin] +name = "snake_egg._internal" + [lib] name = "snake_egg" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] + +[profile.release] +lto = true +codegen-units = 1 [dependencies] hashbrown = "0.11" once_cell = "1" -pyo3 = { version = "0.15", features = ["extension-module"] } -egg = { git = "https://github.com/egraphs-good/egg", rev = "v0.7.1" } +pyo3 = { version = "0.16", features = ["extension-module"] } +egg = "0.9.1" diff --git a/Makefile b/Makefile index 756e295..fdc062e 100644 --- a/Makefile +++ b/Makefile @@ -12,16 +12,18 @@ venv: build: venv $(activate) && maturin build --release -test: tests/*.py build venv - $(activate) && maturin develop && python tests/math.py - $(activate) && maturin develop && python tests/prop.py - $(activate) && maturin develop && python tests/simple.py - -stubtest: snake_egg.pyi build venv +test: snake_egg/tests/*.py build venv + $(activate) && maturin develop && python snake_egg/tests/test_math.py + $(activate) && maturin develop && python snake_egg/tests/test_prop.py + $(activate) && maturin develop && python snake_egg/tests/test_simple.py + $(activate) && maturin develop && python snake_egg/tests/test_dynamic.py + $(activate) && maturin develop && python snake_egg/tests/test_dataclass.py + +stubtest: snake_egg/__init__.pyi build venv $(activate) && maturin develop --extras=dev && python -m mypy.stubtest snake_egg --ignore-missing-stub -mypy: snake_egg.pyi tests/*.py build venv - $(activate) && maturin develop --extras=dev && mypy tests +mypy: snake_egg/__init__.pyi build venv + $(activate) && maturin develop --extras=dev && mypy snake_egg install: venv $(activate) maturin build --release && \ @@ -29,7 +31,7 @@ install: venv --find-link ./target/wheels/ doc: venv - $(activate) && maturin develop && python -m pydoc -w snake_egg + $(activate) && maturin develop && python -m pydoc -w snake_egg.internal shell: venv $(activate) && maturin develop && python -ic 'import snake_egg' diff --git a/pyproject.toml b/pyproject.toml index 75c17d4..945cb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,18 +1,25 @@ -[project] -name = "snake-egg" - [build-system] requires = ["maturin>=0.11,<0.12"] build-backend = "maturin" dependencies = ["typing-extensions"] +[project] +name = "snake-egg" + +[tool.isort] +profile = "black" + +[tool.maturin] +sdist-include = ["Cargo.lock"] + [project.optional-dependencies] dev = [ "mypy" ] [tool.mypy] -ignore_missing_imports = false +ignore_missing_imports = true warn_redundant_casts = true check_untyped_defs = true strict_equality = true -warn_unused_configs = true \ No newline at end of file +warn_unused_configs = true +enable_recursive_aliases = true diff --git a/snake_egg/__init__.py b/snake_egg/__init__.py new file mode 100644 index 0000000..230447e --- /dev/null +++ b/snake_egg/__init__.py @@ -0,0 +1,15 @@ +from ._internal import PyEGraph # type: ignore +from ._internal import vars # type: ignore +from ._internal import PyId as Id # type: ignore +from ._internal import PyPattern as Pattern # type: ignore +from ._internal import PyRewrite as Rewrite # type: ignore +from ._internal import PyVar as Var # type: ignore + + +class EGraph(PyEGraph): + def extract(self, expr): + result = super().extract(expr) + if len(result) == 1: + return result[0] + else: + return result diff --git a/snake_egg.pyi b/snake_egg/__init__.pyi similarity index 73% rename from snake_egg.pyi rename to snake_egg/__init__.pyi index fda8b98..5bd0ce5 100644 --- a/snake_egg.pyi +++ b/snake_egg/__init__.pyi @@ -1,5 +1,5 @@ from collections.abc import Callable, Hashable, Iterable -from typing import Optional +from typing import Dict, Optional, Protocol, Union from typing_extensions import final @@ -12,9 +12,12 @@ class Id: ... class Var: def __init__(self, name: str) -> None: ... +class _CallableApplier(Protocol): + def __call__(self, **substiution: Dict[str, _Expr]) -> _Expr: ... + @final class Rewrite: - def __init__(self, lhs: _Expr, rhs: _Expr, name: str = "") -> None: ... + def __init__(self, lhs: _Expr, rhs: Union[_Expr, _CallableApplier], name: str = "") -> None: ... @property def name(self) -> str: ... @@ -38,6 +41,6 @@ class EGraph: time_limit: float = 10.0, node_limit: int = 100000, ) -> None: ... - def extract(self, *exprs: _Expr) -> tuple[_Expr, ...] | _Expr: ... + def extract(self, expr: _Expr) -> _Expr: ... def vars(vars: str) -> tuple[Var, ...] | Var: ... diff --git a/snake_egg/py.typed b/snake_egg/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/snake_egg/tests/test_dataclass.py b/snake_egg/tests/test_dataclass.py new file mode 100644 index 0000000..d8a5f87 --- /dev/null +++ b/snake_egg/tests/test_dataclass.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# This is a reimplementation of simple.rs from the Rust egg repository + +from dataclasses import dataclass +from typing import Any + +from snake_egg import EGraph, Rewrite, Var, vars + + +# Operations +@dataclass(frozen=True) +class Add: + x: Any + y: Any + + @property + def __match_args__(self): + return (self.x, self.y) + + +@dataclass(frozen=True) +class Mul: + x: Any + y: Any + + @property + def __match_args__(self): + return (self.x, self.y) + + +# Rewrite rules +a, b = vars("a b") # type: ignore + +rules = [ + Rewrite(Add(a, b), Add(b, a), name="commute-add"), + Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"), + Rewrite(Add(a, 0), a, name="add-0"), + Rewrite(Mul(a, 0), 0, name="mul-0"), + Rewrite(Mul(a, 1), a, name="mul-1"), +] + + +def simplify(expr, iters=7): + egraph = EGraph() + egraph.add(expr) + egraph.run(rules, iters) + best = egraph.extract(expr) + return best + + +def test_simple_1(): + assert simplify(Mul(0, 42)) == 0 + + +def test_simple_2(): + foo = "foo" + assert simplify(Add(0, Mul(1, foo))) == foo + + +def test_simple_3(): + assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo") diff --git a/snake_egg/tests/test_dynamic.py b/snake_egg/tests/test_dynamic.py new file mode 100644 index 0000000..d40c997 --- /dev/null +++ b/snake_egg/tests/test_dynamic.py @@ -0,0 +1,49 @@ +""" +This test is an example of using a dynamic rewrite rule. +The test language is a basic mathematical one with: + +* Symbols, represented as strings +* Numbers, represented as ints +* Addition, represented as a named tuple + +We can replace addition nodes with the result, if both values are numbers. +""" +from __future__ import annotations + +from collections import namedtuple +from typing import List, NamedTuple, Union, cast + +from snake_egg import EGraph, Rewrite, Var, vars + + +class Add(NamedTuple): + x: Expr + y: Expr + +Expr = Union[str, int, Add, Var] + + +def replace_add(x: Expr, y: Expr) -> Expr: + if isinstance(x, int) and isinstance(y, int): + return x + y + return Add(x, y) + +x, y = cast(List[Var], vars("x y")) + +rules = [ + Rewrite(Add(x, y), replace_add, name="replace-add"), +] + + +def simplify(expr: Expr): + egraph = EGraph() + egraph.add(expr) + egraph.run(rules) + best = egraph.extract(expr) + return best + + +def test_simplify_add(): + assert simplify(Add(1, 2)) == 3 + assert simplify(Add(1, Add("x", "y"))) == Add(1, Add("x", "y")) + diff --git a/tests/math.py b/snake_egg/tests/test_math.py similarity index 56% rename from tests/math.py rename to snake_egg/tests/test_math.py index 0b55040..6b315f8 100644 --- a/tests/math.py +++ b/snake_egg/tests/test_math.py @@ -7,12 +7,10 @@ # + this disables some rules and tests # * the last three tests -from snake_egg import EGraph, Rewrite, Var, vars - -import unittest -from typing import List, Any from collections import namedtuple +from typing import Any, List +from snake_egg import EGraph, Rewrite, vars # Operations Diff = namedtuple("Diff", "x y") @@ -55,6 +53,7 @@ def eval_math(car, cdr): pass return None + # Rewrite rules, not all are currently used since gaurds aren't in snake-egg yet a, b, c, x, f, g, y = vars("a b c x f g y") # type: ignore list_rules: List[List[Any]] = [ @@ -123,6 +122,7 @@ def eval_math(car, cdr): ["i-parts", Integral(Mul(a, b), x), Sub(Mul(a, Integral(b, x)), Integral(Mul(Diff(x, a), Integral(b, x)), x))], ] +# fmt: on # Turn the lists into rewrites rules = list() @@ -138,7 +138,8 @@ def eval_math(car, cdr): y = "y" five = "five" -def is_equal(expr_a, expr_b, iters=7): + +def is_equal(expr_a, expr_b, iters=5): egraph = EGraph(eval_math) id_a = egraph.add(expr_a) @@ -149,106 +150,99 @@ def is_equal(expr_a, expr_b, iters=7): return egraph.equiv(id_a, id_b) -class TestMathEgraph(unittest.TestCase): - - def test_math_associate_adds(self): - expr_a = Add(1, Add(2, Add(3, Add(4, Add(5, Add(6, 7)))))) - expr_b = Add(7, Add(6, Add(5, Add(4, Add(3, Add(2, 1)))))) - self.assertTrue(is_equal(expr_a, expr_b)) - - def test_math_simplify_add(self): - expr_a = Add(x, Add(x, Add(x, x))) - expr_b = Mul(4, x) - self.assertTrue(is_equal(expr_a, expr_b)) - - def test_math_powers(self): - expr_a = Mul(Pow(2, x), Pow(2, y)) - expr_b = Pow(2, Add(x, y)) - self.assertTrue(is_equal(expr_a, expr_b)) - - def test_math_simplify_const(self): - expr_a = Add(1, Sub(a, Mul(Sub(2, 1), a))) - expr_b = 1 - self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_simplify_root(self): - # expr_a = div(1, sub(div(add(1, sqrt(five)), 2), - # div(sub(1, sqrt(five)), 2))) - # expr_b = div(1, sqrt(five)) - # self.assertTrue(is_equal(expr_a, expr_b)) - - def test_math_simplify_factor(self): - expr_a = Mul(Add(x, 3), Add(x, 1)) - expr_b = Add(Add(Mul(x, x), Mul(4, x)), 3) - self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_diff_same(self): - # expr_a = diff(x, x) - # expr_b = 1 - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_diff_different(self): - # expr_a = diff(x, y) - # expr_b = 0 - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_diff_simple1(self): - # expr_a = diff(x, add(1, mul(2, x))) - # expr_b = 2 - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_diff_simple2(self): - # expr_a = diff(x, add(1, mul(y, x))) - # expr_b = y - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_math_diff_ln(self): - # expr_a = diff(x, ln(x)) - # expr_b = div(1, x) - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_diff_power_simple(self): - # expr_a = diff(x, pow(x, 3)) - # expr_b = mul(3, pow(x, 2)) - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_diff_power_harder(self): - # expr_a = diff(x, sub(pow(x, 3), mul(7, pow(x, 2)))) - # expr_b = mul(x, sub(mul(3, x), 14)) - # self.assertTrue(is_equal(expr_a, expr_b)) - - def test_integ_one(self): - expr_a = Integral(1, x) - expr_b = x - self.assertTrue(is_equal(expr_a, expr_b)) - - def test_integ_sin(self): - expr_a = Integral(Cos(x), x) - expr_b = Sin(x) - self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_integ_x(self): - # expr_a = inte(pow(x, 1), x) - # expr_b = div(pow(x, 2), 2) - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_integ_part1(self): - # expr_a = inte(mul(x, cos(x)), x) - # expr_b = add(mul(x, sin(x)), cos(x)) - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_integ_part2(self): - # expr_a = inte(mul(cos(x), x), x) - # expr_b = add(mul(x, sin(x)), cos(x)) - # self.assertTrue(is_equal(expr_a, expr_b)) - - # def test_integ_part3(self): - # expr_a = inte(ln(x), x) - # expr_b = sub(mul(x, ln(x)), x) - # self.assertTrue(is_equal(expr_a, expr_b)) - - - - -if __name__ == '__main__': - unittest.main(verbosity=2) \ No newline at end of file + +def test_math_associate_adds(): + expr_a = Add(1, Add(2, Add(3, Add(4, Add(5, Add(6, 7)))))) + expr_b = Add(7, Add(6, Add(5, Add(4, Add(3, Add(2, 1)))))) + assert is_equal(expr_a, expr_b) + +def test_math_simplify_add(): + expr_a = Add(x, Add(x, Add(x, x))) + expr_b = Mul(4, x) + assert is_equal(expr_a, expr_b) + +def test_math_powers(): + expr_a = Mul(Pow(2, x), Pow(2, y)) + expr_b = Pow(2, Add(x, y)) + assert is_equal(expr_a, expr_b) + +def test_math_simplify_const(): + expr_a = Add(1, Sub(a, Mul(Sub(2, 1), a))) + expr_b = 1 + assert is_equal(expr_a, expr_b) + +# def test_math_simplify_root(): +# expr_a = div(1, sub(div(add(1, sqrt(five)), 2), +# div(sub(1, sqrt(five)), 2))) +# expr_b = div(1, sqrt(five)) +# assert is_equal(expr_a, expr_b) + +def test_math_simplify_factor(): + expr_a = Mul(Add(x, 3), Add(x, 1)) + expr_b = Add(Add(Mul(x, x), Mul(4, x)), 3) + assert is_equal(expr_a, expr_b) + +# def test_math_diff_same(): +# expr_a = diff(x, x) +# expr_b = 1 +# assert is_equal(expr_a, expr_b) + +# def test_math_diff_different(): +# expr_a = diff(x, y) +# expr_b = 0 +# assert is_equal(expr_a, expr_b) + +# def test_math_diff_simple1(): +# expr_a = diff(x, add(1, mul(2, x))) +# expr_b = 2 +# assert is_equal(expr_a, expr_b) + +# def test_math_diff_simple2(): +# expr_a = diff(x, add(1, mul(y, x))) +# expr_b = y +# assert is_equal(expr_a, expr_b) + +# def test_math_diff_ln(): +# expr_a = diff(x, ln(x)) +# expr_b = div(1, x) +# assert is_equal(expr_a, expr_b) + +# def test_diff_power_simple(): +# expr_a = diff(x, pow(x, 3)) +# expr_b = mul(3, pow(x, 2)) +# assert is_equal(expr_a, expr_b) + +# def test_diff_power_harder(): +# expr_a = diff(x, sub(pow(x, 3), mul(7, pow(x, 2)))) +# expr_b = mul(x, sub(mul(3, x), 14)) +# assert is_equal(expr_a, expr_b) + +def test_integ_one(): + expr_a = Integral(1, x) + expr_b = x + assert is_equal(expr_a, expr_b) + +def test_integ_sin(): + expr_a = Integral(Cos(x), x) + expr_b = Sin(x) + assert is_equal(expr_a, expr_b) + +# def test_integ_x(): +# expr_a = inte(pow(x, 1), x) +# expr_b = div(pow(x, 2), 2) +# assert is_equal(expr_a, expr_b) + +# def test_integ_part1(): +# expr_a = inte(mul(x, cos(x)), x) +# expr_b = add(mul(x, sin(x)), cos(x)) +# assert is_equal(expr_a, expr_b) + +# def test_integ_part2(): +# expr_a = inte(mul(cos(x), x), x) +# expr_b = add(mul(x, sin(x)), cos(x)) +# assert is_equal(expr_a, expr_b) + +# def test_integ_part3(): +# expr_a = inte(ln(x), x) +# expr_b = sub(mul(x, ln(x)), x) +# assert is_equal(expr_a, expr_b) diff --git a/tests/prop.py b/snake_egg/tests/test_prop.py similarity index 62% rename from tests/prop.py rename to snake_egg/tests/test_prop.py index f01f53f..b25cc1e 100644 --- a/tests/prop.py +++ b/snake_egg/tests/test_prop.py @@ -2,12 +2,10 @@ # This is a reimplementation of simple.rs from the Rust egg repository -from snake_egg import EGraph, Rewrite, Var, vars - -import unittest -from typing import List, Any from collections import namedtuple +from typing import Any, List +from snake_egg import EGraph, Rewrite, vars # Operations And = namedtuple("And", "x y") @@ -69,6 +67,7 @@ def eval_prod(car, cdr): ["contrapositive", Implies(a, b), Implies(Not(b), Not(a))], ["lem_imply", And(Implies(a, b), Implies(Not(a), c)), Or(b, c)], ] +# fmt: on # Turn the lists into rewrites rules = list() @@ -79,46 +78,40 @@ def eval_prod(car, cdr): rules.append(Rewrite(frm, to, name)) -def prove_something(start_expr, goal_exprs, tester): +def prove_something(start_expr, goal_exprs): egraph = EGraph(eval_prod) id_start = egraph.add(start_expr) egraph.run(rules, 10) - for i,goal in enumerate(goal_exprs): + for i, goal in enumerate(goal_exprs): id_goal = egraph.add(goal) - tester.assertTrue(egraph.equiv(id_start, id_goal), - "Couldn't prove goal {}: {}".format(i, goal)) + assert egraph.equiv(id_start, id_goal), "Couldn't prove goal {}: {}".format( + i, goal + ) x = "x" y = "y" z = "z" -class TestPropEgraph(unittest.TestCase): - - def test_prove_contrapositive(self): - prove_something(Implies(x,y), - [Implies(x,y), - Or(Not(x), y), - Or(Not(x), Not(Not(y))), - Or(Not(Not(y)), Not(x)), - Implies(Not(y), Not(x))], - self) - - def test_prove_chain(self): - prove_something(And(Implies(x, y), Implies(y, z)), - [And(Implies(x, y), Implies(y, z)), - And(Implies(Not(y), Not(x)), Implies(y, z)), - And(Implies(y, z), Implies(Not(y), Not(x))), - Or(z, Not(x)), - Or(Not(x), z), - Implies(x, z)], - self) - - def test_prove_fold(self): - prove_something(Or(And(False, True), And(True, False)), - [False], - self) - - -if __name__ == '__main__': - unittest.main(verbosity=2) \ No newline at end of file + +def test_prove_contrapositive(): + prove_something(Implies(x,y), + [Implies(x,y), + Or(Not(x), y), + Or(Not(x), Not(Not(y))), + Or(Not(Not(y)), Not(x)), + Implies(Not(y), Not(x))]) + +def test_prove_chain(): + prove_something(And(Implies(x, y), Implies(y, z)), + [And(Implies(x, y), Implies(y, z)), + And(Implies(Not(y), Not(x)), Implies(y, z)), + And(Implies(y, z), Implies(Not(y), Not(x))), + Or(z, Not(x)), + Or(Not(x), z), + Implies(x, z)]) + +def test_prove_fold(): + prove_something(Or(And(False, True), And(True, False)), + [False]) + diff --git a/snake_egg/tests/test_simple.py b/snake_egg/tests/test_simple.py new file mode 100644 index 0000000..1dc3753 --- /dev/null +++ b/snake_egg/tests/test_simple.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +# This is a reimplementation of simple.rs from the Rust egg repository + +from collections import namedtuple +from typing import Any, NamedTuple + +from snake_egg import EGraph, Rewrite, Var, vars + + +# Operations +class Add(NamedTuple): + x: Any + y: Any + + +class Mul(NamedTuple): + x: Any + y: Any + + +# Rewrite rules +a, b = vars("a b") # type: ignore +rules = [ + Rewrite(Add(a, b), Add(b, a), name="commute-add"), + Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"), + Rewrite(Add(a, 0), a, name="add-0"), + Rewrite(Mul(a, 0), 0, name="mul-0"), + Rewrite(Mul(a, 1), a, name="mul-1"), +] + + +def simplify(expr, iters=7): + egraph = EGraph() + egraph.add(expr) + egraph.run(rules, iters) + best = egraph.extract(expr) + return best + + +def test_simple_1(): + assert simplify(Mul(0, 42)) == 0 + + +def test_simple_2(): + foo = "foo" + assert simplify(Add(0, Mul(1, foo))) == foo diff --git a/src/core.rs b/src/core.rs new file mode 100644 index 0000000..f155807 --- /dev/null +++ b/src/core.rs @@ -0,0 +1,203 @@ +use egg::{AstSize, EGraph, Extractor, Id, Pattern, PatternAst, RecExpr, Rewrite, Runner, Var}; +use pyo3::types::{PyList, PyString, PyTuple}; +use pyo3::{basic::CompareOp, prelude::*}; + +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::time::Duration; + +use crate::lang::{PythonAnalysis, PythonApplier, PythonNode}; +use crate::util::{build_node, build_pattern}; + +#[pyclass] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PyId(pub Id); + +#[pymethods] +impl PyId { + fn __richcmp__(&self, other: Self, op: CompareOp) -> bool { + match op { + CompareOp::Lt => self.0 < other.0, + CompareOp::Le => self.0 <= other.0, + CompareOp::Eq => self.0 == other.0, + CompareOp::Ne => self.0 != other.0, + CompareOp::Gt => self.0 > other.0, + CompareOp::Ge => self.0 >= other.0, + } + } +} + +#[pyclass] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PyVar(pub Var); + +#[pymethods] +impl PyVar { + #[new] + fn new(str: &PyString) -> Self { + Self::from_str(str.to_string_lossy().as_ref()) + } + + fn __hash__(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.0.hash(&mut hasher); + hasher.finish().into() + } + + fn __richcmp__(&self, other: Self, op: CompareOp) -> bool { + match op { + CompareOp::Lt => self.0 < other.0, + CompareOp::Le => self.0 <= other.0, + CompareOp::Eq => self.0 == other.0, + CompareOp::Ne => self.0 != other.0, + CompareOp::Gt => self.0 > other.0, + CompareOp::Ge => self.0 >= other.0, + } + } +} + +impl PyVar { + pub fn from_str(str: &str) -> Self { + let v = format!("?{}", str); + PyVar(v.parse().unwrap()) + } +} + +#[pyclass] +pub struct PyPattern { + pub pattern: Pattern, +} + +#[pyclass] +pub struct PyRewrite { + pub rewrite: Rewrite, +} + +#[pymethods] +impl PyRewrite { + #[new] + #[args(name = "\"\"")] + fn new(searcher: PyPattern, applier: &PyAny, name: &str) -> Self { + let rewrite = if applier.is_callable() { + let applier = PythonApplier { + eval: applier.into(), + vars: searcher.pattern.vars(), + }; + Rewrite::new(name, searcher.pattern, applier).unwrap() + } else if let Ok(pat) = applier.extract::() { + Rewrite::new(name, searcher.pattern, pat.pattern).unwrap() + } else { + panic!("Applier must be a pattern or callable"); + }; + PyRewrite { rewrite } + } + + #[getter] + fn name(&self) -> &str { + self.rewrite.name.as_str() + } +} + +impl<'source> FromPyObject<'source> for PyPattern { + fn extract(obj: &'source PyAny) -> PyResult { + let mut ast = PatternAst::default(); + build_pattern(&mut ast, obj); + let pattern = Pattern::from(ast); + Ok(Self { pattern }) + } +} + +#[pyclass(subclass)] +pub struct PyEGraph { + pub egraph: EGraph, +} + +#[pymethods] +impl PyEGraph { + #[new] + fn new(eval: Option) -> Self { + Self { + egraph: EGraph::new(PythonAnalysis { eval }), + } + } + + fn add(&mut self, expr: &PyAny) -> PyId { + PyId(build_node(&mut self.egraph, expr)) + } + + #[args(exprs = "*")] + fn union(&mut self, exprs: &PyTuple) -> bool { + assert!(exprs.len() > 1); + let mut exprs = exprs.iter(); + let id = self.add(exprs.next().unwrap()).0; + let mut did_something = false; + for expr in exprs { + let added = self.add(expr); + did_something |= self.egraph.union(id, added.0); + } + did_something + } + + #[args(exprs = "*")] + fn equiv(&mut self, exprs: &PyTuple) -> bool { + assert!(exprs.len() > 1); + let mut exprs = exprs.iter(); + let id = self.add(exprs.next().unwrap()).0; + let mut all_equiv = true; + for expr in exprs { + let added = self.add(expr); + all_equiv &= added.0 == id + } + all_equiv + } + + fn rebuild(&mut self) -> usize { + self.egraph.rebuild() + } + + #[args(iter_limit = "10", time_limit = "10.0", node_limit = "100_000")] + fn run( + &mut self, + rewrites: &PyList, + iter_limit: usize, + time_limit: f64, + node_limit: usize, + ) -> PyResult<()> { + let refs = rewrites + .iter() + .map(FromPyObject::extract) + .collect::>>>()?; + let egraph = std::mem::take(&mut self.egraph); + let scheduled_runner = Runner::::default(); + let runner = scheduled_runner + .with_iter_limit(iter_limit) + .with_node_limit(node_limit) + .with_time_limit(Duration::from_secs_f64(time_limit)) + .with_egraph(egraph) + .run(refs.iter().map(|r| &r.rewrite)); + + self.egraph = runner.egraph; + Ok(()) + } + + #[args(exprs = "*")] + fn extract(&mut self, py: Python, exprs: &PyTuple) -> Vec { + let ids: Vec = exprs.iter().map(|expr| self.add(expr).0).collect(); + let extractor = Extractor::new(&self.egraph, AstSize); + ids.iter() + .map(|&id| { + let (_cost, recexpr) = extractor.find_best(id); + reconstruct(py, &recexpr) + }) + .collect() + } +} + +fn reconstruct(py: Python, recexpr: &RecExpr) -> PyObject { + let mut objs = Vec::::with_capacity(recexpr.as_ref().len()); + for node in recexpr.as_ref() { + let obj = node.to_object(py, |id| objs[usize::from(id)].clone()); + objs.push(obj) + } + objs.pop().unwrap() +} diff --git a/src/lang.rs b/src/lang.rs new file mode 100644 index 0000000..cfcd5bf --- /dev/null +++ b/src/lang.rs @@ -0,0 +1,243 @@ +use egg::{Analysis, Applier, DidMerge, EGraph, PatternAst, Subst, Symbol}; +use egg::{Id, Language, Var}; +use once_cell::sync::Lazy; +use pyo3::AsPyPointer; +use pyo3::{ + basic::CompareOp, + prelude::*, + types::{PyTuple, PyType, PyDict}, +}; +use std::cmp::Ordering; +use std::sync::Mutex; +use std::{fmt::Display, hash::Hash}; + +use crate::util::{build_node, py_eq}; +use crate::core::PyPattern; + +struct PythonHashable { + obj: PyObject, + hash: isize, +} + +impl PythonHashable { + pub fn new(obj: &PyAny) -> Self { + Self { + obj: obj.into(), + hash: obj.hash().expect("Failed to hash"), + } + } +} + +impl Hash for PythonHashable { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} + +impl PartialEq for PythonHashable { + fn eq(&self, other: &Self) -> bool { + let py = unsafe { Python::assume_gil_acquired() }; + + self.obj + .as_ref(py) + .rich_compare(&other.obj, CompareOp::Eq) + .expect("Failed to compare") + .is_true() + .expect("Failed to extract bool") + } +} + +impl Eq for PythonHashable {} + +#[derive(Debug, Clone)] +pub struct PythonNode { + pub class: PyObject, + pub children: Vec, +} + +impl PythonNode { + pub fn op(ty: &PyType, children: impl IntoIterator) -> Self { + Self { + class: ty.into(), + children: children.into_iter().collect(), + } + } + + pub fn leaf(obj: &PyAny) -> Self { + static LEAVES: Lazy>> = + Lazy::new(Default::default); + let mut leaves = LEAVES.lock().unwrap(); + + let object = PythonHashable::new(obj); + let hashable = leaves.get_or_insert(object); + + Self { + class: hashable.obj.clone(), + children: vec![], + } + } + + pub fn to_object>(&self, py: Python, f: impl FnMut(Id) -> T) -> PyObject { + if self.is_leaf() { + self.class.clone() + } else { + let children = self.children.iter().copied().map(f); + let args = PyTuple::new(py, children.map(|o| o.into_py(py))); + self.class.call1(py, args).expect("Failed to construct") + } + } +} + +impl Language for PythonNode { + fn matches(&self, other: &Self) -> bool { + self.class.as_ptr() == other.class.as_ptr() && self.children.len() == other.children.len() + } + + fn children(&self) -> &[Id] { + &self.children + } + + fn children_mut(&mut self) -> &mut [Id] { + &mut self.children + } +} + +impl PartialEq for PythonNode { + fn eq(&self, other: &Self) -> bool { + self.class.as_ptr() == other.class.as_ptr() && self.children == other.children + } +} + +impl Hash for PythonNode { + fn hash(&self, state: &mut H) { + self.class.as_ptr().hash(state); + self.children.hash(state); + } +} + +impl Ord for PythonNode { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).expect("comparison failed") + } +} + +impl PartialOrd for PythonNode { + fn partial_cmp(&self, other: &Self) -> Option { + match self.class.as_ptr().partial_cmp(&other.class.as_ptr()) { + Some(Ordering::Equal) => {} + ord => return ord, + } + self.children.partial_cmp(&other.children) + } +} + +impl Eq for PythonNode {} + +impl Display for PythonNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Python::with_gil(|py| match self.class.as_ref(py).str() { + Ok(s) => s.fmt(f), + Err(_) => "<>".fmt(f), + }) + } +} + +#[derive(Default)] +pub struct PythonAnalysis { + pub eval: Option, +} + +impl Analysis for PythonAnalysis { + type Data = Option; + + fn make(egraph: &EGraph, enode: &PythonNode) -> Self::Data { + let eval = egraph.analysis.eval.as_ref()?; + let py = unsafe { Python::assume_gil_acquired() }; + + // collect the children if they are not `None` in python + let mut children = Vec::with_capacity(enode.len()); + for &id in enode.children() { + let any = egraph[id].data.as_ref()?.as_ref(py); + if any.is_none() { + return None; + } else { + children.push(any) + } + } + + let res = eval + .call1(py, (enode.class.clone(), children)) + .expect("Failed to call eval"); + if res.is_none(py) { + None + } else { + Some(res) + } + } + + fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge { + let py = unsafe { Python::assume_gil_acquired() }; + let aa = a.as_ref().map(|obj| obj.as_ref(py)).filter(|r| r.is_none()); + let bb = b.as_ref().map(|obj| obj.as_ref(py)).filter(|r| r.is_none()); + match (aa, bb) { + (None, None) => DidMerge(false, false), + (None, Some(bb)) => { + *a = Some(bb.to_object(py)); + DidMerge(true, false) + } + (Some(_), None) => egg::DidMerge(false, true), + (Some(aa), Some(bb)) => { + if !py_eq(aa, bb) { + panic!("Failed to merge") + } + DidMerge(false, false) + } + } + } + + fn modify(egraph: &mut EGraph, id: Id) { + let obj = egraph[id].data.clone(); + if let Some(obj) = obj { + let py = unsafe { Python::assume_gil_acquired() }; + let id2 = build_node(egraph, obj.as_ref(py)); + egraph.union(id, id2); + } + } +} + +pub struct PythonApplier { + pub eval: PyObject, + /// List of vars in the pattern which this is used with + pub vars: Vec, +} + +impl Applier for PythonApplier { + + fn apply_one( + &self, + egraph: &mut EGraph, + eclass: Id, + subst: &Subst, + searcher_ast: Option<&PatternAst>, + rule_name: Symbol, + ) -> Vec { + let py = unsafe { Python::assume_gil_acquired() }; + let kwargs = PyDict::new(py); + + for var in &self.vars { + let id = subst[*var]; + let obj = if let Some(data) = egraph[id].data.clone() { + data + } else { + py.None() + }; + let key = &var.to_string()[1..]; + kwargs.set_item(key, obj).unwrap(); + } + + let result = self.eval.as_ref(py).call((), Some(kwargs)).unwrap(); + let pattern = result.extract::().unwrap(); + pattern.pattern.apply_one(egraph, eclass, subst, searcher_ast, rule_name) + + } +} diff --git a/src/lib.rs b/src/lib.rs index afe35da..b95bc2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,467 +1,27 @@ -use egg::{Language, RecExpr}; -use once_cell::sync::Lazy; +mod core; +mod lang; +mod util; -use std::cmp::Ordering; -use std::sync::Mutex; -use std::{borrow::Cow, fmt::Display, hash::Hash, time::Duration}; +use crate::core::*; +use crate::lang::*; -use pyo3::AsPyPointer; -use pyo3::{ - basic::CompareOp, - prelude::*, - types::{PyList, PyString, PyTuple, PyType}, -}; - -fn py_eq(a: &PyAny, b: impl ToPyObject) -> bool { - a.rich_compare(b, CompareOp::Eq) - .expect("Failed to compare") - .is_true() - .expect("Failed to extract bool") -} - -macro_rules! py_object { - (impl $t:ty { $($rest:tt)* }) => { - #[pymethods] - impl $t { - $($rest)* - - fn __str__(&self) -> String { - self.0.to_string() - } - - fn __repr__(&self) -> String { - format!(concat!(stringify!($t), "({})"), self.0) - } - - fn __richcmp__(&self, other: Self, op: CompareOp) -> bool { - match op { - CompareOp::Lt => self.0 < other.0, - CompareOp::Le => self.0 <= other.0, - CompareOp::Eq => self.0 == other.0, - CompareOp::Ne => self.0 != other.0, - CompareOp::Gt => self.0 > other.0, - CompareOp::Ge => self.0 >= other.0, - } - } - } - }; -} - -#[pyclass] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -struct Id(egg::Id); - -py_object!(impl Id {}); - -#[pyclass] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -struct Var(egg::Var); - -py_object!(impl Var { - #[new] - fn new(str: &PyString) -> Self { - Self::from_str(str.to_string_lossy().as_ref()) - } -}); - -impl Var { - fn from_str(str: &str) -> Self { - let v = format!("?{}", str); - Var(v.parse().unwrap()) - } -} - -#[derive(Debug, Clone)] -struct PyLang { - obj: PyObject, - children: Vec, -} - -impl PyLang { - fn op(ty: &PyType, children: impl IntoIterator) -> Self { - let any = ty.as_ref(); - let py = any.py(); - Self { - obj: any.to_object(py), - children: children.into_iter().collect(), - } - } - - fn leaf(any: &PyAny) -> Self { - struct Hashable { - obj: PyObject, - hash: isize, - } - - impl Hash for Hashable { - fn hash(&self, state: &mut H) { - self.hash.hash(state); - } - } - - impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - let py = unsafe { Python::assume_gil_acquired() }; - py_eq(self.obj.as_ref(py), &other.obj) - } - } - - impl Eq for Hashable {} - - static LEAVES: Lazy>> = Lazy::new(Default::default); - - let hash = any.hash().expect("failed to hash"); - let py = any.py(); - let obj = any.to_object(py); - - let mut leaves = LEAVES.lock().unwrap(); - let hashable = leaves.get_or_insert(Hashable { obj, hash }); - - Self { - obj: hashable.obj.clone(), - children: vec![], - } - } - - fn to_object>(&self, py: Python, f: impl FnMut(egg::Id) -> T) -> PyObject { - if self.is_leaf() { - self.obj.clone() - } else { - let children = self.children.iter().copied().map(f); - let args = PyTuple::new(py, children.map(|o| o.into_py(py))); - self.obj.call1(py, args).expect("Failed to construct") - } - } -} - -impl PartialEq for PyLang { - fn eq(&self, other: &Self) -> bool { - self.obj.as_ptr() == other.obj.as_ptr() && self.children == other.children - } -} - -impl Hash for PyLang { - fn hash(&self, state: &mut H) { - self.obj.as_ptr().hash(state); - self.children.hash(state); - } -} - -impl Ord for PyLang { - fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).expect("comparison failed") - } -} - -impl PartialOrd for PyLang { - fn partial_cmp(&self, other: &Self) -> Option { - match self.obj.as_ptr().partial_cmp(&other.obj.as_ptr()) { - Some(Ordering::Equal) => {} - ord => return ord, - } - self.children.partial_cmp(&other.children) - } -} - -impl Eq for PyLang {} - -impl egg::Language for PyLang { - fn matches(&self, other: &Self) -> bool { - self.obj.as_ptr() == other.obj.as_ptr() && self.children.len() == other.children.len() - } - - fn children(&self) -> &[egg::Id] { - &self.children - } - - fn children_mut(&mut self) -> &mut [egg::Id] { - &mut self.children - } -} - -impl Display for PyLang { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Python::with_gil(|py| match self.obj.as_ref(py).str() { - Ok(s) => s.fmt(f), - Err(_) => "<>".fmt(f), - }) - } -} - -#[pyclass] -struct Pattern { - pattern: egg::Pattern, -} - -#[pymethods] -impl Pattern { - #[new] - fn new(tree: &PyAny) -> Self { - let mut ast = egg::PatternAst::default(); - build_pattern(&mut ast, tree); - let pattern = egg::Pattern::from(ast); - Self { pattern } - } -} - - -fn build_pattern(ast: &mut egg::PatternAst, tree: &PyAny) -> egg::Id { - if let Ok(id) = tree.extract::() { - panic!("Ids are unsupported in patterns: {}", id.0) - } else if let Ok(var) = tree.extract::() { - ast.add(egg::ENodeOrVar::Var(var.0)) - } else if let Ok(tuple) = tree.downcast::() { - let op = PyLang::op( - tree.get_type(), - tuple.iter().map(|child| build_pattern(ast, child)), - ); - ast.add(egg::ENodeOrVar::ENode(op)) - } else { - ast.add(egg::ENodeOrVar::ENode(PyLang::leaf(tree))) - } -} - -#[pyclass] -struct Rewrite { - rewrite: egg::Rewrite, -} - -#[pymethods] -impl Rewrite { - #[new] - #[args(name = "\"\"")] - fn new(lhs: &PyAny, rhs: &PyAny, name: &str) -> Self { - let searcher = Pattern::new(lhs).pattern; - let applier = Pattern::new(rhs).pattern; - - let mut name = Cow::Borrowed(name); - if name == "" { - name = Cow::Owned(format!("{} => {}", searcher, applier)); - } - let rewrite = egg::Rewrite::new(name, searcher, applier).expect("Failed to create rewrite"); - Rewrite { rewrite } - } - - #[getter] - fn name(&self) -> &str { - self.rewrite.name.as_str() - } -} - -#[derive(Default)] -struct PyAnalysis { - eval: Option, -} - -impl egg::Analysis for PyAnalysis { - type Data = Option; - - fn make(egraph: &egg::EGraph, enode: &PyLang) -> Self::Data { - let eval = egraph.analysis.eval.as_ref()?; - let py = unsafe { Python::assume_gil_acquired() }; - - // collect the children if they are not `None` in python - let mut children = Vec::with_capacity(enode.len()); - for &id in enode.children() { - let any = egraph[id].data.as_ref()?.as_ref(py); - if any.is_none() { - return None; - } else { - children.push(any) - } - } - - let res = eval - .call1(py, (enode.obj.clone(), children)) - .expect("Failed to call eval"); - if res.is_none(py) { - None - } else { - Some(res) - } - } - - fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> egg::DidMerge { - let py = unsafe { Python::assume_gil_acquired() }; - let aa = a.as_ref().map(|obj| obj.as_ref(py)).filter(|r| r.is_none()); - let bb = b.as_ref().map(|obj| obj.as_ref(py)).filter(|r| r.is_none()); - match (aa, bb) { - (None, None) => egg::DidMerge(false, false), - (None, Some(bb)) => { - *a = Some(bb.to_object(py)); - egg::DidMerge(true, false) - } - (Some(_), None) => egg::DidMerge(false, true), - (Some(aa), Some(bb)) => { - if !py_eq(aa, bb) { - panic!("Failed to merge") - } - egg::DidMerge(false, false) - } - } - } - - fn modify(egraph: &mut egg::EGraph, id: egg::Id) { - let obj = egraph[id].data.clone(); - if let Some(obj) = obj { - let py = unsafe { Python::assume_gil_acquired() }; - let id2 = add_rec(egraph, obj.as_ref(py)); - egraph.union(id, id2); - } - } -} - -#[pyclass] -struct EGraph { - egraph: egg::EGraph, -} - -type Runner = egg::Runner; - -#[pymethods] -impl EGraph { - #[new] - fn new(eval: Option) -> Self { - Self { - egraph: egg::EGraph::new(PyAnalysis { eval }), - } - } - - fn add(&mut self, expr: &PyAny) -> Id { - Id(add_rec(&mut self.egraph, expr)) - } - - #[args(exprs = "*")] - fn union(&mut self, exprs: &PyTuple) -> bool { - assert!(exprs.len() > 1); - let mut exprs = exprs.iter(); - let id = self.add(exprs.next().unwrap()).0; - let mut did_something = false; - for expr in exprs { - let added = self.add(expr); - did_something |= self.egraph.union(id, added.0); - } - did_something - } - - #[args(exprs = "*")] - fn equiv(&mut self, exprs: &PyTuple) -> bool { - assert!(exprs.len() > 1); - let mut exprs = exprs.iter(); - let id = self.add(exprs.next().unwrap()).0; - let mut all_equiv = true; - for expr in exprs { - let added = self.add(expr); - all_equiv &= added.0 == id - } - all_equiv - } - - fn rebuild(&mut self) -> usize { - self.egraph.rebuild() - } - - #[args( - iter_limit = "10", - time_limit = "10.0", - node_limit = "100_000", - )] - fn run( - &mut self, - rewrites: &PyList, - iter_limit: usize, - time_limit: f64, - node_limit: usize, - ) -> PyResult<()> { - let refs = rewrites - .iter() - .map(FromPyObject::extract) - .collect::>>>()?; - let egraph = std::mem::take(&mut self.egraph); - let scheduled_runner = Runner::default(); - let runner = scheduled_runner - .with_iter_limit(iter_limit) - .with_node_limit(node_limit) - .with_time_limit(Duration::from_secs_f64(time_limit)) - .with_egraph(egraph) - .run(refs.iter().map(|r| &r.rewrite)); - - self.egraph = runner.egraph; - Ok(()) - } - - #[args(exprs = "*")] - fn extract(&mut self, py: Python, exprs: &PyTuple) -> SingletonOrTuple { - let ids: Vec = exprs.iter().map(|expr| self.add(expr).0).collect(); - let extractor = egg::Extractor::new(&self.egraph, egg::AstSize); - ids.iter() - .map(|&id| { - let (_cost, recexpr) = extractor.find_best(id); - reconstruct(py, &recexpr) - }) - .collect() - } - -} - -fn reconstruct(py: Python, recexpr: &RecExpr) -> PyObject { - let mut objs = Vec::::with_capacity(recexpr.as_ref().len()); - for node in recexpr.as_ref() { - let obj = node.to_object(py, |id| objs[usize::from(id)].clone()); - objs.push(obj) - } - objs.pop().unwrap() -} - -fn add_rec(egraph: &mut egg::EGraph, expr: &PyAny) -> egg::Id { - if let Ok(Id(id)) = expr.extract() { - egraph.find(id) - } else if let Ok(Var(var)) = expr.extract() { - panic!("Can't add a var: {}", var) - } else if let Ok(tuple) = expr.downcast::() { - let enode = PyLang::op( - expr.get_type(), - tuple.iter().map(|child| add_rec(egraph, child)), - ); - egraph.add(enode) - } else { - egraph.add(PyLang::leaf(expr)) - } -} - -struct SingletonOrTuple(Vec); - -impl> IntoPy for SingletonOrTuple { - fn into_py(mut self, py: Python) -> PyObject { - match self.0.len() { - 0 => panic!("Shouldn't be empty"), - 1 => self.0.pop().unwrap().into_py(py), - _ => PyTuple::new(py, self.0.into_iter().map(|x| x.into_py(py))).into_py(py), - } - } -} - -impl> FromIterator for SingletonOrTuple { - fn from_iter>(iter: TS) -> Self { - Self(iter.into_iter().collect()) - } -} +use pyo3::{prelude::*, types::PyString}; /// A Python module implemented in Rust. The name of this function must match /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. #[pymodule] -fn snake_egg(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; +fn _internal(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; #[pyfn(m)] - fn vars(vars: &PyString) -> SingletonOrTuple { + fn vars(vars: &PyString) -> Vec { let s = vars.to_string_lossy(); - s.split_whitespace().map(|s| Var::from_str(s)).collect() + s.split_whitespace().map(|s| PyVar::from_str(s)).collect() } Ok(()) } diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..b89256b --- /dev/null +++ b/src/util.rs @@ -0,0 +1,67 @@ +use egg::{EGraph, ENodeOrVar, Id, PatternAst}; +use pyo3::types::{PyTuple, PyType}; +use pyo3::{basic::CompareOp, prelude::*}; + +use crate::{PyId, PyVar, PythonAnalysis, PythonNode}; + +pub fn py_eq(a: &PyAny, b: impl ToPyObject) -> bool { + a.rich_compare(b, CompareOp::Eq) + .expect("Failed to compare") + .is_true() + .expect("Failed to extract bool") +} + +// TODO(kszucs): proper error handling +pub fn build_node(egraph: &mut EGraph, expr: &PyAny) -> Id { + if let Ok(PyId(id)) = expr.extract() { + egraph.find(id) + } else if let Ok(PyVar(var)) = expr.extract() { + panic!("Can't add a var: {}", var) + } else if let Ok(args) = expr.getattr("__match_args__") { + let args = args.downcast::().unwrap(); + let class = if let Ok(class) = expr.getattr("__match_type__") { + class.downcast::().unwrap() + } else { + expr.get_type() + }; + //let children = args.iter().map(|arg| expr.getattr(arg).unwrap()); + let enode = PythonNode::op(class, args.iter().map(|child| build_node(egraph, child))); + egraph.add(enode) + } else if let Ok(tuple) = expr.downcast::() { + let enode = PythonNode::op( + expr.get_type(), + tuple.iter().map(|child| build_node(egraph, child)), + ); + egraph.add(enode) + } else { + egraph.add(PythonNode::leaf(expr)) + } +} + +// TODO(kszucs): proper error handling +pub fn build_pattern(ast: &mut PatternAst, tree: &PyAny) -> Id { + if let Ok(id) = tree.extract::() { + panic!("Ids are unsupported in patterns: {}", id.0) + } else if let Ok(var) = tree.extract::() { + ast.add(ENodeOrVar::Var(var.0)) + // check for Sequence first? + } else if let Ok(args) = tree.getattr("__match_args__") { + let args = args.downcast::().unwrap(); + let class = if let Ok(class) = tree.getattr("__match_type__") { + class.downcast::().unwrap() + } else { + tree.get_type() + }; + //let children = args.iter().map(|arg| tree.getattr(arg).unwrap()); + let enode = PythonNode::op(class, args.iter().map(|child| build_pattern(ast, child))); + ast.add(ENodeOrVar::ENode(enode)) + } else if let Ok(tuple) = tree.downcast::() { + let enode = PythonNode::op( + tree.get_type(), + tuple.iter().map(|child| build_pattern(ast, child)), + ); + ast.add(ENodeOrVar::ENode(enode)) + } else { + ast.add(ENodeOrVar::ENode(PythonNode::leaf(tree))) + } +} diff --git a/tests/simple.py b/tests/simple.py deleted file mode 100644 index ed21c26..0000000 --- a/tests/simple.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 - -# This is a reimplementation of simple.rs from the Rust egg repository - -from snake_egg import EGraph, Rewrite, Var, vars - -import unittest -from typing import List, Any -from collections import namedtuple - - -# Operations -Add = namedtuple("Add", "x y") -Mul = namedtuple("Mul", "x y") - - -# Rewrite rules -a, b = vars("a b") # type: ignore -list_rules: List[List[Any]] = [ - ["commute-add", Add(a, b), Add(b, a)], - ["commute-mul", Mul(a, b), Mul(b, a)], - ["add-0", Add(a, 0), a], - ["mul-0", Mul(a, 0), 0], - ["mul-1", Mul(a, 1), a], -] - -# Turn the lists into rewrites -rules = list() -for l in list_rules: - name = l[0] - frm = l[1] - to = l[2] - rules.append(Rewrite(frm, to, name)) - - -def simplify(expr, iters=7): - egraph = EGraph() - egraph.add(expr) - egraph.run(rules, iters) - best = egraph.extract(expr) - return best - - -class TestSimpleEgraph(unittest.TestCase): - - def test_simple_1(self): - self.assertEqual(simplify(Mul(0, 42)), 0) - - def test_simple_2(self): - foo = "foo" - self.assertEqual(simplify(Add(0, Mul(1, foo))), foo) - - -if __name__ == '__main__': - unittest.main(verbosity=2) \ No newline at end of file