diff --git a/.gitignore b/.gitignore index 25cf9a4..e49382d 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,13 @@ src/*/_version.py ehthumbs.db Thumbs.db +# DaCe +.dacecache/ +_dacegraphs + +# JaCe +.jacecache/ + # Common editor files *~ *.swp diff --git a/ROADMAP.md b/ROADMAP.md index ec14397..5beb172 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -4,11 +4,11 @@ A kind of roadmap that gives a rough idea about how the project will be continue - [x] Being able to perform _some_ translations [PR#3](https://github.com/GridTools/jace/pull/3). - [ ] Basic functionalities: - - [ ] Annotation `@jace.jit`. - - [ ] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function. - - [ ] Implementing the `stages` model that is supported by Jax. + - [x] Annotation `@jace.jit`. + - [x] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function. + - [x] Implementing the `stages` model that is supported by Jax. - [ ] Handling Jax arrays as native input (only on single host). - - [ ] Cache the compilation and lowering results for later reuse. + - [x] Cache the compilation and lowering results for later reuse. In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache. - [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as: - [ ] Backporting the ones from the prototype. diff --git a/docs/conf.py b/docs/conf.py index cb0bb09..01d2ca7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -25,7 +25,10 @@ "sphinx_copybutton", ] -source_suffix = [".rst", ".md"] +source_suffix = [ + ".rst", + ".md", +] exclude_patterns = [ "_build", "**.ipynb_checkpoints", diff --git a/docs/pyproject.toml b/docs/pyproject.toml new file mode 100644 index 0000000..b6658df --- /dev/null +++ b/docs/pyproject.toml @@ -0,0 +1,2 @@ +[tool.ruff.format] +skip-magic-trailing-comma = false diff --git a/pyproject.toml b/pyproject.toml index 0add471..393ce01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,10 @@ module = [ [tool.pytest.ini_options] addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] -filterwarnings = ["error"] +filterwarnings = [ + "error", + "ignore:numpy\\..*:DeprecationWarning", # DaCe is not NumPy v2.0 ready so ignore the usage of deprecated features. +] log_cli_level = "INFO" minversion = "6.0" testpaths = ["tests"] @@ -174,6 +177,7 @@ ignore = [ "TRY003", # [raise-vanilla-args] # TODO(egparedes): reevaluate if it should be activated "UP038", # [non-pep604-isinstance] ] +task-tags = ["TODO"] # ignore-init-module-imports = true # deprecated in preview mode unfixable = [] diff --git a/src/jace/__about__.py b/src/jace/__about__.py new file mode 100644 index 0000000..437e86b --- /dev/null +++ b/src/jace/__about__.py @@ -0,0 +1,23 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Package metadata: version, authors, license and copyright.""" + +from __future__ import annotations + +from typing import Final + +from packaging import version as pkg_version + + +__all__ = ["__author__", "__copyright__", "__license__", "__version__", "__version_info__"] + +__author__: Final = "ETH Zurich and individual contributors" +__copyright__: Final = "Copyright (c) 2024 ETH Zurich" +__license__: Final = "BSD-3-Clause-License" +__version__: Final = "0.0.1" +__version_info__: Final = pkg_version.parse(__version__) diff --git a/src/jace/__init__.py b/src/jace/__init__.py index ebc8a4f..11c5d2a 100644 --- a/src/jace/__init__.py +++ b/src/jace/__init__.py @@ -9,7 +9,20 @@ from __future__ import annotations +import jace.translator.primitive_translators as _ # noqa: F401 # Populate the internal registry. -__version__ = "0.1.0" +from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__ +from .api import grad, jacfwd, jacrev, jit -__all__ = ["__version__"] + +__all__ = [ + "__author__", + "__copyright__", + "__license__", + "__version__", + "__version_info__", + "grad", + "jacfwd", + "jacrev", + "jit", +] diff --git a/src/jace/api.py b/src/jace/api.py new file mode 100644 index 0000000..8afc20a --- /dev/null +++ b/src/jace/api.py @@ -0,0 +1,87 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implementation of the `jax.*` namespace.""" + +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Literal, overload + +from jax import grad, jacfwd, jacrev + +from jace import stages, translator + + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + +__all__ = ["grad", "jacfwd", "jacrev", "jit"] + + +@overload +def jit( + fun: Literal[None] = None, + /, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + **kwargs: Any, +) -> Callable[[Callable], stages.JaCeWrapped]: ... + + +@overload +def jit( + fun: Callable, + /, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + **kwargs: Any, +) -> stages.JaCeWrapped: ... + + +def jit( + fun: Callable | None = None, + /, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + **kwargs: Any, +) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]: + """ + JaCe's replacement for `jax.jit` (just-in-time) wrapper. + + It works the same way as `jax.jit` does, but instead of using XLA the + computation is lowered to DaCe. In addition it accepts some JaCe specific + arguments. + + Args: + fun: Function to wrap. + primitive_translators: Use these primitive translators for the lowering to SDFG. + If not specified the translators in the global registry are used. + kwargs: Jit arguments. + + Notes: + After constructions any change to `primitive_translators` has no effect. + """ + if kwargs: + # TODO(phimuell): Add proper name verification and exception type. + raise NotImplementedError( + f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}." + ) + + def wrapper(f: Callable) -> stages.JaCeWrapped: + # TODO(egparedes): Improve typing. + jace_wrapper = stages.JaCeWrapped( + fun=f, + primitive_translators=( + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY + if primitive_translators is None + else primitive_translators + ), + jit_options=kwargs, + ) + functools.update_wrapper(jace_wrapper, f) + return jace_wrapper + + return wrapper if fun is None else wrapper(fun) diff --git a/src/jace/optimization.py b/src/jace/optimization.py new file mode 100644 index 0000000..b5af4fa --- /dev/null +++ b/src/jace/optimization.py @@ -0,0 +1,70 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +JaCe specific optimizations. + +Currently just a dummy exists for the sake of providing a callable function. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Final, TypedDict + +from typing_extensions import Unpack + + +if TYPE_CHECKING: + from jace import translator + + +class CompilerOptions(TypedDict, total=False): + """ + All known compiler options to `JaCeLowered.compile()`. + + See `jace_optimize()` for a description of the different options. + + There are some predefined option sets in `jace.jax.stages`: + - `DEFAULT_OPTIONS` + - `NO_OPTIMIZATIONS` + """ + + auto_optimize: bool + simplify: bool + + +# TODO(phimuell): Add a context manager to modify the default. +DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True} + +NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False} + + +def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs + """ + Performs optimization of the translated SDFG _in place_. + + It is recommended to use the `CompilerOptions` `TypedDict` to pass options + to the function. However, any option that is not specified will be + interpreted as to be disabled. + + Args: + tsdfg: The translated SDFG that should be optimized. + simplify: Run the simplification pipeline. + auto_optimize: Run the auto optimization pipeline (currently does nothing) + """ + # Currently this function exists primarily for the same of existing. + + simplify = kwargs.get("simplify", False) + auto_optimize = kwargs.get("auto_optimize", False) + + if simplify: + tsdfg.sdfg.simplify() + + if auto_optimize: + pass + + tsdfg.validate() diff --git a/src/jace/stages.py b/src/jace/stages.py new file mode 100644 index 0000000..4639b11 --- /dev/null +++ b/src/jace/stages.py @@ -0,0 +1,319 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +""" +Reimplementation of the `jax.stages` module. + +This module reimplements the public classes of that Jax module. +However, they are a bit different, because JaCe uses DaCe as backend. + +As in Jax JaCe has different stages, the terminology is taken from +[Jax' AOT-Tutorial](https://jax.readthedocs.io/en/latest/aot.html). +- Stage out: + In this phase an executable Python function is translated to Jaxpr. +- Lower: + This will transform the Jaxpr into an SDFG equivalent. As a implementation + note, currently this and the previous step are handled as a single step. +- Compile: + This will turn the SDFG into an executable object, see `dace.codegen.CompiledSDFG`. +- Execution: + This is the actual running of the computation. + +As in Jax the `stages` module give access to the last three stages, but not +the first one. +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any + +import jax as _jax + +from jace import optimization, translator, util +from jace.optimization import CompilerOptions +from jace.translator import post_translation as ptrans +from jace.util import dace_helper, translation_cache as tcache + + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + import dace + +__all__ = [ + "CompilerOptions", # export for compatibility with Jax. + "JaCeCompiled", + "JaCeLowered", + "JaCeWrapped", + "Stage", +] + + +class JaCeWrapped(tcache.CachingStage["JaCeLowered"]): + """ + A function ready to be specialized, lowered, and compiled. + + This class represents the output of functions such as `jace.jit()` and is + the first stage in the translation/compilation chain of JaCe. A user should + never create a `JaCeWrapped` object directly, instead `jace.jit` should be + used for that. While it supports just-in-time lowering and compilation, by + just calling it, these steps can also be performed explicitly. The lowering + performed by this stage is cached, thus if a `JaCeWrapped` object is lowered + later, with the same argument the result is taken from the cache. + Furthermore, a `JaCeWrapped` object is composable with all Jax transformations. + + Args: + fun: The function that is wrapped. + primitive_translators: Primitive translators that that should be used. + jit_options: Options to influence the jit process. + + Todo: + - Support pytrees. + - Support keyword arguments and default values of the wrapped function. + - Support static arguments. + + Note: + The tracing of function will always happen with enabled `x64` mode, + which is implicitly and temporary activated while tracing. + """ + + _fun: Callable + _primitive_translators: dict[str, translator.PrimitiveTranslator] + _jit_options: dict[str, Any] + + def __init__( + self, + fun: Callable, + primitive_translators: Mapping[str, translator.PrimitiveTranslator], + jit_options: Mapping[str, Any], + ) -> None: + super().__init__() + # We have to shallow copy both the translator and the jit options. + # This prevents that any modifications affect `self`. + # Shallow is enough since the translators themselves are immutable. + self._primitive_translators = {**primitive_translators} + # TODO(phimuell): Do we need to deepcopy the options? + self._jit_options = {**jit_options} + self._fun = fun + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Executes the wrapped function, lowering and compiling as needed in one step. + + The arguments passed to this function are the same as the wrapped function uses. + """ + # If we are inside a traced context, then we forward the call to the wrapped + # function. This ensures that JaCe is composable with Jax. + if util.is_tracing_ongoing(*args, **kwargs): + return self._fun(*args, **kwargs) + + lowered = self.lower(*args, **kwargs) + compiled = lowered.compile() + return compiled(*args, **kwargs) + + @tcache.cached_transition + def lower(self, *args: Any, **kwargs: Any) -> JaCeLowered: + """ + Lower this function explicitly for the given arguments. + + Performs the first two steps of the AOT steps described above, i.e. + trace the wrapped function with the given arguments and stage it out + to a Jaxpr. Then translate it to SDFG. The result is encapsulated + inside a `JaCeLowered` object which can later be compiled. + + Note: + The call to the function is cached. As key an abstract description + of the call, similar to the tracers used by Jax, is used. + The tracing is always done with activated `x64` mode. + """ + if len(kwargs) != 0: + raise NotImplementedError("Currently only positional arguments are supported.") + + # TODO(phimuell): Currently the SDFG that we build only supports `C_CONTIGUOUS` + # memory order. Since we support the paradigm that "everything passed to + # `lower()` should also be accepted as argument to call the result", we forbid + # other memory orders here. + if not all((not util.is_array(arg)) or arg.flags["C_CONTIGUOUS"] for arg in args): + raise NotImplementedError("Currently can not yet handle strides beside 'C_CONTIGUOUS'.") + + # In Jax `float32` is the main datatype, and they go to great lengths to avoid + # some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html). + # However, in this case we will have problems when we call the SDFG, for some + # reasons `CompiledSDFG` does not work in that case correctly, thus we enable + # it for the tracing. + with _jax.experimental.enable_x64(): + builder = translator.JaxprTranslationBuilder( + primitive_translators=self._primitive_translators + ) + jaxpr = _jax.make_jaxpr(self._fun)(*args) + trans_ctx: translator.TranslationContext = builder.translate_jaxpr(jaxpr) + + # Perform the post processing and turn it into a `TranslatedJaxprSDFG` that can + # be compiled and called later. + # NOTE: `tsdfg` was deepcopied as a side effect of post processing. + tsdfg: translator.TranslatedJaxprSDFG = ptrans.postprocess_jaxpr_sdfg( + trans_ctx=trans_ctx, + fun=self.wrapped_fun, + call_args=args, # Already linearised, since we only accept positional args. + intree=None, # Not yet implemented. + ) + + return JaCeLowered(tsdfg) + + @property + def wrapped_fun(self) -> Callable: + """Returns the wrapped function.""" + return self._fun + + def _make_call_description(self, *args: Any) -> tcache.StageTransformationSpec: + """ + Computes the key for the `JaCeWrapped.lower()` call inside the cache. + + The function will compute a full abstract description on its argument. + """ + call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in args) + return tcache.StageTransformationSpec(stage_id=id(self), call_args=call_args) + + +class JaCeLowered(tcache.CachingStage["JaCeCompiled"]): + """ + Represents the original computation as an SDFG. + + This class is the output type of `JaCeWrapped.lower()` and represents the + originally wrapped computation as an SDFG. This stage is followed by the + `JaCeCompiled` stage. + + Args: + tsdfg: The translated SDFG object representing the computation. + + Note: + `self` will manage the passed `tsdfg` object. Modifying it results in + undefined behavior. Although `JaCeWrapped` is composable with Jax + transformations `JaCeLowered` is not. A user should never create such + an object, instead `JaCeWrapped.lower()` should be used. + """ + + _translated_sdfg: translator.TranslatedJaxprSDFG + + def __init__(self, tsdfg: translator.TranslatedJaxprSDFG) -> None: + super().__init__() + self._translated_sdfg = tsdfg + + @tcache.cached_transition + def compile(self, compiler_options: CompilerOptions | None = None) -> JaCeCompiled: + """ + Optimize and compile the lowered SDFG using `compiler_options`. + + Returns an object that encapsulates a compiled SDFG object. To influence + the various optimizations and compile options of JaCe you can use the + `compiler_options` argument. If nothing is specified + `jace.optimization.DEFAULT_OPTIMIZATIONS` will be used. + + Note: + Before `compiler_options` is forwarded to `jace_optimize()` it + will be merged with the default arguments. + """ + # We **must** deepcopy before we do any optimization, because all optimizations + # are in place, to properly cache stages, stages needs to be immutable. + tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._translated_sdfg) + optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options)) + + return JaCeCompiled( + csdfg=dace_helper.compile_jax_sdfg(tsdfg), + inp_names=tsdfg.inp_names, + out_names=tsdfg.out_names, + ) + + def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprSDFG: + """ + Returns the internal SDFG. + + The function returns a `TranslatedJaxprSDFG` object. Direct modification + of the returned object is forbidden and will cause undefined behaviour. + """ + if (dialect is None) or (dialect.upper() == "SDFG"): + return self._translated_sdfg + raise ValueError(f"Unknown dialect '{dialect}'.") + + def view(self, filename: str | None = None) -> None: + """ + Runs the `view()` method of the underlying SDFG. + + This will open a browser and display the SDFG. + """ + self.compiler_ir().sdfg.view(filename=filename, verbose=False) + + def as_sdfg(self) -> dace.SDFG: + """ + Returns the encapsulated SDFG. + + Modifying the returned SDFG in any way is undefined behavior. + """ + return self.compiler_ir().sdfg + + def _make_call_description( + self, compiler_options: CompilerOptions | None = None + ) -> tcache.StageTransformationSpec: + """ + This function computes the key for the `self.compile()` call inside the cache. + + The key that is computed by this function is based on the concrete + values of the passed compiler options. + """ + options = self._make_compiler_options(compiler_options) + call_args = tuple(sorted(options.items(), key=lambda x: x[0])) + return tcache.StageTransformationSpec(stage_id=id(self), call_args=call_args) + + @staticmethod + def _make_compiler_options(compiler_options: CompilerOptions | None) -> CompilerOptions: + return optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {}) + + +class JaCeCompiled: + """ + Compiled version of the SDFG. + + This is the last stage of the jit chain. A user should never create a + `JaCeCompiled` instance, instead `JaCeLowered.compile()` should be used. + + Args: + csdfg: The compiled SDFG object. + inp_names: Names of the SDFG variables used as inputs. + out_names: Names of the SDFG variables used as outputs. + + Note: + The class assumes ownership of its input arguments. + + Todo: + - Handle pytrees. + """ + + _csdfg: dace_helper.CompiledSDFG + _inp_names: tuple[str, ...] + _out_names: tuple[str, ...] + + def __init__( + self, csdfg: dace_helper.CompiledSDFG, inp_names: Sequence[str], out_names: Sequence[str] + ) -> None: + if (not inp_names) or (not out_names): + raise ValueError("Input and output can not be empty.") + self._csdfg = csdfg + self._inp_names = tuple(inp_names) + self._out_names = tuple(out_names) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Calls the embedded computation. + + The arguments must be the same as for the wrapped function, but with + all static arguments removed. + """ + return dace_helper.run_jax_sdfg(self._csdfg, self._inp_names, self._out_names, args, kwargs) + + +#: Known compilation stages in JaCe. +Stage = JaCeWrapped | JaCeLowered | JaCeCompiled diff --git a/src/jace/translator/__init__.py b/src/jace/translator/__init__.py new file mode 100644 index 0000000..2f184a0 --- /dev/null +++ b/src/jace/translator/__init__.py @@ -0,0 +1,37 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Subpackage containing all the code related to the Jaxpr to SDFG translation. + +The concrete primitive translators that ships with JaCe are inside the +`primitive_translators` subpackage. +""" + +from __future__ import annotations + +from .jaxpr_translator_builder import JaxprTranslationBuilder, TranslationContext +from .primitive_translator import ( + PrimitiveTranslator, + PrimitiveTranslatorCallable, + get_registered_primitive_translators, + make_primitive_translator, + register_primitive_translator, +) +from .translated_jaxpr_sdfg import TranslatedJaxprSDFG + + +__all__ = [ + "JaxprTranslationBuilder", + "PrimitiveTranslator", + "PrimitiveTranslatorCallable", + "TranslatedJaxprSDFG", + "TranslationContext", + "get_registered_primitive_translators", + "make_primitive_translator", + "register_primitive_translator", +] diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py new file mode 100644 index 0000000..da2e68f --- /dev/null +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -0,0 +1,760 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +"""Contains the translator that actually builds an SDFG based on a Jaxpr description.""" + +from __future__ import annotations + +import copy +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Literal, cast, overload + +import dace +from dace import data as ddata, properties as dprop +from jax import core as jax_core + +from jace import util + + +if TYPE_CHECKING: + from jace import translator + + +class JaxprTranslationBuilder: + """ + Internal builder class for creating an SDFG equivalent of a `Jaxpr` instance. + + The SDFG created by this class has a very particular form, which we call + canonical. The main features of such an SDFG are: + - the SDFG is a list of states, + - it has a single source and sink state. + - all variable names are derived from Jax names, + - there are only transient variables inside the SDFG, + - It lacks the special `__return` variable, + - the `arg_names` parameter is not set. + + For these reasons the SDFG is not directly usable, and further manipulations + have to be performed. Especially, DaCe's validation function will fail and + it is unable to be processed by JaCe's optimization pipeline. For more + information also see `jace.translator.post_translation` module. + + The idea of the translator is extremely simple. A Jaxpr is essentially a + list consisting of more or less simple instructions/equations, they get + processed one after the other. Each equation is translated into its own + state that is successively appended to the SDFG, while the SDFG is being + build, which explains the particular form of the SDFG. + + However, the actual translation of the equations is not handled by the + builder. Instead the request is forwarded to a `PrimitiveTranslator` + object, known as primitive translator. This is a highly specialized object + that is able to handle one kind of primitive. For more information on them + see the documentation of `PrimitiveTranslator`. + + To start a translation the `translate_jaxpr()` function has to be called, + if this happens it is said that the builder has an ongoing translation. + The first translator is known as root, translator. If `translate_jaxpr()` + is called on a builder that has an ongoing translation, a new translation + context will be set up. Thus the builder will then translate the supplied + (nested) Jaxpr and return the result. However, this will have no influence + on the translation process that is already going. + + Args: + primitive_translators: Primitive translators to use in the translation. + + Notes: + After a translation has been performed the translator object can be used + again. Currently the builder will generate only Array as SDFG variables, + however, this is a temporary solution, see `add_array()`. + """ + + _primitive_translators: Mapping[str, translator.PrimitiveTranslatorCallable] + _jax_name_map: dict[jax_core.Var | util.JaCeVar, str] + _ctx_stack: list[TranslationContext] + + def __init__( + self, primitive_translators: Mapping[str, translator.PrimitiveTranslatorCallable] + ) -> None: + # Maps name of primitives to the associated translator. + self._primitive_translators = {**primitive_translators} + + # Maps Jax variables to the name of its SDFG equivalent. + # Shared between all translation contexts, to ensure consecutive variable + # naming as seen as in a pretty printed Jaxpr. Will be cleared by + # `_clear_translation_ctx()` at the end of the root translation. + self._jax_name_map = {} + + # Stack of all context, to handle nested Jaxpr instances. + # The first one, i.e. index 0, is known as head translator. + self._ctx_stack = [] + + def translate_jaxpr( + self, jaxpr: jax_core.ClosedJaxpr, *, name: str | None = None + ) -> TranslationContext: + """ + Perform the translation of a Jaxpr into a SDFG. + + In case this function is called and `self` has an ongoing translation + process, a new translation context will be created. This allows to + handle nested Jaxprs. However, the variable map is shared among all. + + Returns: + The function will translate the passed Jaxpr object into an SDFG + in canonical form. This SDFG together with additional meta data, + that is needed for further processing is encapsulated inside a + `TranslationContext` object. For further use it should be passed + to `postprocess_jaxpr_sdfg()`. + + Args: + name: Use this name for the SDFG instead some generated one. + jaxpr: The Jaxpr object that should be translated. + """ + if len(jaxpr.effects) != 0: + raise NotImplementedError("'Jaxpr' with side effects are not supported.") + + # NOTE: If `self` is already allocated, i.e. has an ongoing translation + # process, the `_allocate_translation_ctx()` function will start a new + # context. Thus the builder will start to translate a second (nested) + # SDFG. Also note that there is no mechanism that forces the integration + # of the nested SDFG/Jaxpr, this must be done manually. + self._allocate_translation_ctx(name=name) + self._create_constants(jaxpr=jaxpr) + self._create_initial_input(jaxpr=jaxpr) + + return self._translate_jaxpr_internal(jaxpr) + + def append_new_state( + self, + label: str | None = None, + condition: dprop.CodeBlock | None = None, + assignments: Mapping[str, Any] | None = None, + prev_state: dace.SDFGState | None = None, + ) -> dace.SDFGState: + """ + Creates a new `SDFGState`, adds it to the SDFG and returns it. + + By default the new state is appended to the current terminal state. + However, if `prev_state` is specified it will be appended to it. In + case the new state is appended to the current terminal state, this will + modify the terminal state of `self`. + + Args: + label: The name that should be given to the new `SDFGState`. + condition: Condition on the `InterstateEdge`. + assignments: Symbol assignments on the `InterstateEdge`. + prev_state: Alternative state at which we append. + + Notes: + It is potentially dangerous to not append to the current terminal + state, as a canonical SDFG only has one sink state. If this is done + the user has to ensure, that at the end of the processing the SDFG + is back in canonical form. + """ + if isinstance(label, str) and (not util.VALID_SDFG_OBJ_NAME.fullmatch(label)): + raise ValueError(f"Can not create state with label '{label}' since it is invalid.") + + # Decide if appending to that state will modify the terminal state. + modify_term_state: bool = False + if (prev_state is self._ctx.terminal_state) or (prev_state is None): + modify_term_state = True + app_state = self._ctx.terminal_state + else: + app_state = prev_state + + new_state = self._ctx.sdfg.add_state(label, is_start_block=False) + self._ctx.sdfg.add_edge( + app_state, + new_state, + dace.sdfg.InterstateEdge(condition=condition, assignments=assignments), + ) + + if modify_term_state: + self._ctx.terminal_state = new_state + return new_state + + @property + def arrays(self) -> Mapping[str, ddata.Data]: + """ + Get all data descriptors that are currently known to the SDFG. + + Notes: + Essentially a shorthand and preferred way for `self.sdfg.arrays`. + For getting a specific data descriptor use `self.get_array()`. + """ + return cast(Mapping[str, ddata.Data], self._ctx.sdfg.arrays) + + def get_array(self, name: str | jax_core.Atom | util.JaCeVar) -> ddata.Data: + """ + Returns the SDFG `Data` object `name` referees to. + + `name` can either be a string, in which case it is interpreted as a + verbatim SDFG name. If it is a Jax or JaCe variable, the function will + first perform a lookup using `self.map_jax_var_to_sdfg(name)`. + """ + if isinstance(name, (jax_core.Var, util.JaCeVar)): + sdfg_name: str = self.map_jax_var_to_sdfg(name) + elif isinstance(name, str): + sdfg_name = name + else: + raise TypeError(f"The literal '{name}' does not have an SDFG equivalent.") + if sdfg_name not in self._ctx.sdfg.arrays: + raise KeyError(f"Requested SDFG object '{name}' is not known.") + return self._ctx.sdfg.arrays[sdfg_name] + + @overload + def map_jax_var_to_sdfg( + self, jax_var: jax_core.Atom | util.JaCeVar, allow_fail: Literal[False] = False + ) -> str: ... + + @overload + def map_jax_var_to_sdfg( + self, jax_var: jax_core.Atom | util.JaCeVar, allow_fail: Literal[True] + ) -> str | None: ... + + def map_jax_var_to_sdfg( + self, jax_var: jax_core.Atom | util.JaCeVar, allow_fail: bool = False + ) -> str | None: + """ + Get the name of the SDFG variable to which `jax_var` is referring to. + + Args: + jax_var: The Jax variable to look up. + allow_fail: Return `None` instead of raising a `KeyError`. + """ + if isinstance(jax_var, jax_core.Literal): + raise TypeError(f"There is no SDFG variable for literal '{jax_var}'.") + if jax_var in self._jax_name_map: + sdfg_name = self._jax_name_map[jax_var] + elif allow_fail: + return None + else: + raise KeyError(f"The Jax variable '{jax_var}' was never registered.") + if sdfg_name not in self._ctx.sdfg.arrays: + raise KeyError( + f"Jax variable '{jax_var}' was supposed to map to '{sdfg_name}'," + " but no such SDFG variable is known." + ) + return sdfg_name + + @property + def sdfg(self) -> dace.SDFG: + """Returns the SDFG that is currently constructed.""" + return self._ctx.sdfg + + def is_allocated(self) -> bool: + """ + Tests if `self` has an allocated context. + + If `self` is allocated then there is also an ongoing translation process. + """ + return len(self._ctx_stack) != 0 + + def is_root_translator(self) -> bool: + """ + Tests if `self` is the root translator. + + The root translator (context) is the very first translator process. + """ + if not self.is_allocated(): + raise RuntimeError("Builder is not allocated.") + return len(self._ctx_stack) == 1 + + def add_jax_name_mapping( + self, jax_var: jax_core.Var | util.JaCeVar, sdfg_name: str + ) -> JaxprTranslationBuilder: + """ + Creates a new mapping between `jax_var` to `sdfg_name`. + + If the mapping already exists an error will be generated. This function + is not able to delete a variable mapping that was established before. + + Args: + jax_var: The Jax variable. + sdfg_name: The name of the corresponding SDFG variable. + """ + assert sdfg_name + + if jax_var in self._jax_name_map: + raise ValueError( + f"Cannot change the mapping of '{jax_var}' from" + f" '{self.map_jax_var_to_sdfg(jax_var)}' to '{sdfg_name}'." + ) + if sdfg_name not in self._ctx.sdfg.arrays: + raise KeyError(f"Mapping '{jax_var} -> {sdfg_name}': SDFG target unknown.") + if sdfg_name in util.FORBIDDEN_SDFG_VAR_NAMES: + raise NameError(f"Mapping '{jax_var} -> {sdfg_name}': Forbidden name.") + + self._jax_name_map[jax_var] = sdfg_name + return self + + def add_array( + self, + arg: jax_core.Atom | util.JaCeVar, + *, + name_prefix: str | None = None, + update_var_mapping: bool = False, + ) -> str: + """ + Creates an SDFG variable for Jax variable `arg` and returns its SDFG name. + + The SDFG object is always created as a transient. Furthermore, the + function will not update the internal variable mapping, by default. + + By default the function will use `jace.util.propose_jax_name()` to derive + the name that should be used. However, by passing a `JaCeVar` with a + name it is possible to suggest a specific name. In addition it is possible + to specify `name_prefix` to supply a prefix to the determined name that + should be used. + + Args: + arg: The Jax object for which a SDFG equivalent should be created. + name_prefix: If given it will be used as prefix for the name. + update_var_mapping: Update the internal variable mapping. + + Notes: + As a temporary fix for handling scalar return values, the function + will always generate arrays, even if `arg` is a scalar. According to + the DaCe developer, the majority of the backend, i.e. optimization + pipeline, should be able to handle it. But there are some special + parts that might explicitly want a scalar, it also might block + certain compiler optimization. + """ + if isinstance(arg, jax_core.Literal): + raise TypeError(f"Can not generate an SDFG variable for literal '{arg}'.") + + shape: tuple[int | dace.symbol | str, ...] = util.get_jax_var_shape(arg) + dtype: dace.typeclass = util.get_jax_var_dtype(arg) + storage: dace.StorageType = dace.StorageType.Default # Set at later stages (optimization) + offset = None + as_transient = True + strides = None + + # Temporary fix for handling DaCe scalars, see above for more. + shape = shape or (1,) + + # Propose a name and if needed extend it. + arg_name = util.propose_jax_name(arg, self._jax_name_map) + if name_prefix: + arg_name = f"{name_prefix}{arg_name}" + + # final checks + if arg_name in self._ctx.sdfg.arrays: + raise ValueError(f"add_array({arg}): The proposed name '{arg_name}', is used.") + if not util.VALID_SDFG_VAR_NAME.fullmatch(arg_name): + raise ValueError(f"add_array({arg}): The proposed name '{arg_name}', is invalid.") + if arg_name in util.FORBIDDEN_SDFG_VAR_NAMES: + raise ValueError(f"add_array({arg}): The proposed name '{arg_name}', is forbidden.") + + self._ctx.sdfg.add_array( + name=arg_name, + shape=shape, + strides=strides, + offset=offset, + storage=storage, + dtype=dtype, + transient=as_transient, + ) + + if update_var_mapping: + try: + # If the mapping fails, remove the variable from the SDFG. + self.add_jax_name_mapping(jax_var=arg, sdfg_name=arg_name) + except: + del self._ctx.sdfg.arrays[arg_name] + raise + + return arg_name + + @overload + def create_jax_var_list( + self, + jax_var_list: Sequence[jax_core.Atom | util.JaCeVar], + prevent_creation: bool = False, + only_creation: bool = True, + handle_literals: bool = False, + **kwargs: Any, + ) -> list[str]: ... + + @overload + def create_jax_var_list( # type: ignore[misc] + self, + jax_var_list: Sequence[jax_core.Atom | util.JaCeVar], + prevent_creation: bool = False, + only_creation: bool = False, + handle_literals: bool = False, + **kwargs: Any, + ) -> list[None | str]: ... + + def create_jax_var_list( # type: ignore[misc] + self, + jax_var_list: Sequence[jax_core.Atom | util.JaCeVar], + prevent_creation: bool = False, + only_creation: bool = False, + handle_literals: bool = False, + **kwargs: Any, + ) -> list[None | str]: + """ + Create SDFG variables from the passed Jax variables. + + If a Jax variable already has a SDFG equivalent then the function will + use this variable. If no corresponding SDFG variable is known the function + will create one using `add_array()`. + + By setting `prevent_creation` the function will not create any new SDFG + variables, if no corresponding SDFG variable exists an error is generated. + By setting `only_creation` the function will only create new SDFG variables, + if a variable already have a corresponding SDFG variable an error will be + generated. + + By default literals cause an error. However, by setting `handle_literals` + to `True` literals will will be included in the output with the value `None`. + + Args: + jax_var_list: The list of Jax variables that should be processed. + prevent_creation: Never create a variable, all must already be known. + only_creation: Always create a variable. + handle_literals: Allow the processing of literals. + kwargs: Will be forwarded to `self.add_array()` if a variable is created. + + Todo: + - Rollback if the creation fails. + """ + if only_creation and prevent_creation: + raise ValueError("Specified both 'only_creation' and 'prevent_creation'.") + + ret_list: list[None | str] = [] + for jax_var in jax_var_list: + if isinstance(jax_var, jax_core.Literal): + if not handle_literals: + raise ValueError("Encountered a literal but `handle_literals` was `False`.") + sdfg_name = None + else: + mapped_sdfg_name: str | None = self.map_jax_var_to_sdfg(jax_var, allow_fail=True) + if prevent_creation and (mapped_sdfg_name is None): + raise ValueError(f"'prevent_creation' given but have to create '{jax_var}'.") + if mapped_sdfg_name is None: + sdfg_name = self.add_array(arg=jax_var, **kwargs) + elif only_creation: + raise ValueError(f"'only_creation' given but '{jax_var}' already exists.") + else: + sdfg_name = mapped_sdfg_name + ret_list.append(sdfg_name) + + return ret_list + + def _create_initial_input(self, jaxpr: jax_core.ClosedJaxpr) -> None: + """ + Creates the input variables of `jaxpr`. + + Notes: + The function will populate the `inp_names` member of the current context. + """ + assert self.is_allocated(), "Builder is not allocated, can not create constants." + assert self._ctx.inp_names is None + + # Handle the initial input arguments + init_in_var_names: Sequence[str] = self.create_jax_var_list( + jax_var_list=jaxpr.jaxpr.invars, + only_creation=True, # Nothing exists yet. + handle_literals=False, # Initial arguments are never literals + update_var_mapping=True, + ) + self.sdfg.arg_names = [] + + # The output list is populated by `self._translate_jaxpr_internal()` + self._ctx.inp_names = tuple(init_in_var_names) + + def _create_constants(self, jaxpr: jax_core.ClosedJaxpr) -> None: + """ + Creates all constants requested by the `jaxpr`. + + The function will create an SDFG variable and add them as constant to + the SDFG. Their value is deepcopied. + """ + assert self.is_allocated(), "Builder is not allocated, can not create constants." + if len(jaxpr.consts) == 0: + return + + sdfg_const_names: Sequence[str] = self.create_jax_var_list( + jax_var_list=jaxpr.jaxpr.constvars, + only_creation=True, # Nothing exists yet. + handle_literals=False, # It seems that constants are never literals. + name_prefix="__const_", + update_var_mapping=True, + ) + for sdfg_name, const_value in zip(sdfg_const_names, jaxpr.consts, strict=True): + self._ctx.sdfg.add_constant( + sdfg_name, copy.deepcopy(const_value), self._ctx.sdfg.arrays[sdfg_name] + ) + + def _allocate_translation_ctx(self, name: str | None = None) -> JaxprTranslationBuilder: + """ + Allocate a new context and activate it. + + Args: + name: The name of the SDFG. + """ + self._ctx_stack.append(TranslationContext(name=name)) + return self + + @property + def _ctx(self) -> TranslationContext: + """Returns the currently active translation context.""" + assert len(self._ctx_stack) != 0, "No context is active." + return self._ctx_stack[-1] + + def _clear_translation_ctx(self) -> TranslationContext | None: + """ + Remove the currently active context from `self` and returns it. + + If `self` is not allocated it will return `None`. + """ + if not self.is_allocated(): + return None + + if self.is_root_translator(): + # The translation, as a whole has finished, so restore the builder, + # i.e. delete all the shared state. + self._jax_name_map = {} + + # Remove the current head stack. + return self._ctx_stack.pop() + + def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: + """ + Translate `eqn` into its SDFG equivalent. + + To do this the function will perform the following steps: + - Assemble the in and output variables. + - Select the appropriate primitive translator to use. + - Create a new empty state terminal state. + - Call the primitive translator to perform the translation inside the new state. + """ + if len(eqn.effects) != 0: + raise NotImplementedError(f"Equation '{eqn}' has side effects.") + + # Input/Output variables + # Using a tuple for the input ensures that it cannot be modified. + in_var_names: Sequence[str | None] = self.create_jax_var_list( + eqn.invars, + prevent_creation=True, # Inputs must already exists. + handle_literals=True, # but they can be literals. + ) + out_var_names: Sequence[str] = self.create_jax_var_list( + eqn.outvars, + only_creation=True, # Output must not exist yet. + update_var_mapping=True, + ) + + primitive_name: str = eqn.primitive.name + if primitive_name not in self._primitive_translators: + raise NotImplementedError(f"No translator known to handle '{primitive_name}'.") + translator = self._primitive_translators[primitive_name] + + # Create the state into which the equation should be translated + eqn_state = self.append_new_state( + label=f"{primitive_name}_{'_'.join(out_var_names)}", + prev_state=None, # forces the creation of a new terminal state + ) + + # Now perform the actual translation of the equation. + new_sdfg_term_state = translator( + builder=self, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) + + # Determine the new (tentative) terminal state of the SDFG we are building. + if new_sdfg_term_state is None: + if eqn_state is not self._ctx.terminal_state: + raise RuntimeError("Inconsistent terminal state was detected.") + new_sdfg_term_state = eqn_state + if not self._ctx.validate(): + raise RuntimeError("Detected an invalid SDFG under construction.") + + # Modify terminal root state of 'self' + self._ctx.terminal_state = new_sdfg_term_state + + def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext: + """ + Performs the actual translation of the Jaxpr into an SDFG. + + The function assumes that the context is allocated as well as the + initial variables. The function removes and returns the currently + active translation context. + + Args: + jaxpr: The Jaxpr to translate. + + Notes: + Equations that store into drop variables, i.e. with name `_`, + will be ignored. + """ + nb_translated_eqn: int = 0 + out_var_names: Sequence[str] = () + + for eqn in jaxpr.jaxpr.eqns: + if any(util.is_drop_var(outVar) for outVar in eqn.outvars): + assert all(util.is_drop_var(outVar) for outVar in eqn.outvars) + continue + self._translate_single_eqn(eqn=eqn) + nb_translated_eqn += 1 + + # Handle the output or the case of an empty Jaxpr + if nb_translated_eqn == 0: + out_var_names = self._handle_null_jaxpr(jaxpr) + else: + out_var_names = self.create_jax_var_list( + jaxpr.jaxpr.outvars, prevent_creation=True, handle_literals=False + ) + + self._ctx.out_names = tuple(out_var_names) + + return cast(TranslationContext, self._clear_translation_ctx()) + + def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: + """ + This function is called in case a `Jaxpr` with zero equations is encountered. + + A function with zero equation might still have output, in which case + an input is copied to an output. This function will handle the copying + from the input into the corresponding output variable. It is important + that the function will remove the variables that are used as input and + output from the mapping. + + Returns: + The function returns a tuple containing the SDFG variables that + refer to the output. The order of the list is the same as in + `jaxpr.jaxpr.outvars`. + + Todo: + - Handle the case if if the output is a literal. + + Note: + The function will _not_ update the `out_names` field of the current context. + """ + assert self._ctx.terminal_state is self._ctx.start_state + assert self._ctx.inp_names + assert self._ctx.out_names is None + + # There is not output so we do not have to copy anything around. + if not jaxpr.out_avals: + return [] + + # List of the real output variables. + out_var_names: list[str] = [] + + # If we are here then we are dealing with a nested SDFG/Jaxpr, that has output. + # Because an input also serves as output, the nested SDFG will have a + # connector for the input and one for the output, but both with the same name. + # This will make node validation fail. We have to work around this by + # introducing some fake copies, which will be removed by DaCe later. + for jax_out_var in jaxpr.jaxpr.outvars: + sdfg_in_name: str = self.map_jax_var_to_sdfg(jax_out_var) + + # Now we create a variable that serves as true output, however, since the + # Jax variable is already known we can not update the variable mapping and + # must use another name. + sdfg_out_name = self.add_array( + jax_out_var, name_prefix="_zero_equation_output_for_", update_var_mapping=False + ) + out_var_names.append(sdfg_out_name) + + # Now we perform the copy from the input variable in the newly created + # output variable. + inp_acc = self._start_state.add_read(sdfg_in_name) + out_acc = self._start_state.add_write(sdfg_out_name) + self._start_state.add_nedge( + src=inp_acc, + dst=out_acc, + data=dace.Memlet.from_array(sdfg_in_name, self.get_array(sdfg_in_name)), + ) + + # `jax_out_var` now has, in some sense, two SDFG equivalents, the input, + # that was previously created by `self._create_initial_input()` and the + # `sdfg_out_name` we just created. But we can not add this to the mapping. + # Because it is the best, as in the least worst thing we can do, we remove + # it from the mapping. I am open for different approaches. + self._jax_name_map.pop(jax_out_var) + + return out_var_names + + @property + def _start_state(self) -> dace.SDFGState: + return cast(dace.SDFGState, self._ctx.start_state) + + @property + def _terminal_sdfg_state(self) -> dace.SDFGState: + """Returns the current terminal state of the SDFG under construction.""" + return cast(dace.SDFGState, self._ctx.terminal_state) + + +class TranslationContext: + """ + Translation context used by the `JaxprTranslationBuilder`. + + Internal representation of the builder of an SDFG under construction together + with the needed metadata. Essentially it is an extended version of the + `TranslatedJaxprSDFG`, but carrying an unfinished canonical SDFG. + A user should consider this class as an opaque object, that represents an + invalid `TranslatedJaxprSDFG` object, and the only valid operation a user + can do with it is passing it either to `finalize_translation_context()` or + the `postprocess_jaxpr_sdfg()` function. + + Attributes: + sdfg: The encapsulated SDFG object. + inp_names: A list of the SDFG variables that are used as input + out_names: A list of the SDFG variables that are used as output. + start_state: The first state in the SDFG state machine. + terminal_state: The (currently) last state in the state machine. + + Args: + name: The name of the SDFG. + + Note: + Access of any attribute of this class by an outside user is considered + undefined behaviour. + """ + + sdfg: dace.SDFG + inp_names: tuple[str, ...] | None + out_names: tuple[str, ...] | None + start_state: dace.SDFGState + terminal_state: dace.SDFGState + + def __init__(self, name: str | None = None) -> None: + if isinstance(name, str) and not util.VALID_SDFG_OBJ_NAME.fullmatch(name): + raise ValueError(f"'{name}' is not a valid SDFG name.") + + self.sdfg = dace.SDFG(name=(name or f"unnamed_SDFG_{id(self)}")) + self.inp_names = None + self.out_names = None + self.start_state = self.sdfg.add_state(label="initial_state", is_start_block=True) + self.terminal_state = self.start_state + + def validate(self) -> bool: + """ + Validate internal state of `self`. + + Since the SDFG is under construction it will not be validated, instead the + meta data will be validated. + """ + if self.start_state is not self.sdfg.start_block: + raise dace.sdfg.InvalidSDFGError( + f"Expected to find '{self.start_state}' as start state," + f" but instead found '{self.sdfg.start_block}'.", + self.sdfg, + self.sdfg.node_id(self.start_state), + ) + if {self.terminal_state} != set(self.sdfg.sink_nodes()): + raise dace.sdfg.InvalidSDFGError( + f"Expected to find as terminal state '{self.terminal_state}'," + f" but instead found '{self.sdfg.sink_nodes()}'.", + self.sdfg, + self.sdfg.node_id(self.terminal_state), + ) + return True diff --git a/src/jace/translator/post_translation.py b/src/jace/translator/post_translation.py new file mode 100644 index 0000000..ec445e9 --- /dev/null +++ b/src/jace/translator/post_translation.py @@ -0,0 +1,108 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +This module contains all functions that are related to post processing the SDFG. + +Most of them operate on `TranslatedJaxprSDFG` objects. +Currently they mostly exist for the sake of existing. +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Any + +from jace import translator + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + +def postprocess_jaxpr_sdfg( + trans_ctx: translator.TranslationContext, + fun: Callable, # noqa: ARG001 # Currently unused + call_args: Sequence[Any], # noqa: ARG001 # Currently unused + intree: None, # noqa: ARG001 # Currently unused +) -> translator.TranslatedJaxprSDFG: + """ + Perform the final post processing steps on the `TranslationContext` _in place_. + + The function will perform post processing stages on the context in place. + However, the function will return a decoupled `TranslatedJaxprSDFG` object. + + Args: + trans_ctx: The `TranslationContext` obtained from a `translate_jaxpr()` call. + fun: The original function that was translated. + call_args: The linearized input arguments. + intree: The pytree describing the inputs. + + Todo: + - Setting correct input names (layer that does not depend on JAX). + - Setting the correct strides & storage properties. + - Fixing the scalar input problem on GPU. + """ + # Currently we do nothing except finalizing. + trans_ctx.validate() + + # + # Assume some post processing here. + # + + return finalize_translation_context(trans_ctx, validate=True) + + +def finalize_translation_context( + trans_ctx: translator.TranslationContext, validate: bool = True +) -> translator.TranslatedJaxprSDFG: + """ + Finalizes the supplied translation context `trans_ctx`. + + The function will process the SDFG that is encapsulated inside the context, + i.e. a canonical one, into a proper SDFG, as it is described in + `TranslatedJaxprSDFG`. It is important to realize that this function does + not perform any optimization of the underlying SDFG itself, instead it + prepares an SDFG such that it can be passed to the optimization pipeline. + + The function will not mutate the passed translation context and the output + is always decoupled from its output. + + Args: + trans_ctx: The context that should be finalized. + validate: Call the validate function after the finalizing. + """ + trans_ctx.validate() + if trans_ctx.inp_names is None: + raise ValueError("Input names are not specified.") + if trans_ctx.out_names is None: + raise ValueError("Output names are not specified.") + + # We guarantee decoupling + tsdfg = translator.TranslatedJaxprSDFG( + sdfg=copy.deepcopy(trans_ctx.sdfg), + inp_names=trans_ctx.inp_names, + out_names=trans_ctx.out_names, + ) + + # Make inputs and outputs to globals. + sdfg_arg_names: list[str] = [] + for glob_name in tsdfg.inp_names + tsdfg.out_names: + if glob_name in sdfg_arg_names: + continue + tsdfg.sdfg.arrays[glob_name].transient = False + sdfg_arg_names.append(glob_name) + + # This forces the signature of the SDFG to include all arguments in order they + # appear. If an argument is used as input and output then it is only listed as + # input. + tsdfg.sdfg.arg_names = sdfg_arg_names + + if validate: + tsdfg.validate() + + return tsdfg diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py new file mode 100644 index 0000000..dc3bd74 --- /dev/null +++ b/src/jace/translator/primitive_translator.py @@ -0,0 +1,222 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +""" +Interface for all primitive translators and managing of the global translator registry. + +Todo: + Implement proper context manager for working with the registry. +""" + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Literal, Protocol, cast, overload, runtime_checkable + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + import dace + from jax import core as jax_core + + from jace import translator + +#: Global registry of the active primitive translators. +#: The `dict` maps the name of a primitive to its associated translators. +_PRIMITIVE_TRANSLATORS_REGISTRY: dict[str, translator.PrimitiveTranslator] = {} + + +class PrimitiveTranslatorCallable(Protocol): + """Callable version of the primitive translators.""" + + @abc.abstractmethod + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> dace.SDFGState | None: + """ + Translates the Jax primitive into its SDFG equivalent. + + Before the builder calls this function it will perform the following + preparatory tasks: + - It will allocate the SDFG variables that are used as outputs. Their + names will be passed through the `out_var_names` argument, in the + same order as `eqn.outvars`. + - It will collect the names of the SDFG variables that are used as + inputs and place them in `in_var_names`, in the same order as + `eqn.invars`. If an input argument refers to a literal no SDFG + variable is created for it and `None` is used to indicate this. + - The builder will create variables that are used as output. They are + passed as `out_var_names`, same order as in the equation. + - The builder will create a new terminal state and pass it as `eqn_state` + argument. This state is guaranteed to be empty and + `translator.terminal_sdfg_state is eqn_state` holds. + + Then the primitive translator is called. + Usually a primitive translator should construct the dataflow graph + inside `eqn_state`. However, it is allowed that the primitive translators + creates more states if needed, but this state machinery has to have a + single terminal state, which must be returned and reachable from + `eqn_state`. If the function returns `None` the builder will assume that + primitive translator was able to fully construct the dataflow graph + within `eqn_state`. + + A primitive translator has to use the passed input variables, + `in_var_names` and must write its output into the variables indicated + by `out_var_names`. But it is allowed that a primitive translator + creates intermediate values as needed. To ensure that there are no + collision with further variables, the translator should prefix them, + see the `name_prefix` argument of `JaxprTranslationBuilder.add_array()`. + + Args: + builder: The builder object of the translation. + in_var_names: List of the names of the arrays created inside the + SDFG for the inpts or `None` in case of a literal. + out_var_names: List of the names of the arrays created inside the + SDFG for the outputs. + eqn: The Jax primitive that should be translated. + eqn_state: State into which the primitive`s SDFG representation + should be constructed. + """ + ... + + +@runtime_checkable +class PrimitiveTranslator(PrimitiveTranslatorCallable, Protocol): + """ + Interface for all Jax primitive translators. + + A translator for a primitive translates a single equation of a Jaxpr into + its SDFG equivalent. For satisfying this interface a concrete implementation + must be immutable after construction. + + Primitive translators are simple, but highly specialized objects that are + only able to perform the translation of a single primitive. The overall + translation process itself is managed by a builder object, which also owns + and manage the primitive translators. In the end this implements the + delegation pattern. + + The `jace.translator.register_primitive_translator()` function can be used + to add a translator to the JaCe global registry. + """ + + @property + @abc.abstractmethod + def primitive(self) -> str: + """Returns the name of the Jax primitive that `self` is able to handle.""" + ... + + +@overload +def make_primitive_translator( + primitive: str, primitive_translator: Literal[None] = None +) -> Callable[[translator.PrimitiveTranslatorCallable], translator.PrimitiveTranslator]: ... + + +@overload +def make_primitive_translator( + primitive: str, primitive_translator: translator.PrimitiveTranslatorCallable +) -> translator.PrimitiveTranslator: ... + + +def make_primitive_translator( + primitive: str, primitive_translator: translator.PrimitiveTranslatorCallable | None = None +) -> ( + Callable[[translator.PrimitiveTranslatorCallable], translator.PrimitiveTranslator] + | translator.PrimitiveTranslator +): + """ + Turn `primitive_translator` into a `PrimitiveTranslator` for primitive `primitive`. + + Essentially, this function adds the `primitive` property to a callable, such + that it satisfy the `PrimitiveTranslator` protocol. However, it does not add + it to the registry, for that `register_primitive_translator()` has to be used. + + Notes: + This function can also be used as decorator. + """ + + def wrapper( + primitive_translator: translator.PrimitiveTranslatorCallable, + ) -> translator.PrimitiveTranslator: + if getattr(primitive_translator, "primitive", primitive) != primitive: + raise ValueError( + f"Tried to change the 'primitive' property of '{primitive_translator}' from " + f"'{primitive_translator.primitive}' to '{primitive}'." # type: ignore[attr-defined] + ) + primitive_translator.primitive = primitive # type: ignore[attr-defined] # We define the attribute. + return cast("translator.PrimitiveTranslator", primitive_translator) + + return wrapper if primitive_translator is None else wrapper(primitive_translator) + + +@overload +def register_primitive_translator( + primitive_translator: Literal[None] = None, overwrite: bool = False +) -> Callable[[translator.PrimitiveTranslator], translator.PrimitiveTranslator]: ... + + +@overload +def register_primitive_translator( + primitive_translator: translator.PrimitiveTranslator, overwrite: bool = False +) -> translator.PrimitiveTranslator: ... + + +def register_primitive_translator( + primitive_translator: translator.PrimitiveTranslator | None = None, overwrite: bool = False +) -> ( + translator.PrimitiveTranslator + | Callable[[translator.PrimitiveTranslator], translator.PrimitiveTranslator] +): + """ + Adds a primitive translator to JaCe's global registry. + + The default set of primitives that are used if nothing is specified to to + `jace.jit` are stored inside a global registry. To add a translator to this + registry this function can be used. + + If a translator for `primitive` is already registered an error will be + generated. However, by specifying `overwrite` `primitive_translator` will + replace the current one. + + Args: + primitive_translator: The primitive translator to add to the global registry. + overwrite: Replace the current primitive translator with `primitive_translator`. + + Note: + To add a `primitive` property use the `@make_primitive_translator` decorator. + This function returns `primitive_translator` unmodified, which allows it to be + used as decorator. + """ + + def wrapper( + primitive_translator: translator.PrimitiveTranslator, + ) -> translator.PrimitiveTranslator: + if primitive_translator.primitive in _PRIMITIVE_TRANSLATORS_REGISTRY and not overwrite: + raise ValueError( + f"Explicit override=True needed for primitive '{primitive_translator.primitive}' " + "to overwrite existing one." + ) + _PRIMITIVE_TRANSLATORS_REGISTRY[primitive_translator.primitive] = primitive_translator + return primitive_translator + + return wrapper if primitive_translator is None else wrapper(primitive_translator) + + +def get_registered_primitive_translators() -> dict[str, translator.PrimitiveTranslator]: + """ + Returns a copy of the current state of JaCe's global primitive registry. + + The state returned by this function is compatible to what `jace.jit`'s + `primitive_translators` argument expects. It is important the the returned + object is decoupled from the registry. + """ + return _PRIMITIVE_TRANSLATORS_REGISTRY.copy() diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py new file mode 100644 index 0000000..65f9153 --- /dev/null +++ b/src/jace/translator/primitive_translators/__init__.py @@ -0,0 +1,14 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +"""Module collecting all built-in primitive translators.""" + +from __future__ import annotations + +from .alu_translator import ALUTranslator + + +__all__ = ["ALUTranslator"] diff --git a/src/jace/translator/primitive_translators/alu_translator.py b/src/jace/translator/primitive_translators/alu_translator.py new file mode 100644 index 0000000..d865ee8 --- /dev/null +++ b/src/jace/translator/primitive_translators/alu_translator.py @@ -0,0 +1,286 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This module contains the `ALUTranslator` which translates all arithmetic and logic primitives.""" +# ruff: noqa: W505 PLR0912 C901 PLR0914 PLR0915 D417 + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Final, cast + +import dace +import numpy as np +from jax import core as jax_core +from typing_extensions import override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +class ALUTranslator(translator.PrimitiveTranslator): + """ + This translator handles all arithmetic and logical operations. + + This translator will be reworked soon, it just exists that the initial PR can do anything at all!! + """ + + def __init__(self, prim_name: str, prim_tmpl: str) -> None: + """Initialize the `ALUTranslator`.""" + self._prim_name = prim_name + self._prim_tmpl = prim_tmpl + + @property + @override + def primitive(self) -> str: + return self._prim_name + + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + """ + Perform the translation. + + Deepening on the shapes of the input the function will either create a Tasklet or a mapped Tasklet. + The translator is able to handle broadcasting with NumPy rules. + The function will always perform the translation inside the provided state. + + Args: + builder: The builder object of the translation. + in_var_names: List of the names of the arrays created inside the SDFG for the inpts or 'None' in case of a literal. + out_var_names: List of the names of the arrays created inside the SDFG for the outputs. + eqn: The Jax equation that is translated. + eqn_state: State into which the primitive's SDFG representation is constructed. + """ + assert self._prim_name == eqn.primitive.name + + # Determine what kind of input we got and how we should proceed. + is_scalar = len(util.get_jax_var_shape(eqn.outvars[0])) == 0 + inp_scalars = [len(util.get_jax_var_shape(Inp)) == 0 for i, Inp in enumerate(eqn.invars)] + has_scalars_as_inputs = any(inp_scalars) + has_some_literals = any(x is None for x in in_var_names) + inps_same_shape = all( + util.get_jax_var_shape(eqn.invars[0]) == util.get_jax_var_shape(eqn.invars[i]) + for i in range(1, len(eqn.invars)) + ) + + # We will now look which dimensions have to be broadcasted on which operator. + # I.e. in the dimensions in the lists below there will be no map iteration index. + dims_to_bcastl: list[int] = [] + dims_to_bcastr: list[int] = [] + + # Determine if and how we have to broadcast. + if inps_same_shape or is_scalar: + pass + + elif has_some_literals or has_scalars_as_inputs: + # This is essentially an array plus a scalar, that is eitehr a literal or a variable. + assert (not has_some_literals) or all( + util.get_jax_var_shape(invar) == util.get_jax_var_shape(eqn.outvars[0]) + for (invar, x) in zip(eqn.invars, in_var_names, strict=False) + if x is not None + ) + assert (not has_scalars_as_inputs) or all( + util.get_jax_var_shape(invar) in {util.get_jax_var_shape(eqn.outvars[0]), ()} + for (invar, x) in zip(eqn.invars, in_var_names, strict=False) + if x is not None + ) + + else: + # This is the general broadcasting case + # We assume that both inputs and the output have the same rank but different sizes in each dimension. + # It seems that Jax ensures this. + # We further assume that if the size in a dimension differs then one must have size 1. + # This is the size we broadcast over, i.e. conceptually replicated. + out_shps = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output + inp_shpl = tuple(util.get_jax_var_shape(eqn.invars[0])) # Shape of the left/first input + inp_shpr = tuple( + util.get_jax_var_shape(eqn.invars[1]) + ) # Shape of the right/second input + + if not ((len(inp_shpl) == len(inp_shpr)) and (len(out_shps) == len(inp_shpr))): + raise NotImplementedError("Can not broadcast over different ranks.") + + for dim, (shp_lft, shp_rgt, out_shp) in enumerate(zip(inp_shpl, inp_shpr, out_shps)): + if shp_lft == shp_rgt: + assert out_shp == shp_lft + elif shp_lft == 1: + assert shp_rgt == out_shp + dims_to_bcastl.append(dim) + elif shp_rgt == 1: + assert shp_lft == out_shp + dims_to_bcastr.append(dim) + else: + raise ValueError(f"Invalid shapes in dimension {dim} for broadcasting.") + + # Now we create the Tasklet in which the calculation is performed. + tskl_code: str = self._write_tasklet_code(in_var_names, eqn) + tskl_name: str = eqn.primitive.name + tskl_map_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + tskl_output: tuple[str, dace.Memlet] = None # type: ignore[assignment] + tskl_inputs: list[tuple[str, dace.Memlet] | tuple[None, None]] = [] + + # Generate the Memlets for the input. + for i, dims_to_bcast in zip(range(len(in_var_names)), [dims_to_bcastl, dims_to_bcastr]): + if in_var_names[i] is None: # Literal: No input needed. + tskl_inputs.append((None, None)) + continue + if inp_scalars[i]: # Scalar + assert len(dims_to_bcast) == 0 + i_memlet = dace.Memlet.simple(in_var_names[i], "0") + else: # Array: We may have to broadcast + inputs_: list[str] = [] + for dim, (map_var, _) in enumerate(tskl_map_ranges): + if dim in dims_to_bcast: + inputs_.append("0") + else: + inputs_.append(map_var) + i_memlet = dace.Memlet.simple(in_var_names[i], ", ".join(inputs_)) + del inputs_ + tskl_inputs.append((f"__in{i}", i_memlet)) + + # Now generate the Memlets for the output + if is_scalar: + tskl_output = ("__out0", dace.Memlet.simple(out_var_names[0], "0")) + else: + tskl_output = ( + "__out0", + dace.Memlet.simple(out_var_names[0], ", ".join([X[0] for X in tskl_map_ranges])), + ) + + if is_scalar: + tskl_tasklet = eqn_state.add_tasklet( + tskl_name, + _list_to_dict(tskl_inputs).keys(), + _list_to_dict([tskl_output]).keys(), + tskl_code, + ) + for in_var, (in_connector, in_memlet) in zip(in_var_names, tskl_inputs, strict=False): + if in_var is None: # So access node for literal + continue + eqn_state.add_edge( + eqn_state.add_read(in_var), None, tskl_tasklet, in_connector, in_memlet + ) + eqn_state.add_edge( + tskl_tasklet, + tskl_output[0], + eqn_state.add_write(out_var_names[0]), + None, + tskl_output[1], + ) + else: + eqn_state.add_mapped_tasklet( + name=tskl_name, + map_ranges=_list_to_dict(tskl_map_ranges), + inputs=_list_to_dict(tskl_inputs), + code=tskl_code, + outputs=_list_to_dict([tskl_output]), + external_edges=True, + ) + + return eqn_state + + def _write_tasklet_code( + self, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn + ) -> str: + """ + This function generates the Tasklet code based on a primitive. + + The function will also perform literal substitution and parameter handling. + + Args: + in_var_names: The list of SDFG variables used as input. + """ + t_code = self._prim_tmpl + + # Now we handle Literal substitution + for i, in_var_name in enumerate(in_var_names): + if in_var_name is not None: + continue + + jax_in_var: jax_core.Literal = cast(jax_core.Literal, eqn.invars[i]) + if util.get_jax_var_shape(jax_in_var) == (): + t_val = jax_in_var.val + if isinstance(t_val, np.ndarray): + t_val = jax_in_var.val.max() # I do not know a better way in that case + t_code = t_code.replace(f"__in{i}", str(t_val)) + else: + raise ValueError( + f"Can not handle the literal case of shape: {util.get_jax_var_shape(jax_in_var)}" + ) + + # Now replace the parameters + if len(eqn.params) != 0: + t_code = t_code.format(**eqn.params) + + return t_code + + +def _list_to_dict(inp: Sequence[tuple[None | Any, Any]]) -> dict[Any, Any]: + """ + This method turns a `list` of pairs into a `dict` and applies a `None` filter. + + The function will only include pairs whose key, i.e. first element is not `None`. + """ + return {k: v for k, v in inp if k is not None} + + +# Contains all the templates for ALU operations. +_ALU_OPS_TASKLET_TEMPLATES: Final[dict[str, str]] = { + # Unary operations + "pos": "__out0 = +(__in0)", + "neg": "__out0 = -(__in0)", + "not": "__out0 = not (__in0)", + "floor": "__out0 = floor(__in0)", + "ceil": "__out0 = ceil(__in0)", + "round": "__out0 = round(__in0)", + "abs": "__out0 = abs(__in0)", + "sign": "__out0 = sign(__in0)", + "sqrt": "__out0 = sqrt(__in0)", + "log": "__out0 = log(__in0)", + "exp": "__out0 = exp(__in0)", + "integer_pow": "__out0 = (__in0)**({y})", # 'y' is a parameter of the primitive + "sin": "__out0 = sin(__in0)", + "asin": "__out0 = asin(__in0)", + "cos": "__out0 = cos(__in0)", + "acos": "__out0 = acos(__in0)", + "tan": "__out0 = tan(__in0)", + "atan": "__out0 = atan(__in0)", + "tanh": "__out0 = tanh(__in0)", + # Binary operations + "add": "__out0 = (__in0)+(__in1)", + "add_any": "__out0 = (__in0)+(__in1)", # No idea what makes `add_any` differ from `add` + "sub": "__out0 = (__in0)-(__in1)", + "mul": "__out0 = (__in0)*(__in1)", + "div": "__out0 = (__in0)/(__in1)", + "rem": "__out0 = (__in0)%(__in1)", + "and": "__out0 = (__in0) and (__in1)", + "or": "__out0 = (__in0) or (__in1)", + "pow": "__out0 = (__in0)**(__in1)", + "ipow": "__out0 = (__in0)**(int(__in1))", + "min": "__out0 = min(__in0, __in1)", + "max": "__out0 = max(__in0, __in1)", + "eq": "__out0 = __in0 == __in1", + "ne": "__out0 = __in0 != __in1", + "ge": "__out0 = __in0 >= __in1", + "gt": "__out0 = __in0 > __in1", + "le": "__out0 = __in0 <= __in1", + "lt": "__out0 = __in0 < __in1", +} + +for prim_name, prim_tmpl in _ALU_OPS_TASKLET_TEMPLATES.items(): + translator.register_primitive_translator(ALUTranslator(prim_name, prim_tmpl)) diff --git a/src/jace/translator/translated_jaxpr_sdfg.py b/src/jace/translator/translated_jaxpr_sdfg.py new file mode 100644 index 0000000..afa91ff --- /dev/null +++ b/src/jace/translator/translated_jaxpr_sdfg.py @@ -0,0 +1,71 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Container for storing a translated SDFG.""" + +from __future__ import annotations + +import dataclasses + +import dace + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TranslatedJaxprSDFG: + """ + Encapsulates a translated SDFG with additional the metadata. + + Contrary to the SDFG that is encapsulated inside the `TranslationContext` + object, `self` carries a proper SDFG, however: + - It does not have `__return*` variables, instead all return arguments are + passed by arguments. + - All input arguments are passed through arguments mentioned in `inp_names`, + while the outputs are passed through `out_names`. + - Only variables listed as in/outputs are non transient. + - The order inside `inp_names` and `out_names` is the same as in the original Jaxpr. + - If an input is used as outputs it appears in both `inp_names` and `out_names`. + - Its `arg_names` is set to `inp_names + out_names`, but arguments that are + input and outputs are only listed as inputs. + + The only valid way to obtain a `TranslatedJaxprSDFG` is by passing a + `TranslationContext`, that was in turn constructed by + `JaxprTranslationBuilder.translate_jaxpr()`, to the + `finalize_translation_context()` or preferably to the `postprocess_jaxpr_sdfg()` + function. + + Attributes: + sdfg: The encapsulated SDFG object. + inp_names: A list of the SDFG variables that are used as input + out_names: A list of the SDFG variables that are used as output. + """ + + sdfg: dace.SDFG + inp_names: tuple[str, ...] + out_names: tuple[str, ...] + + def validate(self) -> bool: + """Validate the underlying SDFG.""" + if any(self.sdfg.arrays[inp].transient for inp in self.inp_names): + raise dace.sdfg.InvalidSDFGError( + f"Found transient inputs: {(inp for inp in self.inp_names if self.sdfg.arrays[inp].transient)}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + if any(self.sdfg.arrays[out].transient for out in self.out_names): + raise dace.sdfg.InvalidSDFGError( + f"Found transient outputs: {(out for out in self.out_names if self.sdfg.arrays[out].transient)}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + if self.sdfg.free_symbols: # This is a simplification that makes our life simple. + raise dace.sdfg.InvalidSDFGError( + f"Found free symbols: {self.sdfg.free_symbols}", + self.sdfg, + self.sdfg.node_id(self.sdfg.start_state), + ) + self.sdfg.validate() + return True diff --git a/src/jace/util/__init__.py b/src/jace/util/__init__.py new file mode 100644 index 0000000..ab73e4e --- /dev/null +++ b/src/jace/util/__init__.py @@ -0,0 +1,49 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Global utility package for the jax to dace translator.""" + +from __future__ import annotations + +from .definitions import FORBIDDEN_SDFG_VAR_NAMES, VALID_SDFG_OBJ_NAME, VALID_SDFG_VAR_NAME +from .jax_helper import ( + JaCeVar, + get_jax_var_dtype, + get_jax_var_name, + get_jax_var_shape, + is_tracing_ongoing, + propose_jax_name, + translate_dtype, +) +from .traits import ( + is_array, + is_drop_var, + is_fully_addressable, + is_jax_array, + is_on_device, + is_scalar, +) + + +__all__ = [ + "FORBIDDEN_SDFG_VAR_NAMES", + "VALID_SDFG_OBJ_NAME", + "VALID_SDFG_VAR_NAME", + "JaCeVar", + "get_jax_var_dtype", + "get_jax_var_name", + "get_jax_var_shape", + "is_array", + "is_drop_var", + "is_fully_addressable", + "is_jax_array", + "is_on_device", + "is_scalar", + "is_tracing_ongoing", + "propose_jax_name", + "translate_dtype", +] diff --git a/src/jace/util/dace_helper.py b/src/jace/util/dace_helper.py new file mode 100644 index 0000000..1828fac --- /dev/null +++ b/src/jace/util/dace_helper.py @@ -0,0 +1,144 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements all utility functions that are related to DaCe.""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any + +import dace +import numpy as np +from dace import data as dace_data + +# The compiled SDFG is not available in the dace namespace or anywhere else +# Thus we import it here directly +from dace.codegen.compiled_sdfg import CompiledSDFG + +from jace import util + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from jace import translator + +__all__ = ["CompiledSDFG", "compile_jax_sdfg", "run_jax_sdfg"] + + +def compile_jax_sdfg(tsdfg: translator.TranslatedJaxprSDFG) -> CompiledSDFG: + """Compiles the embedded SDFG and return the resulting `CompiledSDFG` object.""" + if any( # We do not support the DaCe return mechanism + array_name.startswith("__return") + for array_name in tsdfg.sdfg.arrays.keys() # noqa: SIM118 # we can not use `in` because we are also interested in `__return_`! + ): + raise ValueError("Only support SDFGs without '__return' members.") + + # To ensure that the SDFG is compiled and to get rid of a warning we must modify + # some settings of the SDFG. To fake an immutable SDFG, we will restore them later. + sdfg = tsdfg.sdfg + org_sdfg_name = sdfg.name + org_recompile = sdfg._recompile + org_regenerate_code = sdfg._regenerate_code + + try: + # We need to give the SDFG another name, this is needed to prevent a DaCe + # error/warning. This happens if we compile the same lowered SDFG multiple + # times with different options. + sdfg.name = f"{sdfg.name}__comp_{int(time.time() * 1000)}" + + with dace.config.temporary_config(): + sdfg._recompile = True + sdfg._regenerate_code = True + dace.Config.set("compiler", "use_cache", value=False) + csdfg: CompiledSDFG = sdfg.compile() + + finally: + sdfg.name = org_sdfg_name + sdfg._recompile = org_recompile + sdfg._regenerate_code = org_regenerate_code + + return csdfg + + +def run_jax_sdfg( + csdfg: CompiledSDFG, + inp_names: Sequence[str], + out_names: Sequence[str], + call_args: Sequence[Any], + call_kwargs: Mapping[str, Any], +) -> tuple[Any, ...] | Any: + """ + Run the compiled SDFG. + + The function assumes that the SDFG was finalized and then compiled by + `compile_jax_sdfg()`. For running the SDFG you also have to pass the input + names (`inp_names`) and output names (`out_names`) that were inside the + `TranslatedJaxprSDFG` from which `csdfg` was compiled from. + + Args: + csdfg: The `CompiledSDFG` object. + inp_names: List of names of the input arguments. + out_names: List of names of the output arguments. + call_args: All positional arguments of the call. + call_kwargs: All keyword arguments of the call. + + Note: + There is no pytree mechanism jet, thus the return values are returned + inside a `tuple` or in case of one value, directly, in the order + determined by Jax. Furthermore, DaCe does not support scalar return + values, thus they are silently converted into arrays of length 1, the + same holds for inputs. + + Todo: + - Implement non C strides. + """ + sdfg: dace.SDFG = csdfg.sdfg + + if len(call_kwargs) != 0: + raise NotImplementedError("No kwargs are supported yet.") + if len(inp_names) != len(call_args): + raise RuntimeError("Wrong number of arguments.") + if sdfg.free_symbols: # This is a simplification that makes our life simple. + raise NotImplementedError( + f"No externally defined symbols are allowed, found: {sdfg.free_symbols}" + ) + + # Build the argument list that we will pass to the compiled object. + sdfg_call_args: dict[str, Any] = {} + for in_name, in_val in zip(inp_names, call_args, strict=True): + if util.is_scalar(in_val): + # Currently the translator makes scalar into arrays, this has to be + # reflected here + in_val = np.array([in_val]) # noqa: PLW2901 # Loop variable is intentionally modified. + sdfg_call_args[in_name] = in_val + + for out_name, sdfg_array in ((out_name, sdfg.arrays[out_name]) for out_name in out_names): + if out_name in sdfg_call_args: + if util.is_jax_array(sdfg_call_args[out_name]): + # Jax arrays are immutable, so they can not be return values too. + raise ValueError("Passed a Jax array as output.") + else: + sdfg_call_args[out_name] = dace_data.make_array_from_descriptor(sdfg_array) + + assert len(sdfg_call_args) == len(csdfg.argnames), ( + "Failed to construct the call arguments," + f" expected {len(csdfg.argnames)} but got {len(call_args)}." + f"\nExpected: {csdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" + ) + + # Calling the SDFG + with dace.config.temporary_config(): + dace.Config.set("compiler", "allow_view_arguments", value=True) + csdfg(**sdfg_call_args) + + # Handling the output (pytrees are missing) + if not out_names: + return None + ret_val: tuple[Any] = tuple(sdfg_call_args[out_name] for out_name in out_names) + return ret_val[0] if len(out_names) == 1 else ret_val diff --git a/src/jace/util/definitions.py b/src/jace/util/definitions.py new file mode 100644 index 0000000..13daf7a --- /dev/null +++ b/src/jace/util/definitions.py @@ -0,0 +1,38 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Definitions of patterns for valid names.""" + +from __future__ import annotations + +import re +from typing import Final + + +#: Valid name for an SDFG variable. +VALID_SDFG_VAR_NAME: re.Pattern = re.compile("[a-zA-Z_][a-zA-Z0-9_]*") + +#: Valid name for an SDFG itself, includes `SDFGState` objects. +VALID_SDFG_OBJ_NAME: re.Pattern = re.compile("[a-zA-Z_][a-zA-Z0-9_]*") + + +# fmt: off +#: This is a set of all names that are invalid SDFG names. +FORBIDDEN_SDFG_VAR_NAMES: Final[set[str]] = { + # These should be most of the C++ keywords, it is more important to have the short + # ones. Taken from 'https://learn.microsoft.com/en-us/cpp/cpp/keywords-cpp?view=msvc-170' + "alignas", "alignof", "and", "asm", "auto", "bitand", "bitor", "bool", "break", "case", + "catch", "char", "class", "compl", "concept", "const", "consteval", "constexpr", + "constinit", "continue", "decltype", "default", "delete", "directive", "do", "double", + "else", "enum", "explicit", "export", "extern", "false", "float", "for", "friend", + "goto", "if", "inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", + "nullptr", "operator", "or", "private", "protected", "public", "register", "requires", + "return", "short", "signed", "sizeof", "static", "struct", "switch", "template", "this", + "throw", "true", "try", "typedef", "typeid", "typename", "union", "unsigned", "using", + "virtual", "void", "volatile", "while", "xor", "std", "", +} +# fmt: on diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py new file mode 100644 index 0000000..175671f --- /dev/null +++ b/src/jace/util/jax_helper.py @@ -0,0 +1,217 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Implements all utility functions that are related to Jax. + +Most of the functions defined here allow an unified access to Jax' internal in +a consistent and stable way. +""" + +from __future__ import annotations + +import dataclasses +import itertools +from typing import TYPE_CHECKING, Any + +import dace +import jax.core as jax_core +import numpy as np + +from jace import util + + +if TYPE_CHECKING: + from collections.abc import Mapping + + +@dataclasses.dataclass(repr=True, frozen=True, eq=False) +class JaCeVar: + """ + Replacement for the `jax.Var` class. + + This class can be seen as some kind of substitute `jax.core.Var`. The main + intention of this class is as an internal representation of values, as they + are used in Jax, but without the Jax machinery. As abstract values in Jax + this class has a datatype, which is a `dace.typeclass` instance and a shape. + In addition it has an optional name, which allows to create variables with + a certain name using `JaxprTranslationBuilder.add_array()`. + + If it is expected that code must handle both Jax variables and `JaCeVar` + then the `get_jax_var_*()` functions should be used. + + Args: + shape: The shape of the variable. + dtype: The dace datatype of the variable. + name: Name the variable should have, optional. + + Note: + If the name of a `JaCeVar` is '_' it is considered a drop variable. The + definitions of `__hash__` and `__eq__` are in accordance with how Jax + variable works. + + Todo: + - Add support for strides. + """ + + shape: tuple[int | dace.symbol | str, ...] + dtype: dace.typeclass + name: str | None = None + + def __post_init__(self) -> None: + """Sanity checks.""" + if self.name is not None and ( + (not util.VALID_SDFG_VAR_NAME.fullmatch(self.name)) + or self.name in util.FORBIDDEN_SDFG_VAR_NAMES + ): + raise ValueError(f"Supplied the invalid name '{self.name}'.") + if not isinstance(self.dtype, dace.typeclass): # No typechecking yet. + raise TypeError(f"'dtype' is not a 'dace.typeclass' but '{type(self.dtype).__name__}'.") + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, JaCeVar): + return NotImplemented + return id(self) == id(other) + + +def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: + """Returns the name of `jax_var` as a string.""" + match jax_var: + case jax_core.DropVar(): + return "_" + case JaCeVar(): + return jax_var.name if jax_var.name else f"jax{id(jax_var)}" + case jax_core.Var(): + # This is not how the pretty printer works nor `jax.Var.__repr__()`, + # but leads to stable and valid names. + return f"jax{jax_var.count}{jax_var.suffix}" + case jax_core.Literal(): + raise TypeError("Can not derive a name from a Jax Literal.") + case _: + raise TypeError( + f"Does not know how to transform '{jax_var}' (type: '{type(jax_var).__name__}') " + "into a string." + ) + + +def get_jax_var_shape(jax_var: jax_core.Atom | JaCeVar) -> tuple[int | dace.symbol | str, ...]: + """Returns the shape of `jax_var`.""" + match jax_var: + case jax_core.Var() | jax_core.Literal(): + assert hasattr(jax_var.aval, "shape") # To silences mypy. + return jax_var.aval.shape + case JaCeVar(): + return jax_var.shape + case _: + raise TypeError(f"'get_jax_var_shape()` is not implemented for '{type(jax_var)}'.") + + +def get_jax_var_dtype(jax_var: jax_core.Atom | JaCeVar) -> dace.typeclass: + """Returns the DaCe equivalent of `jax_var`s datatype.""" + match jax_var: + case jax_core.Var() | jax_core.Literal(): + assert hasattr(jax_var.aval, "dtype") # To silences mypy. + return translate_dtype(jax_var.aval.dtype) + case JaCeVar(): + return jax_var.dtype + case _: + raise TypeError(f"'get_jax_var_dtype()` is not implemented for '{type(jax_var)}'.") + + +def is_tracing_ongoing(*args: Any, **kwargs: Any) -> bool: + """ + Test if tracing is ongoing. + + While a return value `True` guarantees that a translation is ongoing, a + value of `False` does not guarantees that no tracing is ongoing. + """ + # The current implementation only checks the arguments if it contains tracers. + if (len(args) == 0) and (len(kwargs) == 0): + raise RuntimeError("Failed to determine if tracing is ongoing.") + return any(isinstance(x, jax_core.Tracer) for x in itertools.chain(args, kwargs.values())) + + +def translate_dtype(dtype: Any) -> dace.typeclass: + """Turns a Jax datatype into a DaCe datatype.""" + if dtype is None: + raise NotImplementedError # Handling a special case in DaCe. + if isinstance(dtype, dace.typeclass): + return dtype + try: + return dace.typeclass(dtype) + except (NameError, KeyError): + pass + return dace.dtype_to_typeclass(getattr(dtype, "type", dtype)) + + +def propose_jax_name( + jax_var: jax_core.Atom | JaCeVar, + jax_name_map: Mapping[jax_core.Var | JaCeVar, str] | None = None, +) -> str: + """ + Proposes a variable name for `jax_var`. + + If `jax_name_map` is `None` the function will fallback to + `get_jax_var_name(jax_var)`. If `jax_name_map` is supplied the function + will: + - If `jax_var` is stored inside `jax_name_map`, returns the mapped value. + - If `jax_var` is a `JaCeVar` with a set `.name` property that name will + be returned. + - Otherwise the function will generate a new name in a similar way to the + pretty printer of Jaxpr. + + Args: + jax_var: The variable for which a name to propose. + jax_name_map: A mapping of all Jax variables that were already named. + + Note: + The function guarantees that the returned name passes `VALID_SDFG_VAR_NAME` + test and that the name is not inside `util.FORBIDDEN_SDFG_VAR_NAMES`. + Dropped variables will always be named `'_'`. + """ + if isinstance(jax_var, jax_core.Literal): + raise TypeError(f"Can not propose a name for literal '{jax_var}'.") + if util.is_drop_var(jax_var) or (jax_name_map is None): + return get_jax_var_name(jax_var) + if jax_var in jax_name_map: + return jax_name_map[jax_var] + if isinstance(jax_var, JaCeVar) and (jax_var.name is not None): + return jax_var.name + + # This code is taken from Jax so it will generate similar ways, the difference is + # that we do the counting differently. + # Note that `z` is followed by `ba` and not `aa` as it is in Excel. + c = len(jax_name_map) + jax_name = "" + while len(jax_name) == 0 or c != 0: + c, i = c // 26, c % 26 + jax_name = chr(97 + i) + jax_name + jax_name += getattr(jax_var, "suffix", "") + + if jax_name in util.FORBIDDEN_SDFG_VAR_NAMES: + jax_name = f"__jace_forbidden_{jax_name}" + return jax_name + + +def get_jax_literal_value(lit: jax_core.Atom) -> bool | float | int | np.generic: + """ + Returns the value a literal is wrapping. + + The function guarantees to return a scalar value. + """ + if not isinstance(lit, jax_core.Literal): + raise TypeError(f"Can only extract literals not '{type(lit)}'.") + val = lit.val + if isinstance(val, np.ndarray): + assert val.shape == () + return val.max() + if isinstance(val, (bool, float, int)): + return val + raise TypeError(f"Failed to extract value from '{lit}'.") diff --git a/src/jace/util/traits.py b/src/jace/util/traits.py new file mode 100644 index 0000000..a8e6bc8 --- /dev/null +++ b/src/jace/util/traits.py @@ -0,0 +1,93 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Contains all traits function needed inside JaCe.""" + +from __future__ import annotations + +from typing import Any, TypeGuard + +import dace +import jax +import numpy as np +from jax import core as jax_core + +from jace import util + + +def is_drop_var(jax_var: jax_core.Atom | util.JaCeVar) -> TypeGuard[jax_core.DropVar]: + """Tests if `jax_var` is a drop variable.""" + if isinstance(jax_var, jax_core.DropVar): + return True + if isinstance(jax_var, util.JaCeVar): + return jax_var.name == "_" if jax_var.name else False + return False + + +def is_jax_array(obj: Any) -> TypeGuard[jax.Array]: + """ + Tests if `obj` is a Jax array. + + Note: + Jax arrays are special as they can not be mutated. Furthermore, they always + allocate on the CPU _and_ on the GPU, if present. + """ + return isinstance(obj, jax.Array) + + +def is_array(obj: Any) -> bool: + """Identifies arrays, this also includes Jax arrays.""" + return dace.is_array(obj) or is_jax_array(obj) + + +def is_scalar(obj: Any) -> bool: + """Tests if `obj` is a scalar.""" + # These are the type known to DaCe; Taken from `dace.dtypes`. + known_types = { + bool, + int, + float, + complex, + np.intc, + np.uintc, + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float16, + np.float32, + np.float64, + np.complex64, + np.complex128, + np.longlong, + np.ulonglong, + } + return type(obj) in known_types + + +def is_on_device(obj: Any) -> bool: + """ + Tests if `obj` is on a device. + + Jax arrays are always on the CPU and GPU (if there is one). Thus for Jax + arrays this function is more of a test, if there is a GPU at all. + """ + if is_jax_array(obj): + return hasattr(obj, "__cuda_array_interface__") + return dace.is_gpu_array(obj) + + +def is_fully_addressable(obj: Any) -> bool: + """Tests if `obj` is fully addressable, i.e. is only on this host.""" + if is_jax_array(obj): + return obj.is_fully_addressable + return True diff --git a/src/jace/util/translation_cache.py b/src/jace/util/translation_cache.py new file mode 100644 index 0000000..f6366bd --- /dev/null +++ b/src/jace/util/translation_cache.py @@ -0,0 +1,274 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +This module contains the functionality related to the compilation cache of the stages. + +The cache currently caches the lowering, i.e. the result of `JaCeWrapped.lower()` +and the compilation, i.e. `JaCeLowered.compile()`. The caches are on a per stage +basis and not on a per instant basis. To make a stage cacheable, it must be +derived from `CachingStage` and its transition function must be decoration with +`@cached_transition`. +""" + +from __future__ import annotations + +import abc +import collections +import dataclasses +import functools +from collections.abc import Callable, Hashable +from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, cast + +import dace +from jax import core as jax_core + +from jace import util + + +if TYPE_CHECKING: + from jace import stages + +#: Caches used to store the state transition. +#: The caches are on a per stage and not per instant basis. +_TRANSLATION_CACHES: dict[type[CachingStage], StageCache] = {} + + +# Denotes the stage that follows the current one. +# Used by the `NextStage` Mixin. +NextStage = TypeVar("NextStage", bound="stages.Stage") + + +class CachingStage(Generic[NextStage]): + """ + Annotates a stage whose transition to the next stage is cacheable. + + To make the transition of a stage cacheable, the stage must be derived from + this class, and its initialization must call `CachingStage.__init__()`. + Furthermore, its transition function must be annotated by the + `@cached_transition` decorator. + + A class must implement the `_make_call_description()` to compute an abstract + description of the call. This is needed to operate the cache to store the + stage transitions. + + Notes: + The `__init__()` function must explicitly be called to fully setup `self`. + + Todo: + - Handle eviction from the cache due to collecting of unused predecessor stages. + """ + + _cache: StageCache[NextStage] + + def __init__(self) -> None: + self._cache = get_cache(self) + + @abc.abstractmethod + def _make_call_description( + self: CachingStage, *args: Any, **kwargs: Any + ) -> StageTransformationSpec: + """Generates the key that is used to store/locate the call in the cache.""" + ... + + +# Type annotation for the caching. +P = ParamSpec("P") +TransitionFunction = Callable[Concatenate[CachingStage[NextStage], P], NextStage] +CachingStageType = TypeVar("CachingStageType", bound=CachingStage) + + +def cached_transition( + transition: Callable[Concatenate[CachingStageType, P], NextStage], +) -> Callable[Concatenate[CachingStage[NextStage], P], NextStage]: + """ + Decorator for making the transition function of the stage cacheable. + + In order to work, the stage must be derived from `CachingStage`. For computing + the key of a call the function will use the `_make_call_description()` + function of the cache. + + Todo: + - Implement a way to temporary disable the cache. + """ + + @functools.wraps(transition) + def transition_wrapper(self: CachingStageType, *args: P.args, **kwargs: P.kwargs) -> NextStage: + key: StageTransformationSpec = self._make_call_description(*args, **kwargs) + if key in self._cache: + return self._cache[key] + next_stage = transition(self, *args, **kwargs) + self._cache[key] = next_stage + return next_stage + + return cast(TransitionFunction, transition_wrapper) + + +def clear_translation_cache() -> None: + """Clear all caches associated to translation.""" + for stage_caches in _TRANSLATION_CACHES.values(): + stage_caches.clear() + + +def get_cache(stage: CachingStage) -> StageCache: + """Returns the cache that should be used for `stage`.""" + stage_type = type(stage) + if stage_type not in _TRANSLATION_CACHES: + _TRANSLATION_CACHES[stage_type] = StageCache() + return _TRANSLATION_CACHES[stage_type] + + +@dataclasses.dataclass(frozen=True) +class _AbstractCallArgument: + """ + Class to represent a single argument to the transition function in an abstract way. + + As noted in `StageTransformationSpec` there are two ways to describe an + argument, either by using its concrete value or an abstract description, + which is similar to tracers in Jax. This class represents the second way. + To create an instance you should use `_AbstractCallArgument.from_value()`. + + Its description is limited to scalars and arrays. To describe more complex + types, they should be processed by pytrees first. + + Attributes: + shape: In case of an array its shape, in case of a scalar the empty tuple. + dtype: The DaCe type of the argument. + strides: The strides of the argument, or `None` if they are unknown or a scalar. + storage: The storage type where the argument is stored. + """ + + shape: tuple[int, ...] + dtype: dace.typeclass + strides: tuple[int, ...] | None + storage: dace.StorageType + + @classmethod + def from_value(cls, value: Any) -> _AbstractCallArgument: + """Construct an `_AbstractCallArgument` from `value`.""" + if not util.is_fully_addressable(value): + raise NotImplementedError("Distributed arrays are not addressed yet.") + if isinstance(value, jax_core.Literal): + raise TypeError("Jax Literals are not supported as cache keys.") + + if util.is_array(value): + if util.is_jax_array(value): + value = value.__array__() # Passing `copy=False` leads to error in NumPy. + shape = value.shape + dtype = util.translate_dtype(value.dtype) + strides = getattr(value, "strides", None) + # Is `CPU_Heap` always okay? There would also be `CPU_Pinned`. + storage = ( + dace.StorageType.GPU_Global + if util.is_on_device(value) + else dace.StorageType.CPU_Heap + ) + + return cls(shape=shape, dtype=dtype, strides=strides, storage=storage) + + if util.is_scalar(value): + shape = () + dtype = util.translate_dtype(type(value)) + strides = None + # Scalar arguments are always on the CPU and never on the GPU. + storage = dace.StorageType.CPU_Heap + + return cls(shape=shape, dtype=dtype, strides=strides, storage=storage) + + raise TypeError(f"Can not make 'an abstract description from '{type(value).__name__}'.") + + +#: This type is the abstract description of a function call. +#: It is part of the key used in the cache. +CallArgsSpec: TypeAlias = tuple[ + _AbstractCallArgument | Hashable | tuple[str, _AbstractCallArgument | Hashable], ... +] + + +@dataclasses.dataclass(frozen=True) +class StageTransformationSpec: + """ + Represents the entire call to a state transformation function of a stage. + + State transition functions are annotated with `@cached_transition` and their + result may be cached. They key to locate them inside the cache is represented + by this class and computed by the `CachingStage._make_call_description()` + function. The actual key is consists of two parts, `stage_id` and `call_args`. + + Args: + stage_id: Origin of the call, for which the id of the stage object should + be used. + call_args: Description of the arguments of the call. There are two ways + to describe the arguments: + - Abstract description: In this way, the actual value of the argument + is irrelevant, only the structure of them are important, similar + to the tracers used in Jax. + - Concrete description: Here one caches on the actual value of the + argument. The only requirement is that they can be hashed. + """ + + stage_id: int + call_args: CallArgsSpec + + +# Denotes the stage that is stored inside the cache. +StageType = TypeVar("StageType", bound="stages.Stage") + + +class StageCache(Generic[StageType]): + """ + Simple LRU cache to cache the results of the stage transition function. + + Args: + size: The size of the cache, defaults to 256. + """ + + # The most recently used entry is at the end of the `OrderedDict`. + _memory: collections.OrderedDict[StageTransformationSpec, StageType] + _size: int + + def __init__(self, size: int = 256) -> None: + self._memory = collections.OrderedDict() + self._size = size + + def __contains__(self, key: StageTransformationSpec) -> bool: + return key in self._memory + + def __getitem__(self, key: StageTransformationSpec) -> StageType: + if key not in self: + raise KeyError(f"Key '{key}' is unknown.") + self._memory.move_to_end(key, last=True) + return self._memory[key] + + def __setitem__(self, key: StageTransformationSpec, res: StageType) -> None: + if key in self: + self._memory.move_to_end(key, last=True) + self._memory[key] = res + else: + if len(self._memory) == self._size: + self.popitem(None) + self._memory[key] = res + + def popitem(self, key: StageTransformationSpec | None) -> None: + """ + Evict `key` from `self`. + + If `key` is `None` the oldest entry is evicted. + """ + if len(self._memory) == 0: + return + if key is None: + self._memory.popitem(last=False) + elif key in self: + self._memory.move_to_end(key, last=False) + self._memory.popitem(last=False) + + def clear(self) -> None: # noqa: D102 # Missing description. + self._memory.clear() + + def __repr__(self) -> str: + return f"StageCache({len(self._memory)} / {self._size} || {', '.join('[' + repr(k) + ']' for k in self._memory)})" diff --git a/tests/__init__.py b/tests/__init__.py index 116302a..a5e868c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,3 +4,10 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause + +"""JaCe's tests. + + +Note: + This is work in progress. +""" diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 0000000..bc0e44c --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,256 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the caching infrastructure. +.""" + +from __future__ import annotations + +import itertools as it +import re + +import numpy as np +import pytest + +import jace +from jace import optimization, stages +from jace.util import translation_cache as tcache + + +@pytest.fixture(autouse=True) +def _clear_translation_cache(): + """Decorator that clears the translation cache. + + Ensures that a function finds an empty cache and clears up afterwards. + """ + tcache.clear_translation_cache() + yield + tcache.clear_translation_cache() + + +def test_caching_same_sizes() -> None: + """The behaviour of the cache if same sizes are used, in two different functions.""" + + # Counter for how many time it was lowered. + lowering_cnt = [0] + + # This is the pure Python function. + def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A * B + + # this is the wrapped function. + @jace.jit + def wrapped(A, B): + lowering_cnt[0] += 1 + return testee(A, B) + + # First batch of arguments. + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + # The second batch of argument, same structure but different values. + AA = A + 1.0362 + BB = B + 0.638956 + + # Now let's lower it once directly and call it. + lowered: stages.JaCeLowered = wrapped.lower(A, B) + compiled: stages.JaCeCompiled = lowered.compile() + assert lowering_cnt[0] == 1 + assert np.allclose(testee(A, B), compiled(A, B)) + + # Now lets call the wrapped object directly, since we already did the lowering + # no longering (and compiling) is needed. + assert np.allclose(testee(A, B), wrapped(A, B)) + assert lowering_cnt[0] == 1 + + # Now lets call it with different objects, that have the same structure. + # Again no lowering should happen. + assert np.allclose(testee(AA, BB), wrapped(AA, BB)) + assert wrapped.lower(AA, BB) is lowered + assert wrapped.lower(A, B) is lowered + assert lowering_cnt[0] == 1 + + +def test_caching_different_sizes(): + """The behaviour of the cache if different sizes where used.""" + + # Counter for how many time it was lowered. + lowering_cnt = [0] + + # This is the wrapped function. + @jace.jit + def wrapped(A, B): + lowering_cnt[0] += 1 + return A * B + + # First size of arguments + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + # Second size of arguments + C = np.arange(16, dtype=np.float64).reshape((4, 4)) + D = np.full((4, 4), 10, dtype=np.float64) + + # Now lower the function once for each. + lowered1 = wrapped.lower(A, B) + lowered2 = wrapped.lower(C, D) + assert lowering_cnt[0] == 2 + assert lowered1 is not lowered2 + + # Now also check if the compilation works as intended + compiled1 = lowered1.compile() + compiled2 = lowered2.compile() + assert lowering_cnt[0] == 2 + assert compiled1 is not compiled2 + + +@pytest.mark.skip("'convert_element_type' primitive is not implemented.") +def test_caching_different_structure() -> None: + """Now tests if we can handle multiple arguments with different structures. + + Todo: + - Extend with strides once they are part of the cache. + """ + + # This is the wrapped function. + lowering_cnt = [0] + + @jace.jit + def wrapped(A, B): + lowering_cnt[0] += 1 + return A * 4.0, B + 2.0 + + A = np.full((4, 30), 10, dtype=np.float64) + B = np.full((4, 3), 10, dtype=np.float64) + C = np.full((5, 3), 14, dtype=np.float64) + D = np.full((6, 3), 14, dtype=np.int64) + + # These are the known lowerings. + lowerings: dict[tuple[int, int], stages.JaCeLowered] = {} + lowering_ids: set[int] = set() + # These are the known compilations. + compilations: dict[tuple[int, int], stages.JaCeCompiled] = {} + compiled_ids: set[int] = set() + + # Generating the lowerings + for arg1, arg2 in it.permutations([A, B, C, D], 2): + lower = wrapped.lower(arg1, arg2) + compiled = lower.compile() + assert id(lower) not in lowering_ids + assert id(compiled) not in compiled_ids + lowerings[id(arg1), id(arg2)] = lower + lowering_ids.add(id(lower)) + compilations[id(arg1), id(arg2)] = compiled + compiled_ids.add(id(compiled)) + + # Now check if they are still cached. + for arg1, arg2 in it.permutations([A, B, C, D], 2): + lower = wrapped.lower(arg1, arg2) + clower = lowerings[id(arg1), id(arg2)] + assert clower is lower + + compiled1 = lower.compile() + compiled2 = clower.compile() + ccompiled = compilations[id(arg1), id(arg2)] + assert compiled1 is compiled2 + assert compiled1 is ccompiled + + +def test_caching_compilation() -> None: + """Tests the compilation cache, this is just very simple.""" + + @jace.jit + def jaceWrapped(A: np.ndarray, B: np.ndarray) -> np.ndarray: + C = A * B + D = C + A + E = D + B # Just enough state. + return A + B + C + D + E + + # These are the argument + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + # Now we lower it. + jaceLowered = jaceWrapped.lower(A, B) + + # Compiling it without any information. + optiCompiled = jaceLowered.compile() + + # This should be the same as passing the defaults directly. + assert optiCompiled is jaceLowered.compile(optimization.DEFAULT_OPTIMIZATIONS) + + # Also if we pass the empty dict, we should get the default. + assert optiCompiled is jaceLowered.compile({}) + + # Now we disable all optimizations + unoptiCompiled = jaceLowered.compile(optimization.NO_OPTIMIZATIONS) + + # Because of the way how things work the optimized must have more than the + # unoptimized. If there is sharing, then this would not be the case. + assert unoptiCompiled is not optiCompiled + assert optiCompiled._csdfg.sdfg.number_of_nodes() == 1 + assert optiCompiled._csdfg.sdfg.number_of_nodes() < unoptiCompiled._csdfg.sdfg.number_of_nodes() + + +def test_caching_dtype(): + """Tests if the data type is properly included in the test.""" + + lowering_cnt = [0] + + @jace.jit + def testee(A: np.ndarray) -> np.ndarray: + lowering_cnt[0] += 1 + return A + A + + dtypes = [np.float64, np.float32, np.int32, np.int64] + shape = (10, 10) + + for i, dtype in enumerate(dtypes): + A = np.array((np.random.random(shape) - 0.5) * 10, dtype=dtype) # noqa: NPY002 + + assert lowering_cnt[0] == i + _ = testee(A) + assert lowering_cnt[0] == i + 1 + + assert np.allclose(testee(A), 2 * A) + assert lowering_cnt[0] == i + 1 + + +def test_caching_strides() -> None: + """Test if the cache detects a change in strides.""" + + @jace.jit + def wrapped(A: np.ndarray) -> np.ndarray: + return A + 10.0 + + shape = (10, 100, 1000) + C = np.array( + (np.random.random(shape) - 0.5) * 10, # noqa: NPY002 + order="C", + dtype=np.float64, + ) + F = np.array(C, copy=True, order="F") + + # First we compile run it with C strides. + C_lower = wrapped.lower(C) + C_res = wrapped(C) + + # Now we run it with FORTRAN strides. + # However, this does not work because we do not support strides at all. + # But the cache is aware of this, which helps catch some nasty bugs. + F_lower = None # Remove later + F_res = C_res.copy() # Remove later + with pytest.raises( # noqa: PT012 # Multiple calls + expected_exception=NotImplementedError, + match=re.escape("Currently can not yet handle strides beside 'C_CONTIGUOUS'."), + ): + F_lower = wrapped.lower(F) + F_res = wrapped(F) + assert F_lower is None # Remove later. + assert C_res is not F_res + assert np.allclose(F_res, C_res) + assert F_lower is not C_lower diff --git a/tests/test_decorator.py b/tests/test_decorator.py new file mode 100644 index 0000000..7971b29 --- /dev/null +++ b/tests/test_decorator.py @@ -0,0 +1,95 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for the jit decorator. + +Also see the `test_jax_api.py` test file, that tests composability. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +import jace +from jace.util import translation_cache as tcache + + +@pytest.fixture(autouse=True) +def _clear_translation_cache(): + """Decorator that clears the translation cache. + + Ensures that a function finds an empty cache and clears up afterwards. + + Todo: + Should be used _everywhere_. + """ + + tcache.clear_translation_cache() + yield + tcache.clear_translation_cache() + + +def test_decorator_individually(): + """Tests the compilation steps individually.""" + + def testee_(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A + B + + lowering_cnt = [0] + + @jace.jit + def testee(A, B): + lowering_cnt[0] += 1 + return testee_(A, B) + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + lowered = testee.lower(A, B) + compiled = lowered.compile() + + ref = testee_(A, B) + res = compiled(A, B) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + assert lowering_cnt[0] == 1 + + +def test_decorator_one_go(): + """Tests the compilation steps in one go.""" + + def testee_(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A + B + + lowering_cnt = [0] + + @jace.jit + def testee(A, B): + lowering_cnt[0] += 1 + return testee_(A, B) + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + ref = testee_(A, B) + res = testee(A, B) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + assert lowering_cnt[0] == 1 + + +def test_decorator_wrapped(): + """Tests if some properties are set correctly.""" + + def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A * B + + wrapped = jace.jit(testee) + + assert wrapped.wrapped_fun is testee + assert wrapped.__wrapped__ is testee diff --git a/tests/test_empty_jaxpr.py b/tests/test_empty_jaxpr.py new file mode 100644 index 0000000..36e8247 --- /dev/null +++ b/tests/test_empty_jaxpr.py @@ -0,0 +1,48 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for empty jaxprs. +.""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest + +import jace + + +def test_empty_array(): + @jace.jit + def testee(A: np.ndarray) -> np.ndarray: + return A + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + + assert np.all(testee(A) == A) + + +def test_empty_scalar(): + @jace.jit + def testee(A: float) -> float: + return A + + A = np.pi + + assert np.all(testee(A) == A) + + +@pytest.mark.skip(reason="Nested Jaxpr are not handled.") +def test_empty_nested(): + @jace.jit + def testee3(A: float) -> float: + return jax.jit(lambda A: A)(A) + + A = np.pi + + assert np.all(testee3(A) == A) diff --git a/tests/test_jax_api.py b/tests/test_jax_api.py new file mode 100644 index 0000000..0c1905d --- /dev/null +++ b/tests/test_jax_api.py @@ -0,0 +1,200 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests the compatibility of the JaCe api to Jax.""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + + +np.random.seed(42) # noqa: NPY002 # random generator + + +def test_jit(): + """Simple add function.""" + + def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A + B + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + jax_testee = jax.jit(testee) + jace_testee = jace.jit(testee) + + ref = jax_testee(A, B) + res = jace_testee(A, B) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + + +def test_composition_itself(): + """Tests if JaCe is composable with itself.""" + + # Pure Python functions + def f_ref(x): + return jnp.sin(x) + + def df_ref(x): + return jnp.cos(x) + + def ddf_ref(x): + return -jnp.sin(x) + + # Annotated functions. + + @jace.jit + def f(x): + return f_ref(x) + + @jace.jit + def df(x): + return jace.grad(f)(x) + + @jace.jit + @jace.grad + def ddf(x): + return df(x) + + assert all(isinstance(x, jace.stages.JaCeWrapped) for x in [f, df, ddf]) + + x = 1.0 + for fun, fref in zip([f, df, ddf], [f_ref, df_ref, ddf_ref]): + ref = fref(x) + res = fun(x) + assert np.allclose(ref, res), f"f: Expected '{ref}', got '{res}'." + + +@pytest.mark.skip(reason="Nested Jaxpr are not handled.") +def test_composition_with_jax(): + """Tests if JaCe can interact with Jax and vice versa.""" + + def base_fun(A, B, C): + return A + B * jnp.sin(C) - A * B + + @jace.jit + def jace_fun(A, B, C): + return jax.jit(base_fun)(A, B, C) + + def jax_fun(A, B, C): + return jace.jit(base_fun)(A, B, C) + + A, B, C = (np.random.random((10, 3, 50)) for _ in range(3)) # noqa: NPY002 # random generator + + assert np.allclose(jace_fun(A, B, C), jax_fun(A, B, C)) + + +@pytest.mark.skip(reason="Nested Jaxpr are not handled.") +def test_composition_with_jax_2(): + """Second test if JaCe can interact with Jax and vice versa.""" + + @jax.jit + def f1_jax(A, B): + return A + B + + @jace.jit + def f2_jace(A, B, C): + return f1_jax(A, B) - C + + @jax.jit + def f3_jax(A, B, C, D): + return f2_jace(A, B, C) * D + + @jace.jit + def f3_jace(A, B, C, D): + return f3_jax(A, B, C, D) + + A, B, C, D = (np.random.random((10, 3, 50)) for _ in range(4)) # noqa: NPY002 # random generator + + ref = ((A + B) - C) * D + res_jax = f3_jax(A, B, C, D) + res_jace = f3_jace(A, B, C, D) + + assert np.allclose(ref, res_jax), "Jax failed." + assert np.allclose(ref, res_jace), "JaCe Failed." + + +def test_grad_annotation_direct(): + """Test if `jace.grad` works directly.""" + + def f(x): + return jnp.sin(jnp.exp(jnp.cos(x**2))) + + @jax.grad + def jax_ddf(x): + return jax.grad(f)(x) + + @jax.jit + def jace_ddf(x): + return jace.grad(jace.grad(f))(x) + + # These are the random numbers where we test + Xs = (np.random.random(10) - 0.5) * 10 # noqa: NPY002 # Random number generator + + for i in range(Xs.shape[0]): + x = Xs[i] + res = jace_ddf(x) + ref = jax_ddf(x) + assert np.allclose(res, ref) + + +def test_grad_control_flow(): + """Tests if `grad` and controlflow works. + + This requirement is mentioned in `https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-autodiff`. + """ + + @jace.grad + def df(x): + if x < 3: + return 3.0 * x**2 + return -4 * x + + x1 = 2.0 + df_x1 = 6 * x1 + x2 = 4.0 + df_x2 = -4.0 + + res_1 = df(x1) + res_2 = df(x2) + + assert df(x1) == df_x1, f"Failed lower branch, expected '{df_x1}', got '{res_1}'." + assert df(x2) == df_x2, f"Failed upper branch, expected '{df_x2}', got '{res_2}'." + + +@pytest.mark.skip(reason="Running JaCe with disabled 'x64' support does not work.") +def test_disabled_x64(): + """Tests the behaviour of the tool chain if x64 is disabled. + + If you want to test, if this restriction still applies, you can enable the test. + """ + + def testee(A: np.ndarray, B: np.float64) -> np.ndarray: + return A + B + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.float64(10.0) + + # Run them with disabled x64 support + with jax.experimental.disable_x64(): + # JaCe + jace_testee = jace.jit(testee) + jace_lowered = jace_testee.lower(A, B) + jace_comp = jace_lowered.compile() + res = jace_comp(A, B) + + # Jax + jax_testee = jax.jit(testee) + ref = jax_testee(A, B) + + assert np.allclose(ref, res), "Expected that: {ref.tolist()}, but got {res.tolist()}." diff --git a/tests/test_jaxpr_translator_builder.py b/tests/test_jaxpr_translator_builder.py new file mode 100644 index 0000000..efc6657 --- /dev/null +++ b/tests/test_jaxpr_translator_builder.py @@ -0,0 +1,539 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements some tests of the subtranslator builder.""" + +from __future__ import annotations + +import re + +import dace +import jax +import numpy as np +import pytest +from dace.data import Array +from jax import core as jax_core + +import jace +from jace import translator, util +from jace.util import JaCeVar + + +# These are some JaCe variables that we use inside the tests +# Unnamed arrays +array1 = JaCeVar((10, 12), dace.float64) +array2 = JaCeVar((10, 13), dace.float32) +array3 = JaCeVar((11, 16), dace.int64) + +# Unnamed scalars +scal1 = JaCeVar((), dace.float16) +scal2 = JaCeVar((), dace.float32) +scal3 = JaCeVar((), dace.int64) + +# Named variables +narray = JaCeVar((10,), dace.float16, "narr") +nscal = JaCeVar((), dace.int32, "nscal") + + +@pytest.fixture() +def translation_builder(): + """Returns an allocated builder instance.""" + name = "fixture_builder" + builder = translator.JaxprTranslationBuilder( + primitive_translators=translator.get_registered_primitive_translators() + ) + builder._allocate_translation_ctx(name=name) + return builder + + +def test_builder_alloc() -> None: + """Tests the state right after allocation. + + Does not use the fixture because it does it on its own. + """ + builder = translator.JaxprTranslationBuilder( + primitive_translators=translator.get_registered_primitive_translators() + ) + assert not builder.is_allocated(), "Builder was created allocated." + assert len(builder._ctx_stack) == 0 + + # The reserved names will be tested in `test_builder_fork()`. + sdfg_name = "qwertzuiopasdfghjkl" + builder._allocate_translation_ctx(name=sdfg_name) + assert len(builder._ctx_stack) == 1 + assert builder.is_root_translator() + + sdfg: dace.SDFG = builder.sdfg + + assert builder._ctx.sdfg is sdfg + assert builder.sdfg.name == sdfg_name + assert sdfg.number_of_nodes() == 1 + assert sdfg.number_of_edges() == 0 + assert sdfg.start_block is builder._ctx.start_state + assert builder._terminal_sdfg_state is builder._ctx.start_state + + +def test_builder_variable_alloc_auto_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests simple variable allocation.""" + for i, var in enumerate([array1, array2, scal1, array3, scal2, scal3]): + sdfg_name = translation_builder.add_array(var, update_var_mapping=True) + sdfg_var = translation_builder.get_array(sdfg_name) + assert sdfg_name == chr(97 + i) + assert isinstance(sdfg_var, Array) # Everything is now an array + assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_mixed_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests the naming in a mixed setting. + + If `update_var_mapping=True` is given, then the naming will skip variables, + see also `test_builder_variable_alloc_mixed_naming2()`. + """ + # * b c d * f g + for i, var in enumerate([narray, array1, array2, scal1, nscal, scal2, scal3]): + sdfg_name = translation_builder.add_array(var, update_var_mapping=True) + sdfg_var = translation_builder.get_array(sdfg_name) + if var.name is None: + assert sdfg_name == chr(97 + i) + else: + assert sdfg_name == var.name + assert isinstance(sdfg_var, Array) # Everything is now an array + assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_mixed_naming2( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests the naming in a mixed setting. + + This time we do not use `update_var_mapping=True`, instead it now depends on the + name. This means that automatic naming will now again include all, letters, but not + in a linear order. + """ + letoff = 0 + # * a b c * d e + for var in [narray, array1, array2, scal1, nscal, scal2, scal3]: + sdfg_name = translation_builder.add_array(var, update_var_mapping=var.name is None) + sdfg_var = translation_builder.get_array(sdfg_name) + if var.name is None: + assert sdfg_name == chr(97 + letoff) + letoff += 1 + else: + assert sdfg_name == var.name + assert isinstance(sdfg_var, Array) # Everything is now an array + assert sdfg_var.shape == ((1,) if var.shape == () else var.shape) + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_prefix_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Using the prefix to name variables.""" + prefix_1 = "__my_special_prefix" + exp_name_1 = prefix_1 + "a" + sdfg_name_1 = translation_builder.add_array( + array1, name_prefix=prefix_1, update_var_mapping=False + ) + assert exp_name_1 == sdfg_name_1 + + # Because `update_var_mapping` is `False` above, 'a' will be reused. + prefix_2 = "__my_special_prefix_second_" + exp_name_2 = prefix_2 + "a" + sdfg_name_2 = translation_builder.add_array( + array1, name_prefix=prefix_2, update_var_mapping=False + ) + assert exp_name_2 == sdfg_name_2 + + # Now we use a named variables, which are also affected. + prefix_3 = "__my_special_prefix_third_named_" + exp_name_3 = prefix_3 + nscal.name # type: ignore[operator] # `.name` is not `None`. + sdfg_name_3 = translation_builder.add_array( + nscal, name_prefix=prefix_3, update_var_mapping=False + ) + assert exp_name_3 == sdfg_name_3 + + +def test_builder_variable_alloc_auto_naming_wrapped( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests the variable naming if we have more than 26 variables.""" + single_letters = [chr(x) for x in range(97, 123)] + i = 0 + for let1 in ["", *single_letters[1:]]: # Note `z` is followed by `ba` and not by `aa`. + for let2 in single_letters: + i += 1 + # Create a variable and enter it into the variable naming. + var = JaCeVar(shape=(19, 19), dtype=dace.float64) + sdfg_name = translation_builder.add_array(arg=var, update_var_mapping=True) + mapped_name = translation_builder.map_jax_var_to_sdfg(var) + assert ( + sdfg_name == mapped_name + ), f"Mapping for '{var}' failed, expected '{sdfg_name}' got '{mapped_name}'." + + # Get the name that we really expect, we must also handle some situations. + exp_name = let1 + let2 + if exp_name in util.FORBIDDEN_SDFG_VAR_NAMES: + exp_name = "__jace_forbidden_" + exp_name + assert ( + exp_name == sdfg_name + ), f"Automated naming failed, expected '{exp_name}' but got '{sdfg_name}'." + + +def test_builder_nested(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests the ability of the nesting of the builder.""" + + # Now add a variable to the current subtext. + name_1 = translation_builder.add_array(array1, update_var_mapping=True) + assert name_1 == "a" + assert translation_builder.map_jax_var_to_sdfg(array1) == name_1 + + # For the sake of doing it add a new state to the SDFG. + translation_builder.append_new_state("sake_state") + assert translation_builder.sdfg.number_of_nodes() == 2 + assert translation_builder.sdfg.number_of_edges() == 1 + + # Now we go one subcontext deeper + translation_builder._allocate_translation_ctx("builder") + assert len(translation_builder._ctx_stack) == 2 + assert translation_builder.sdfg.name == "builder" + assert translation_builder.sdfg.number_of_nodes() == 1 + assert translation_builder.sdfg.number_of_edges() == 0 + assert not translation_builder.is_root_translator() + + # Because we have a new SDFG the mapping to previous SDFG does not work, + # regardless the fact that it still exists. + with pytest.raises( + expected_exception=KeyError, + match=re.escape( + f"Jax variable '{array1}' was supposed to map to '{name_1}', but no such SDFG variable is known." + ), + ): + _ = translation_builder.map_jax_var_to_sdfg(array1) + + # Because the SDFGs are distinct it is possible to add `array1` to the nested one. + # However, it is not able to update the mapping. + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"Cannot change the mapping of '{array1}' from '{name_1}' to '{name_1}'."), + ): + _ = translation_builder.add_array(array1, update_var_mapping=True) + assert name_1 not in translation_builder.sdfg.arrays + + # Without updating the mapping it is possible create the variable. + assert name_1 == translation_builder.add_array(array1, update_var_mapping=False) + + # Now add a new variable, the map is shared, so a new name will be generated. + name_2 = translation_builder.add_array(array2, update_var_mapping=True) + assert name_2 == "b" + assert name_2 == translation_builder.map_jax_var_to_sdfg(array2) + + # Now we go one stack level back. + translation_builder._clear_translation_ctx() + assert len(translation_builder._ctx_stack) == 1 + assert translation_builder.sdfg.number_of_nodes() == 2 + assert translation_builder.sdfg.number_of_edges() == 1 + + # Again the variable that was declared in the last stack is now no longer present. + # Note if the nested SDFG was integrated into the parent SDFG it would be + # accessible + with pytest.raises( + expected_exception=KeyError, + match=re.escape( + f"Jax variable '{array2}' was supposed to map to '{name_2}', but no such SDFG variable is known." + ), + ): + _ = translation_builder.map_jax_var_to_sdfg(array2) + assert name_2 == translation_builder._jax_name_map[array2] + + # Now add a new variable, since the map is shared, we will now get the next name. + name_3 = translation_builder.add_array(array3, update_var_mapping=True) + assert name_3 == "c" + assert name_3 == translation_builder.map_jax_var_to_sdfg(array3) + + +def test_builder_append_state(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests the functionality of appending states.""" + sdfg: dace.SDFG = translation_builder.sdfg + + terminal_state_1: dace.SDFGState = translation_builder.append_new_state("terminal_state_1") + assert sdfg.number_of_nodes() == 2 + assert sdfg.number_of_edges() == 1 + assert terminal_state_1 is translation_builder._terminal_sdfg_state + assert translation_builder._terminal_sdfg_state is translation_builder._ctx.terminal_state + assert translation_builder._ctx.start_state is sdfg.start_block + assert translation_builder._ctx.start_state is not terminal_state_1 + assert next(iter(sdfg.edges())).src is sdfg.start_block + assert next(iter(sdfg.edges())).dst is terminal_state_1 + + # Specifying an explicit append state that is the terminal should also update the + # terminal state of the builder. + terminal_state_2: dace.SDFGState = translation_builder.append_new_state( + "terminal_state_2", prev_state=terminal_state_1 + ) + assert sdfg.number_of_nodes() == 3 + assert sdfg.number_of_edges() == 2 + assert terminal_state_2 is translation_builder._terminal_sdfg_state + assert sdfg.out_degree(terminal_state_1) == 1 + assert sdfg.out_degree(terminal_state_2) == 0 + assert sdfg.in_degree(terminal_state_2) == 1 + assert next(iter(sdfg.in_edges(terminal_state_2))).src is terminal_state_1 + + # Specifying a previous node that is not the terminal state should not do anything. + non_terminal_state: dace.SDFGState = translation_builder.append_new_state( + "non_terminal_state", prev_state=terminal_state_1 + ) + assert translation_builder._terminal_sdfg_state is not non_terminal_state + assert sdfg.in_degree(non_terminal_state) == 1 + assert sdfg.out_degree(non_terminal_state) == 0 + assert next(iter(sdfg.in_edges(non_terminal_state))).src is terminal_state_1 + + +def test_builder_variable_multiple_variables( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Add an already known variable, but with a different name.""" + # Now we will add `array1` and then different ways of updating it. + narray1: str = translation_builder.add_array(array1, update_var_mapping=True) + + # It will fail if we use the prefix, because we also want to update. + prefix = "__jace_prefix" + prefix_expected_name = prefix + narray1 + with pytest.raises( + expected_exception=ValueError, + match=re.escape( + f"Cannot change the mapping of '{array1}' from '{translation_builder.map_jax_var_to_sdfg(array1)}' to '{prefix_expected_name}'." + ), + ): + _ = translation_builder.add_array(array1, update_var_mapping=True, name_prefix=prefix) + assert prefix_expected_name not in translation_builder.sdfg.arrays + + # But if we do not want to update it then it works. + prefix_sdfg_name = translation_builder.add_array( + array1, update_var_mapping=False, name_prefix=prefix + ) + assert prefix_expected_name == prefix_sdfg_name + assert prefix_expected_name in translation_builder.sdfg.arrays + assert narray1 == translation_builder.map_jax_var_to_sdfg(array1) + + +def test_builder_variable_invalid_prefix( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Use invalid prefix.""" + # It will fail if we use the prefix, because we also want to update. + for iprefix in ["0_", "_ja ", "_!"]: + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"add_array({array1}): The proposed name '{iprefix}a', is invalid."), + ): + _ = translation_builder.add_array(array1, update_var_mapping=False, name_prefix=iprefix) + assert len(translation_builder.sdfg.arrays) == 0 + + +def test_builder_variable_alloc_list( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api.""" + var_list_1 = [array1, nscal, scal2] + exp_names_1 = ["a", nscal.name, "c"] + + res_names_1 = translation_builder.create_jax_var_list(var_list_1, update_var_mapping=True) + assert len(translation_builder.arrays) == 3 + assert res_names_1 == exp_names_1 + + # Now a mixture of the collection and creation. + var_list_2 = [array2, nscal, scal1] + exp_names_2 = ["d", nscal.name, "e"] + + res_names_2 = translation_builder.create_jax_var_list(var_list_2, update_var_mapping=True) + assert res_names_2 == exp_names_2 + assert len(translation_builder.arrays) == 5 + + +@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") +def test_builder_variable_alloc_list_cleaning( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will fail because `update_var_mapping=False` thus the third variable will + cause an error because it is proposed to `a`, which is already used. + """ + var_list = [array1, nscal, scal2] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"add_array({scal2}): The proposed name 'a', is used."), + ): + _ = translation_builder.create_jax_var_list(var_list) + + assert len(translation_builder.arrays) == 0 + + +def test_builder_variable_alloc_list_prevent_creation( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `prevent_creation` flag. + """ + # First create a variable. + translation_builder.add_array(array1, update_var_mapping=True) + assert len(translation_builder.arrays) == 1 + + # Now create the variables + var_list = [array1, array2] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"'prevent_creation' given but have to create '{array2}'."), + ): + translation_builder.create_jax_var_list(var_list, prevent_creation=True) + assert len(translation_builder.arrays) == 1 + assert translation_builder.map_jax_var_to_sdfg(array1) == "a" + + +@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") +def test_builder_variable_alloc_list_only_creation( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `only_creation` flag. + """ + # First create a variable. + translation_builder.add_array(array1, update_var_mapping=True) + assert len(translation_builder.arrays) == 1 + + # Now create the variables + var_list = [array2, array1] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"'only_creation' given '{array1}' already exists."), + ): + translation_builder.create_jax_var_list(var_list, only_creation=True) + assert len(translation_builder.arrays) == 1 + assert translation_builder.map_jax_var_to_sdfg(array1) == "a" + + +def test_builder_variable_alloc_list_handle_literal( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `handle_literals` flag. + """ + + val = np.array(1) + aval = jax_core.get_aval(val) + lit = jax_core.Literal(val, aval) + var_list = [lit] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape("Encountered a literal but `handle_literals` was `False`."), + ): + translation_builder.create_jax_var_list(var_list, handle_literals=False) + assert len(translation_builder.arrays) == 0 + + name_list = translation_builder.create_jax_var_list(var_list, handle_literals=True) + assert len(translation_builder.arrays) == 0 + assert name_list == [None] + + +def test_builder_constants(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests part of the `JaxprTranslationBuilder._create_constants()` api. + + See also the `test_subtranslators_alu.py::test_add3` test. + """ + # Create the Jaxpr that we need. + constant = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + jaxpr = jax.make_jaxpr(lambda A: A + jax.numpy.array(constant))(1.0) + + # We have to manually allocate the builder context. + # You should not do that. + translation_builder._allocate_translation_ctx(name="Manual_test") + + # No create the constants. + translation_builder._create_constants(jaxpr) + + # Test if it was created with the correct value. + assert len(translation_builder.arrays) == 1 + assert len(translation_builder._jax_name_map) == 1 + assert next(iter(translation_builder._jax_name_map.values())) == "__const_a" + assert len(translation_builder.sdfg.constants) == 1 + assert np.all(translation_builder.sdfg.constants["__const_a"] == constant) + + +def test_builder_scalar_return_value() -> None: + """Tests if scalars can be returned directly.""" + + def scalar_ops(A: float) -> float: + return A + A - A * A + + lower_cnt = [0] + + @jace.jit + def wrapped(A: float) -> float: + lower_cnt[0] += 1 + return scalar_ops(A) + + vals = np.random.random(100) # noqa: NPY002 + for i in range(vals.size): + res = wrapped(vals[i]) + ref = scalar_ops(vals[i]) + assert np.allclose(res, ref) + assert lower_cnt[0] == 1 + + +@pytest.mark.skip(reason="Currently 'scalar' return values, are actually shape '(1,)' arrays.") +def test_builder_scalar_return_type() -> None: + """Tests if the type is the same, in case of scalar return.""" + + @jace.jit + def wrapped(A: np.float64) -> np.float64: + return A + A - A * A + + A = np.float64(1.0) + assert type(A) is np.float64, f"Expected type 'np.float64', but got '{type(A).__name__}'." + + +def test_builder_jace_var() -> None: + """Simple tests about the `JaCeVar` objects.""" + for iname in ["do", "", "_ _", "9al", "_!"]: + with pytest.raises( + expected_exception=ValueError, match=re.escape(f"Supplied the invalid name '{iname}'.") + ): + _ = JaCeVar((), dace.int8, name=iname) + + +def test_builder_F_strides() -> None: + """Tests if we can lower without a standard stride. + + Notes: + This tests if the restriction is currently in place. + See also `tests/test_caching.py::test_caching_strides`. + """ + + @jace.jit + def testee(A: np.ndarray) -> np.ndarray: + return A + 10.0 + + F = np.full((4, 3), 10, dtype=np.float64, order="F") + + with pytest.raises( + expected_exception=NotImplementedError, + match=re.escape("Currently can not yet handle strides beside 'C_CONTIGUOUS'."), + ): + _ = testee(F) diff --git a/tests/test_misc.py b/tests/test_misc.py new file mode 100644 index 0000000..80abefd --- /dev/null +++ b/tests/test_misc.py @@ -0,0 +1,40 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements general tests for JaCe.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import jace + + +@pytest.mark.skip("Possible bug in DaCe.") +def test_mismatch_in_datatyte_calling(): + """Tests compilation and calling with different types. + + Note that this more or less tests the calling implementation of the `CompiledSDFG` + class in DaCe. As I understand the `CompiledSDFG::_construct_args()` function this + should be detected. However, as evidently it does not do this. + """ + + @jace.jit + def testee(A: np.ndarray) -> np.ndarray: + return -A + + # Different types. + A1 = np.arange(12, dtype=np.float32).reshape((4, 3)) + A2 = np.arange(12, dtype=np.int64).reshape((4, 3)) + + # Lower and compilation for first type + callee = testee.lower(A1).compile() + + # But calling with the second type + with pytest.raises(Exception): # noqa: B017, PT011 # Unknown exception. + _ = callee(A2) diff --git a/tests/test_package.py b/tests/test_package.py index bf92c00..5237aeb 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -9,8 +9,11 @@ import importlib.metadata +import pytest + import jace as m +@pytest.mark.skip(reason="This does not work yet.") def test_version(): assert importlib.metadata.version("jace") == m.__version__ diff --git a/tests/test_sub_translators_alu.py b/tests/test_sub_translators_alu.py new file mode 100644 index 0000000..603f57c --- /dev/null +++ b/tests/test_sub_translators_alu.py @@ -0,0 +1,63 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for the ALU translator.""" + +from __future__ import annotations + +import jax +import numpy as np + +import jace + + +def test_add(): + """Simple add function.""" + + def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: + return A + B + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + ref = testee(A, B) + res = jace.jit(testee)(A, B) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + + +def test_add2(): + """Simple add function, with literal.""" + + def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: + c = A + 0.01 + d = B * 0.6 + e = c / 1.0 + f = d - 0.1 + return e + f * d + + A = np.arange(12, dtype=np.float64).reshape((4, 3)) + B = np.full((4, 3), 10, dtype=np.float64) + + ref = testee(A, B) + res = jace.jit(testee)(A, B) + + assert np.allclose(ref, res), f"Expected '{ref.tolist()}' got '{res.tolist()}'." + + +def test_add3(): + """Simple add function, with constant.""" + + def testee(A: np.ndarray) -> np.ndarray: + return A + jax.numpy.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + + A = np.ones((3, 3), dtype=np.float64) + + ref = testee(A) + res = jace.jit(testee)(A) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." diff --git a/tests/test_subtranslator_helper.py b/tests/test_subtranslator_helper.py new file mode 100644 index 0000000..56b30fb --- /dev/null +++ b/tests/test_subtranslator_helper.py @@ -0,0 +1,217 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for managing the primitive subtranslators.""" + +from __future__ import annotations + +import re +from typing import Any + +import numpy as np +import pytest + +import jace +from jace import translator +from jace.translator import ( + get_registered_primitive_translators, + make_primitive_translator, + register_primitive_translator, +) + + +@pytest.fixture(autouse=True) +def _conserve_builtin_translators(): + """Restores the set of registered subtranslators after a test.""" + initial_translators = translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.copy() + yield + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.update(initial_translators) + + +@pytest.fixture() +def no_builtin_translators(): # noqa: PT004 # This is how you should do it: https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#use-fixtures-in-classes-and-modules-with-usefixtures + """This fixture can be used if the test does not want any builtin translators.""" + initial_translators = translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.copy() + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() + yield + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.update(initial_translators) + + +# These are definitions of some Subtranslators that can be used to test things. +class SubTrans1(translator.PrimitiveTranslator): + @property + def primitive(self): + return "non_existing_primitive1" + + def __call__(self) -> None: # type: ignore[override] # Arguments + raise NotImplementedError + + +class SubTrans2(translator.PrimitiveTranslator): + @property + def primitive(self): + return "non_existing_primitive2" + + def __call__(self) -> None: # type: ignore[override] # Arguments + raise NotImplementedError + + +@make_primitive_translator("non_existing_callable_primitive3") +def SubTrans3_Callable(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + raise NotImplementedError + + +@make_primitive_translator("add") +def fake_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + raise NotImplementedError + + +def test_are_subtranslators_imported(): + """Tests if something is inside the list of subtranslators.""" + # Must be adapted if new primitives are implemented. + assert len(get_registered_primitive_translators()) == 37 + + +@pytest.mark.usefixtures("no_builtin_translators") +def test_subtranslatior_managing(): + """Basic functionality of the subtranslators.""" + original_active_subtrans = get_registered_primitive_translators() + assert len(original_active_subtrans) == 0 + + # Create the classes. + sub1 = SubTrans1() + sub2 = SubTrans2() + + # These are all primitive translators + prim_translators = [sub1, sub2, SubTrans3_Callable] + + # Add the instances. + for sub in prim_translators: + assert register_primitive_translator(sub) is sub + + # Tests if they were correctly registered + active_subtrans = get_registered_primitive_translators() + for expected in prim_translators: + assert active_subtrans[expected.primitive] is expected + assert len(active_subtrans) == 3 + + +def test_subtranslatior_managing_isolation(): + """Tests if `get_registered_primitive_translators()` decouples.""" + assert ( + get_registered_primitive_translators() + is not translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY + ) + + initial_primitives = get_registered_primitive_translators() + assert get_registered_primitive_translators() is not initial_primitives + assert "add" in initial_primitives, "For this test the 'add' primitive must be registered." + org_add_prim = initial_primitives["add"] + + initial_primitives["add"] = fake_add_translator + assert org_add_prim is not fake_add_translator + assert get_registered_primitive_translators()["add"] is org_add_prim + + +@pytest.mark.usefixtures("no_builtin_translators") +def test_subtranslatior_managing_callable_annotation(): + """Test if `make_primitive_translator()` works.""" + + prim_name = "non_existing_property" + + @make_primitive_translator(prim_name) + def non_existing_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + raise NotImplementedError + + assert hasattr(non_existing_translator, "primitive") + assert non_existing_translator.primitive == prim_name + assert len(get_registered_primitive_translators()) == 0 + + +def test_subtranslatior_managing_overwriting(): + """Tests if we are able to overwrite something.""" + current_add_translator = get_registered_primitive_translators()["add"] + + @make_primitive_translator("add") + def useless_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + raise NotImplementedError + + # This will not work because it is not overwritten. + with pytest.raises( + expected_exception=ValueError, + match=re.escape( + "Explicit override=True needed for primitive 'add' to overwrite existing one." + ), + ): + register_primitive_translator(useless_add_translator) + assert current_add_translator is get_registered_primitive_translators()["add"] + + # Now we use overwrite, thus it will now work. + assert useless_add_translator is register_primitive_translator( + useless_add_translator, overwrite=True + ) + assert useless_add_translator is get_registered_primitive_translators()["add"] + + +@pytest.mark.usefixtures("no_builtin_translators") +def test_subtranslatior_managing_overwriting_2(): + """Again an overwriting test, but this time a bit more complicated.""" + + trans_cnt = [0] + + @register_primitive_translator(overwrite=True) + @make_primitive_translator("add") + def still_useless_but_a_bit_less(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + trans_cnt[0] += 1 + + @jace.jit + def foo(A): + B = A + 1 + C = B + 1 + D = C + 1 + return D + 1 + + _ = foo.lower(1) + assert trans_cnt[0] == 4 + + +def test_subtranslatior_managing_decoupling(): + """Shows that we have proper decoupling. + + I.e. changes to the global state, does not affect already annotated functions. + """ + + # This will use the translators that are currently installed. + @jace.jit + def foo(A): + B = A + 1 + C = B + 1 + D = C + 1 + return D + 1 + + @register_primitive_translator(overwrite=True) + @make_primitive_translator("add") + def useless_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + raise NotImplementedError("The 'useless_add_translator' was called as expected.") + + # Since `foo` was already constructed, a new registering can not change anything. + A = np.zeros((10, 10)) + assert np.all(foo(A) == 4) + + # But if we now annotate a new function, then we will get the uselss translator + @jace.jit + def foo_fail(A): + B = A + 1 + return B + 1 + + with pytest.raises( + expected_exception=NotImplementedError, + match=re.escape("The 'useless_add_translator' was called as expected."), + ): + _ = foo_fail.lower(A)