-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial JaCe code. This PR introduces a series of very basic functionalities, that will be extended in subsequent PRs. Most importantly it introduces the basic infrastructure that allows to translate a Jaxpr object into an SDFG. Furthermore, it also adds the `jace.jit` decorator, which can be used as a replacement for `jax.jit`. However, a function that was decorated with `jace.jit` remains fully composable with Jax transformation, such as `jax.grad` or `jax.jacfwd`. But, the functionality is still very basic and `jace.jit`, essentially does not accepts any arguments and only works on CPU. Nevertheless, there is a cache for caching tracing, translation and compilation of wrapped functions. Although, this PR introduces the components for the translation, the actual primitive translators are not yet implemented (this commit adds the `ALUTranslator`, but this one was back ported from the prototype for simple tests). As a last point the tests, located in `tests` are only a first version, that were not included in the review of this PR. They will be reviewed at a later stage For more information see the `ROADMAP.md` file. --------- Co-authored-by: Enrique González Paredes <[email protected]> Co-authored-by: Philip Mueller <[email protected]>
- Loading branch information
1 parent
dca9349
commit 1ebbcf3
Showing
33 changed files
with
4,317 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[tool.ruff.format] | ||
skip-magic-trailing-comma = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""Package metadata: version, authors, license and copyright.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Final | ||
|
||
from packaging import version as pkg_version | ||
|
||
|
||
__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"] | ||
|
||
__author__: Final = "ETH Zurich and individual contributors" | ||
__copyright__: Final = "Copyright (c) 2024 ETH Zurich" | ||
__license__: Final = "BSD-3-Clause-License" | ||
__version__: Final = "0.0.1" | ||
__version_info__: Final = pkg_version.parse(__version__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""Implementation of the `jax.*` namespace.""" | ||
|
||
from __future__ import annotations | ||
|
||
import functools | ||
from typing import TYPE_CHECKING, Any, Literal, overload | ||
|
||
from jax import grad, jacfwd, jacrev | ||
|
||
from jace import stages, translator | ||
|
||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable, Mapping | ||
|
||
|
||
__all__ = ["grad", "jacfwd", "jacrev", "jit"] | ||
|
||
|
||
@overload | ||
def jit( | ||
fun: Literal[None] = None, | ||
/, | ||
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, | ||
**kwargs: Any, | ||
) -> Callable[[Callable], stages.JaCeWrapped]: ... | ||
|
||
|
||
@overload | ||
def jit( | ||
fun: Callable, | ||
/, | ||
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, | ||
**kwargs: Any, | ||
) -> stages.JaCeWrapped: ... | ||
|
||
|
||
def jit( | ||
fun: Callable | None = None, | ||
/, | ||
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, | ||
**kwargs: Any, | ||
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]: | ||
""" | ||
JaCe's replacement for `jax.jit` (just-in-time) wrapper. | ||
It works the same way as `jax.jit` does, but instead of using XLA the | ||
computation is lowered to DaCe. In addition it accepts some JaCe specific | ||
arguments. | ||
Args: | ||
fun: Function to wrap. | ||
primitive_translators: Use these primitive translators for the lowering to SDFG. | ||
If not specified the translators in the global registry are used. | ||
kwargs: Jit arguments. | ||
Notes: | ||
After constructions any change to `primitive_translators` has no effect. | ||
""" | ||
if kwargs: | ||
# TODO(phimuell): Add proper name verification and exception type. | ||
raise NotImplementedError( | ||
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}." | ||
) | ||
|
||
def wrapper(f: Callable) -> stages.JaCeWrapped: | ||
# TODO(egparedes): Improve typing. | ||
jace_wrapper = stages.JaCeWrapped( | ||
fun=f, | ||
primitive_translators=( | ||
translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY | ||
if primitive_translators is None | ||
else primitive_translators | ||
), | ||
jit_options=kwargs, | ||
) | ||
functools.update_wrapper(jace_wrapper, f) | ||
return jace_wrapper | ||
|
||
return wrapper if fun is None else wrapper(fun) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
""" | ||
JaCe specific optimizations. | ||
Currently just a dummy exists for the sake of providing a callable function. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Final, TypedDict | ||
|
||
from typing_extensions import Unpack | ||
|
||
|
||
if TYPE_CHECKING: | ||
from jace import translator | ||
|
||
|
||
class CompilerOptions(TypedDict, total=False): | ||
""" | ||
All known compiler options to `JaCeLowered.compile()`. | ||
See `jace_optimize()` for a description of the different options. | ||
There are some predefined option sets in `jace.jax.stages`: | ||
- `DEFAULT_OPTIONS` | ||
- `NO_OPTIMIZATIONS` | ||
""" | ||
|
||
auto_optimize: bool | ||
simplify: bool | ||
|
||
|
||
# TODO(phimuell): Add a context manager to modify the default. | ||
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True} | ||
|
||
NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False} | ||
|
||
|
||
def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs | ||
""" | ||
Performs optimization of the translated SDFG _in place_. | ||
It is recommended to use the `CompilerOptions` `TypedDict` to pass options | ||
to the function. However, any option that is not specified will be | ||
interpreted as to be disabled. | ||
Args: | ||
tsdfg: The translated SDFG that should be optimized. | ||
simplify: Run the simplification pipeline. | ||
auto_optimize: Run the auto optimization pipeline (currently does nothing) | ||
""" | ||
# Currently this function exists primarily for the same of existing. | ||
|
||
simplify = kwargs.get("simplify", False) | ||
auto_optimize = kwargs.get("auto_optimize", False) | ||
|
||
if simplify: | ||
tsdfg.sdfg.simplify() | ||
|
||
if auto_optimize: | ||
pass | ||
|
||
tsdfg.validate() |
Oops, something went wrong.