Skip to content

Commit

Permalink
feat: Initial implementation (#3)
Browse files Browse the repository at this point in the history
Initial JaCe code.

This PR introduces a series of very basic functionalities, that will be
extended in subsequent PRs.
Most importantly it introduces the basic infrastructure that allows to
translate a Jaxpr object into an SDFG.
Furthermore, it also adds the `jace.jit` decorator, which can be used as
a replacement for `jax.jit`.
However, a function that was decorated with `jace.jit` remains fully
composable with Jax transformation, such as `jax.grad` or `jax.jacfwd`.
But, the functionality is still very basic and `jace.jit`, essentially
does not accepts any arguments and only works on CPU.
Nevertheless, there is a cache for caching tracing, translation and
compilation of wrapped functions.

Although, this PR introduces the components for the translation, the
actual primitive translators are not yet implemented (this commit adds
the `ALUTranslator`, but this one was back ported from the prototype for
simple tests).

As a last point the tests, located in `tests` are only a first version,
that were not included in the review of this PR.
They will be reviewed at a later stage

For more information see the `ROADMAP.md` file.

---------

Co-authored-by: Enrique González Paredes <[email protected]>
Co-authored-by: Philip Mueller <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent dca9349 commit 1ebbcf3
Show file tree
Hide file tree
Showing 33 changed files with 4,317 additions and 8 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ src/*/_version.py
ehthumbs.db
Thumbs.db

# DaCe
.dacecache/
_dacegraphs

# JaCe
.jacecache/

# Common editor files
*~
*.swp
8 changes: 4 additions & 4 deletions ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ A kind of roadmap that gives a rough idea about how the project will be continue

- [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3).
- [ ] Basic functionalities:
- [ ] Annotation `@jace.jit`.
- [ ] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function.
- [ ] Implementing the `stages` model that is supported by Jax.
- [x] Annotation `@jace.jit`.
- [x] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function.
- [x] Implementing the `stages` model that is supported by Jax.
- [ ] Handling Jax arrays as native input (only on single host).
- [ ] Cache the compilation and lowering results for later reuse.
- [x] Cache the compilation and lowering results for later reuse.
In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache.
- [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as:
- [ ] Backporting the ones from the prototype.
Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
"sphinx_copybutton",
]

source_suffix = [".rst", ".md"]
source_suffix = [
".rst",
".md",
]
exclude_patterns = [
"_build",
"**.ipynb_checkpoints",
Expand Down
2 changes: 2 additions & 0 deletions docs/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.ruff.format]
skip-magic-trailing-comma = false
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ module = [

[tool.pytest.ini_options]
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
filterwarnings = ["error"]
filterwarnings = [
"error",
"ignore:numpy\\..*:DeprecationWarning", # DaCe is not NumPy v2.0 ready so ignore the usage of deprecated features.
]
log_cli_level = "INFO"
minversion = "6.0"
testpaths = ["tests"]
Expand Down Expand Up @@ -174,6 +177,7 @@ ignore = [
"TRY003", # [raise-vanilla-args] # TODO(egparedes): reevaluate if it should be activated
"UP038", # [non-pep604-isinstance]
]
task-tags = ["TODO"]
# ignore-init-module-imports = true # deprecated in preview mode
unfixable = []

Expand Down
23 changes: 23 additions & 0 deletions src/jace/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming)
#
# Copyright (c) 2024, ETH Zurich
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Package metadata: version, authors, license and copyright."""

from __future__ import annotations

from typing import Final

from packaging import version as pkg_version


__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"]

__author__: Final = "ETH Zurich and individual contributors"
__copyright__: Final = "Copyright (c) 2024 ETH Zurich"
__license__: Final = "BSD-3-Clause-License"
__version__: Final = "0.0.1"
__version_info__: Final = pkg_version.parse(__version__)
17 changes: 15 additions & 2 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,20 @@

from __future__ import annotations

import jace.translator.primitive_translators as _ # noqa: F401 # Populate the internal registry.

__version__ = "0.1.0"
from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit

__all__ = ["__version__"]

__all__ = [
"__author__",
"__copyright__",
"__license__",
"__version__",
"__version_info__",
"grad",
"jacfwd",
"jacrev",
"jit",
]
87 changes: 87 additions & 0 deletions src/jace/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming)
#
# Copyright (c) 2024, ETH Zurich
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Implementation of the `jax.*` namespace."""

from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Literal, overload

from jax import grad, jacfwd, jacrev

from jace import stages, translator


if TYPE_CHECKING:
from collections.abc import Callable, Mapping


__all__ = ["grad", "jacfwd", "jacrev", "jit"]


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaCeWrapped]: ...


@overload
def jit(
fun: Callable,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped: ...


def jit(
fun: Callable | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.
It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
Args:
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
Notes:
After constructions any change to `primitive_translators` has no effect.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaCeWrapped:
# TODO(egparedes): Improve typing.
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY
if primitive_translators is None
else primitive_translators
),
jit_options=kwargs,
)
functools.update_wrapper(jace_wrapper, f)
return jace_wrapper

return wrapper if fun is None else wrapper(fun)
70 changes: 70 additions & 0 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming)
#
# Copyright (c) 2024, ETH Zurich
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""
JaCe specific optimizations.
Currently just a dummy exists for the sake of providing a callable function.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Final, TypedDict

from typing_extensions import Unpack


if TYPE_CHECKING:
from jace import translator


class CompilerOptions(TypedDict, total=False):
"""
All known compiler options to `JaCeLowered.compile()`.
See `jace_optimize()` for a description of the different options.
There are some predefined option sets in `jace.jax.stages`:
- `DEFAULT_OPTIONS`
- `NO_OPTIMIZATIONS`
"""

auto_optimize: bool
simplify: bool


# TODO(phimuell): Add a context manager to modify the default.
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False}


def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
"""
Performs optimization of the translated SDFG _in place_.
It is recommended to use the `CompilerOptions` `TypedDict` to pass options
to the function. However, any option that is not specified will be
interpreted as to be disabled.
Args:
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
"""
# Currently this function exists primarily for the same of existing.

simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)

if simplify:
tsdfg.sdfg.simplify()

if auto_optimize:
pass

tsdfg.validate()
Loading

0 comments on commit 1ebbcf3

Please sign in to comment.