Skip to content

Commit

Permalink
Support other types for expressions and dynamic appliers (#7)
Browse files Browse the repository at this point in the history
* Use pytest instead of unittest

* Create a wrapper python module

* remove SingletonOrTuple conversion traits

* update to egg 0.9.0

* slightly rewrite simple test

* reorganize codebase

* Change to published version of egg

* Update tests path in makefile

* Remove unused imports

* Fix test paths again

* Remove use of private attribute

* Install custom ibis version for testing

* Rename back to snake_egg and update typings

* Fix makefile back to snake_egg from egg

* Fix stubtest make

* Update type stubs to be picked up

* Allow only one arg to extract

* Update makefile filename

* Update typings for dynamic applier

* Increase Python version in CI

* Fix tests

* Remove ibis test and add other simple test

* Fix callable type signature

* remove positional only arg for greater python support

Co-authored-by: Krisztián Szűcs <[email protected]>
  • Loading branch information
saulshanabrook and kszucs authored Nov 16, 2022
1 parent 6c4b9f6 commit d6a008c
Show file tree
Hide file tree
Showing 18 changed files with 1,045 additions and 679 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: 3.7
architecture: x64
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
Expand Down
177 changes: 173 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,175 @@
/target
# Generated by Cargo
# will have compiled files and executables
debug/
target/

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
venv

.vscode
snake_egg.html
# These are backup files generated by rustfmt
**/*.rs.bk

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
13 changes: 10 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@ name = "snake-egg"
version = "0.1.0"
edition = "2021"

[package.metadata.maturin]
name = "snake_egg._internal"

[lib]
name = "snake_egg"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
codegen-units = 1

[dependencies]
hashbrown = "0.11"
once_cell = "1"
pyo3 = { version = "0.15", features = ["extension-module"] }
egg = { git = "https://github.com/egraphs-good/egg", rev = "v0.7.1" }
pyo3 = { version = "0.16", features = ["extension-module"] }
egg = "0.9.1"
20 changes: 11 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@ venv:
build: venv
$(activate) && maturin build --release

test: tests/*.py build venv
$(activate) && maturin develop && python tests/math.py
$(activate) && maturin develop && python tests/prop.py
$(activate) && maturin develop && python tests/simple.py

stubtest: snake_egg.pyi build venv
test: snake_egg/tests/*.py build venv
$(activate) && maturin develop && python snake_egg/tests/test_math.py
$(activate) && maturin develop && python snake_egg/tests/test_prop.py
$(activate) && maturin develop && python snake_egg/tests/test_simple.py
$(activate) && maturin develop && python snake_egg/tests/test_dynamic.py
$(activate) && maturin develop && python snake_egg/tests/test_dataclass.py

stubtest: snake_egg/__init__.pyi build venv
$(activate) && maturin develop --extras=dev && python -m mypy.stubtest snake_egg --ignore-missing-stub

mypy: snake_egg.pyi tests/*.py build venv
$(activate) && maturin develop --extras=dev && mypy tests
mypy: snake_egg/__init__.pyi build venv
$(activate) && maturin develop --extras=dev && mypy snake_egg

install: venv
$(activate) maturin build --release && \
python -m pip install snake_egg --force-reinstall --no-index \
--find-link ./target/wheels/

doc: venv
$(activate) && maturin develop && python -m pydoc -w snake_egg
$(activate) && maturin develop && python -m pydoc -w snake_egg.internal

shell: venv
$(activate) && maturin develop && python -ic 'import snake_egg'
Expand Down
17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
[project]
name = "snake-egg"

[build-system]
requires = ["maturin>=0.11,<0.12"]
build-backend = "maturin"
dependencies = ["typing-extensions"]

[project]
name = "snake-egg"

[tool.isort]
profile = "black"

[tool.maturin]
sdist-include = ["Cargo.lock"]

[project.optional-dependencies]
dev = [
"mypy"
]
[tool.mypy]
ignore_missing_imports = false
ignore_missing_imports = true
warn_redundant_casts = true
check_untyped_defs = true
strict_equality = true
warn_unused_configs = true
warn_unused_configs = true
enable_recursive_aliases = true
15 changes: 15 additions & 0 deletions snake_egg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ._internal import PyEGraph # type: ignore
from ._internal import vars # type: ignore
from ._internal import PyId as Id # type: ignore
from ._internal import PyPattern as Pattern # type: ignore
from ._internal import PyRewrite as Rewrite # type: ignore
from ._internal import PyVar as Var # type: ignore


class EGraph(PyEGraph):
def extract(self, expr):
result = super().extract(expr)
if len(result) == 1:
return result[0]
else:
return result
9 changes: 6 additions & 3 deletions snake_egg.pyi → snake_egg/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Hashable, Iterable
from typing import Optional
from typing import Dict, Optional, Protocol, Union

from typing_extensions import final

Expand All @@ -12,9 +12,12 @@ class Id: ...
class Var:
def __init__(self, name: str) -> None: ...

class _CallableApplier(Protocol):
def __call__(self, **substiution: Dict[str, _Expr]) -> _Expr: ...

@final
class Rewrite:
def __init__(self, lhs: _Expr, rhs: _Expr, name: str = "") -> None: ...
def __init__(self, lhs: _Expr, rhs: Union[_Expr, _CallableApplier], name: str = "") -> None: ...
@property
def name(self) -> str: ...

Expand All @@ -38,6 +41,6 @@ class EGraph:
time_limit: float = 10.0,
node_limit: int = 100000,
) -> None: ...
def extract(self, *exprs: _Expr) -> tuple[_Expr, ...] | _Expr: ...
def extract(self, expr: _Expr) -> _Expr: ...

def vars(vars: str) -> tuple[Var, ...] | Var: ...
Empty file added snake_egg/py.typed
Empty file.
62 changes: 62 additions & 0 deletions snake_egg/tests/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3

# This is a reimplementation of simple.rs from the Rust egg repository

from dataclasses import dataclass
from typing import Any

from snake_egg import EGraph, Rewrite, Var, vars


# Operations
@dataclass(frozen=True)
class Add:
x: Any
y: Any

@property
def __match_args__(self):
return (self.x, self.y)


@dataclass(frozen=True)
class Mul:
x: Any
y: Any

@property
def __match_args__(self):
return (self.x, self.y)


# Rewrite rules
a, b = vars("a b") # type: ignore

rules = [
Rewrite(Add(a, b), Add(b, a), name="commute-add"),
Rewrite(Mul(a, b), Mul(b, a), name="commute-mul"),
Rewrite(Add(a, 0), a, name="add-0"),
Rewrite(Mul(a, 0), 0, name="mul-0"),
Rewrite(Mul(a, 1), a, name="mul-1"),
]


def simplify(expr, iters=7):
egraph = EGraph()
egraph.add(expr)
egraph.run(rules, iters)
best = egraph.extract(expr)
return best


def test_simple_1():
assert simplify(Mul(0, 42)) == 0


def test_simple_2():
foo = "foo"
assert simplify(Add(0, Mul(1, foo))) == foo


def test_simple_3():
assert simplify(Mul(2, Mul(1, "foo"))) == Mul(2, "foo")
Loading

0 comments on commit d6a008c

Please sign in to comment.