diff --git a/ROADMAP.md b/ROADMAP.md index 5beb172..2633585 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -7,7 +7,7 @@ A kind of roadmap that gives a rough idea about how the project will be continue - [x] Annotation `@jace.jit`. - [x] Composable with Jax, i.e. take the Jax derivative of a JaCe annotated function. - [x] Implementing the `stages` model that is supported by Jax. - - [ ] Handling Jax arrays as native input (only on single host). + - [x] Handling Jax arrays as native input (only on single host). - [x] Cache the compilation and lowering results for later reuse. In Jax these parts (together with the dispatch) are actually written in C++, thus in the beginning we will use a self made cache. - [ ] Implementing some basic `PrimitiveTranslators`, that allows us to run some early tests, such as: @@ -23,6 +23,7 @@ A kind of roadmap that gives a rough idea about how the project will be continue But passing these benchmarks could give us some better hint of how to proceed in this matter. - [ ] Passing the [pyhpc-benchmark](https://github.com/dionhaefner/pyhpc-benchmarks) - [ ] Passing Felix' fluid project; possibility. + - [ ] Flash-Attention, there is a DaCe implementation. - [ ] Support of static arguments. - [ ] Stop relying on `jax.make_jaxpr()`. Look at the `jax._src.pjit.make_jit()` function for how to hijack the staging process. diff --git a/docs/index.md b/docs/index.md index e447b26..c395a73 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,11 +3,9 @@ ```{toctree} :maxdepth: 2 :hidden: - ``` ```{include} ../README.md -:start-after: ``` ## Indices and tables diff --git a/docs/main_differences.md b/docs/main_differences.md new file mode 100644 index 0000000..5ac19aa --- /dev/null +++ b/docs/main_differences.md @@ -0,0 +1,21 @@ +# Main Differences Between DaCe and JaCe and JAX and JaCe + +Essentially JaCe is a frontend that allows DaCe to process JAX code, thus it has to be compatible with both, at least in some sense. +We will now list the main differences between them, furthermore, you should also consult the ROADMAP. + +### JAX vs. JaCe: + +- JaCe always traces with enabled `x64` mode. + This is a restriction that might be lifted in the future. +- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`. +- In JAX parts of the computation runs on CPU parts on GPU, in JaCe everything runs (currently) either on CPU or GPU. +- Currently JaCe is only able to run on CPU (will be lifted soon). +- Currently JaCe is not able to run distributed (will be lifted later). +- Currently not all primitives are supported. +- JaCe does not return `jax.Array` instances, but NumPy/CuPy arrays. +- The execution is not asynchronous. + +### DaCe vs. JaCe: + +- JaCe accepts complex objects using JAX' pytrees. +- JaCe will support scalar inputs on GPU. diff --git a/noxfile.py b/noxfile.py index b6aec1b..76e0a2e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -134,7 +134,10 @@ def docs(session: nox.Session) -> None: @nox.session(reuse_venv=True) def api_docs(session: nox.Session) -> None: """Build (regenerate) API docs.""" - session.install(f"sphinx=={REQUIREMENTS['sphinx']}") + sphinx_req = REQUIREMENTS["sphinx"] + if sphinx_req.isdigit(): + sphinx_req = "==" + sphinx_req + session.install(f"sphinx{sphinx_req}") session.chdir("docs") session.run( "sphinx-apidoc", diff --git a/pyproject.toml b/pyproject.toml index 393ce01..a03a349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] filterwarnings = [ "error", "ignore:numpy\\..*:DeprecationWarning", # DaCe is not NumPy v2.0 ready so ignore the usage of deprecated features. + "ignore:pandas not found, skipping conversion test\\.:ImportWarning", # Pandas is not installed on the CI. ] log_cli_level = "INFO" minversion = "6.0" @@ -233,7 +234,6 @@ max-complexity = 12 ] "tests/**" = [ "D", # pydocstyle - "N", # TODO(egparedes): remove ignore as soon as all tests are properly named "PLR2004", # [magic-value-comparison] "T10", # flake8-debugger "T20", # flake8-print diff --git a/tests/common_fixture.py b/tests/common_fixture.py new file mode 100644 index 0000000..b36f22d --- /dev/null +++ b/tests/common_fixture.py @@ -0,0 +1,8 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Contains all common fixture we need.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9f454a1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,102 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""General configuration for the tests. + +Todo: + - Implement some fixture that allows to force validation. +""" + +from __future__ import annotations + +from collections.abc import Generator + +import jax +import numpy as np +import pytest + +from jace import optimization, stages +from jace.util import translation_cache as tcache + + +@pytest.fixture(autouse=True) +def _enable_x64_mode_in_jax() -> Generator[None, None, None]: + """Fixture of enable the `x64` mode in JAX. + + Currently, JaCe requires that `x64` mode is enabled and will do all JAX + things with it enabled. However, if we use JAX with the intend to compare + it against JaCe we must also enable it for JAX. + """ + with jax.experimental.enable_x64(): + yield + + +@pytest.fixture(autouse=True) +def _disable_jit() -> Generator[None, None, None]: + """Fixture for disable the dynamic jiting in JAX, used by default. + + Using this fixture has two effects. + - JAX will not cache the results, i.e. every call to a jitted function will + result in a tracing operation. + - JAX will not use implicit jit operations, i.e. nested Jaxpr expressions + using `pjit` are avoided. + + This essentially disable the `jax.jit` decorator, however, the `jace.jit` + decorator is still working. + + Note: + The second point, i.e. preventing JAX from running certain things in `pjit`, + is the main reason why this fixture is used by default, without it + literal substitution is useless and essentially untestable. + In certain situation it can be disabled. + """ + with jax.disable_jit(disable=True): + yield + + +@pytest.fixture() +def _enable_jit() -> Generator[None, None, None]: + """Fixture to enable jit compilation. + + Essentially it undoes the effects of the `_disable_jit()` fixture. + It is important that this fixture is not automatically activated. + """ + with jax.disable_jit(disable=False): + yield + + +@pytest.fixture(autouse=True) +def _clear_translation_cache() -> Generator[None, None, None]: + """Decorator that clears the translation cache. + + Ensures that a function finds an empty cache and clears up afterwards. + """ + tcache.clear_translation_cache() + yield + tcache.clear_translation_cache() + + +@pytest.fixture(autouse=True) +def _reset_random_seed() -> None: + """Fixture for resetting the random seed. + + This ensures that for every test the random seed of NumPy is reset. + This seed is used by the `util.mkarray()` helper. + """ + np.random.seed(42) # noqa: NPY002 [numpy-legacy-random] + + +@pytest.fixture(autouse=True) +def _set_compile_options() -> Generator[None, None, None]: + """Disable all optimizations of jitted code. + + Without explicitly supplied arguments `JaCeLowered.compile()` will not + perform any optimizations. + Please not that certain tests might override this fixture. + """ + with stages.set_compiler_options(optimization.NO_OPTIMIZATIONS): + yield diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 0000000..edaf6ea --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1,11 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""JaCe's integration tests. + +Currently they are mostly related to the primitive translators. +""" diff --git a/tests/integration_tests/primitive_translators/__init__.py b/tests/integration_tests/primitive_translators/__init__.py new file mode 100644 index 0000000..16abf65 --- /dev/null +++ b/tests/integration_tests/primitive_translators/__init__.py @@ -0,0 +1,8 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests related to the actual primitive subtranslators.""" diff --git a/tests/integration_tests/primitive_translators/conftest.py b/tests/integration_tests/primitive_translators/conftest.py new file mode 100644 index 0000000..6f81b09 --- /dev/null +++ b/tests/integration_tests/primitive_translators/conftest.py @@ -0,0 +1,40 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""General configuration for the tests of the primitive translators.""" + +from __future__ import annotations + +from collections.abc import Generator + +import pytest + +from jace import optimization, stages + + +@pytest.fixture( + autouse=True, + params=[ + optimization.NO_OPTIMIZATIONS, + pytest.param( + optimization.DEFAULT_OPTIMIZATIONS, + marks=pytest.mark.skip( + "Simplify bug 'https://github.com/spcl/dace/issues/1595'; resolved > 16.1" + ), + ), + ], +) +def _set_compile_options(request) -> Generator[None, None, None]: + """Set the options used for testing the primitive translators. + + This fixture override the global defined fixture. + + Todo: + Implement a system that only runs the optimization case in CI. + """ + with stages.set_compiler_options(request.param): + yield diff --git a/tests/integration_tests/primitive_translators/test_primitive_arithmetic_logical_operations.py b/tests/integration_tests/primitive_translators/test_primitive_arithmetic_logical_operations.py new file mode 100644 index 0000000..7573de5 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_arithmetic_logical_operations.py @@ -0,0 +1,414 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for `MappedOperationTranslatorBase` class and arithmetic & logical operations. + +The `MappedOperationTranslatorBase` can not be tested on its own, since it does +not generate a Tasklet. For that reason it is thoroughly tested together with +the arithmetic and logical translators (ALT). + +Thus the first tests tests the behaviour of the `MappedOperationTranslatorBase` +class such as +- broadcasting, +- literal substitution, +- scalar vs array computation. + +Followed by tests that are specific to the ALTs, which mostly focuses +on the validity of the template of the ALT. +""" + +from __future__ import annotations + +from collections.abc import Callable, Generator +from typing import Any + +import dace +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +@pytest.fixture(autouse=True) +def _only_alt_translators() -> Generator[None, None, None]: + """Removes all non arithmetic/logical translator from the registry. + + This ensures that JAX is not doing some stuff that is supposed to be handled by the + test class, such as broadcasting. It makes writing tests a bit harder, but it is + worth. For some reasons also type conversion s allowed. + """ + from jace.translator.primitive_translators.arithmetic_logical_translators import ( # noqa: PLC0415 [import-outside-top-level] + _ARITMETIC_OPERATION_TEMPLATES, # noqa: PLC2701 [import-private-name] + _LOGICAL_OPERATION_TEMPLATES, # noqa: PLC2701 [import-private-name] + ) + + # Remove all non ALU translators from the registry + primitive_translators = jace.translator.get_registered_primitive_translators() + allowed_translators = ( + _LOGICAL_OPERATION_TEMPLATES.keys() + | _ARITMETIC_OPERATION_TEMPLATES.keys() + | {"convert_element_type", "pjit"} + ) + testutil.set_active_primitive_translators_to({ + p: t for p, t in primitive_translators.items() if p in allowed_translators + }) + + yield + + # Restore the initial state + testutil.set_active_primitive_translators_to(primitive_translators) + + +@pytest.fixture( + params=[ + (jnp.logical_and, 2, np.bool_), + (jnp.logical_or, 2, np.bool_), + (jnp.logical_xor, 2, np.bool_), + (jnp.logical_not, 1, np.bool_), + (jnp.bitwise_and, 2, np.int64), + (jnp.bitwise_or, 2, np.int64), + (jnp.bitwise_xor, 2, np.int64), + (jnp.bitwise_not, 1, np.int64), + ] +) +def logical_ops(request) -> tuple[Callable, tuple[np.ndarray, ...]]: + """Returns a logical operation function and inputs.""" + return ( + request.param[0], + tuple(testutil.make_array((2, 2), request.param[2]) for _ in range(request.param[1])), + ) + + +@pytest.fixture( + params=[ + np.float32, + pytest.param( + np.complex64, + marks=pytest.mark.skip("Some complex values operations are not fully supported."), + ), + ] +) +def dtype(request) -> type: + """Data types that should be used for the numerical tests of the ALT translators.""" + return request.param + + +@pytest.fixture( + params=[ + lambda x: +(x - 0.5), + lambda x: -x, + jnp.floor, + jnp.ceil, + jnp.round, + jnp.exp2, + lambda x: jnp.abs(-x), + lambda x: jnp.sqrt(x**2), # includes integer power. + lambda x: jnp.log(jnp.exp(x)), + lambda x: jnp.log1p(jnp.expm1(x)), + lambda x: jnp.asin(jnp.sin(x)), + lambda x: jnp.acos(jnp.cos(x)), + lambda x: jnp.atan(jnp.tan(x)), + lambda x: jnp.asinh(jnp.sinh(x)), + lambda x: jnp.acosh(jnp.cosh(x)), + lambda x: jnp.atanh(jnp.tanh(x)), + ] +) +def alt_unary_ops(request, dtype: type) -> tuple[Callable, np.ndarray]: + """The inputs and the operation we need for the full test. + + Some of the unary operations are combined to ensure that they will succeed. + An example is `asin()` which only takes values in the range `[-1, 1]`. + """ + return (request.param, testutil.make_array((2, 2), dtype)) + + +@pytest.fixture( + params=[ + jnp.add, + jnp.multiply, + jnp.divide, + jnp.minimum, + jnp.maximum, + jnp.atan2, + jnp.nextafter, + lambda x, y: x**y, + ] +) +def alt_binary_ops_float(request) -> tuple[Callable, tuple[np.ndarray, np.ndarray]]: + """Binary ALT operations that operates on floats.""" + # Getting 0 in the division test is unlikely. + return ( # type: ignore[return-value] # Type confusion. + request.param, + tuple(testutil.make_array((2, 2), np.float64) for _ in range(2)), + ) + + +@pytest.fixture( + params=[ + lambda x, y: x == y, + lambda x, y: x != y, + lambda x, y: x <= y, + lambda x, y: x < y, + lambda x, y: x >= y, + lambda x, y: x > y, + ] +) +def alt_binary_compare_ops(request) -> tuple[Callable, tuple[np.ndarray, np.ndarray]]: + """Comparison operations, operates on integers.""" + return ( + request.param, + tuple(np.abs(testutil.make_array((20, 20), np.int32)) % 30 for _ in range(2)), + ) + + +@pytest.fixture( + params=[ + [(100, 1), (100, 10)], + [(100, 1, 3), (100, 1, 1)], + [(5, 1, 3, 4, 1, 5), (5, 1, 3, 1, 2, 5)], + ] +) +def binary_broadcast_input(request) -> tuple[np.ndarray, np.ndarray]: + """Inputs to be used for the binary broadcast test.""" + return tuple(testutil.make_array(shape) for shape in request.param) # type: ignore[return-value] # can not deduce that it is only size 2. + + +@pytest.fixture( + params=[ + [(100, 100), (100, 100), (100, 100)], + [(100, 1), (100, 100), (100, 100)], + [(100, 100), (100, 1), (100, 100)], + [(100, 100), (100, 100), (100, 1)], + [(100, 1), (100, 1), (100, 100)], + [(100, 100), (100, 1), (100, 1)], + [(100, 1), (100, 100), (100, 1)], + [(100, 100), (), ()], + [(), (100, 100), ()], + [(), (), (100, 100)], + [(), (100, 100), (100, 100)], + [(100, 100), (), (100, 100)], + ] +) +def ternary_broadcast_input(request) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Inputs to be used for the ternary broadcast test.""" + + min_val = testutil.make_array(request.param[0]) / 2.0 + value = testutil.make_array(request.param[1]) + max_val = testutil.make_array(request.param[2]) / 2.0 + 0.5 + return (min_val, value, max_val) + + +def _perform_alt_test(testee: Callable, *args: Any) -> Any: + """General function that just performs the test. + + The function returns the JaCe result. + """ + wrapped = jace.jit(testee) + + ref = testee(*args) + res = wrapped(*args) + + if jace.util.is_scalar(ref) or ref.shape == (): + assert res.shape == (1,) + else: + assert ref.shape == res.shape + assert ref.dtype == res.dtype + assert np.allclose(ref, res), f"Expected '{ref.tolist()}' got '{res.tolist()}'" + return res + + +# <------------ Tests for `MappedOperationTranslatorBase` + + +def test_mapped_unary_scalar() -> None: + def testee(a: np.float64) -> np.float64 | jax.Array: + return jnp.cos(a) + + _perform_alt_test(testee, np.float64(1.0)) + + +def test_mapped_unary_array() -> None: + def testee(a: np.ndarray) -> jax.Array: + return jnp.sin(a) + + a = testutil.make_array((100, 10, 3)) + + _perform_alt_test(testee, a) + + +def test_mapped_unary_scalar_literal() -> None: + def testee(a: float) -> float | jax.Array: + return jnp.sin(1.98) + a + + _perform_alt_test(testee, 10.0) + + +def test_mapped_binary_scalar() -> None: + def testee(a: np.float64, b: np.float64) -> np.float64: + return a * b + + _perform_alt_test(testee, np.float64(1.0), np.float64(2.0)) + + +def test_mapped_binary_scalar_partial_literal() -> None: + def testee_r(a: np.float64) -> np.float64: + return a * 2.03 + + def testee_l(a: np.float64) -> np.float64: + return 2.03 * a + + a = np.float64(7.0) + _perform_alt_test(testee_r, a) + _perform_alt_test(testee_l, a) + + +def test_mapped_binary_array() -> None: + """Test binary of arrays, with same size.""" + + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + a = testutil.make_array((100, 10, 3)) + b = testutil.make_array((100, 10, 3)) + _perform_alt_test(testee, a, b) + + +def test_mapped_binary_array_scalar() -> None: + def testee(a: np.ndarray | np.float64, b: np.float64 | np.ndarray) -> np.ndarray: + return a + b # type: ignore[return-value] # It is always an array. + + a = testutil.make_array((100, 22)) + b = np.float64(1.34) + _perform_alt_test(testee, a, b) + _perform_alt_test(testee, b, a) + + +def test_mapped_binary_array_partial_literal() -> None: + def testee_r(a: np.ndarray) -> np.ndarray: + return a + 1.52 + + def testee_l(a: np.ndarray) -> np.ndarray: + return 1.52 + a + + a = testutil.make_array((100, 22)) + _perform_alt_test(testee_r, a) + _perform_alt_test(testee_l, a) + + +def test_mapped_binary_array_constants() -> None: + def testee(a: np.ndarray) -> np.ndarray: + return a + jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + + a = testutil.make_array((3, 3)) + _perform_alt_test(testee, a) + + +def test_mapped_broadcast_binary(binary_broadcast_input: tuple[np.ndarray, np.ndarray]) -> None: + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + a = binary_broadcast_input[0] + b = binary_broadcast_input[1] + _perform_alt_test(testee, a, b) + _perform_alt_test(testee, b, a) + + +def test_mapped_broadcast_ternary( + ternary_broadcast_input: tuple[np.ndarray, np.ndarray, np.ndarray], +) -> None: + def testee(min_val: np.ndarray, value: np.ndarray, max_val: np.ndarray) -> np.ndarray: + return jax.numpy.clip(value, min_val, max_val) # type: ignore[return-value] # JAX returns JAX Arrays. + + _perform_alt_test(testee, *ternary_broadcast_input) + + +# <------------ Tests for arithmetic and logical translators/operations + + +def test_alt_general_unary(alt_unary_ops: tuple[Callable, np.ndarray]) -> None: + def testee(a: np.ndarray) -> np.ndarray: + return alt_unary_ops[0](a) + + _perform_alt_test(testee, alt_unary_ops[1]) + + +def test_alt_unary_isfinite() -> None: + def testee(a: np.ndarray) -> jax.Array: + return jnp.isfinite(a) + + a = np.array([np.inf, +np.inf, -np.inf, np.nan, -np.nan, 1.0]) + + args = dace.Config.get("compiler", "cpu", "args") + try: + new_args = args.replace("-ffast-math", "-fno-finite-math-only") + dace.Config.set("compiler", "cpu", "args", value=new_args) + _perform_alt_test(testee, a) + + finally: + dace.Config.set("compiler", "cpu", "args", value=args) + + +def test_alt_general_binary_float( + alt_binary_ops_float: tuple[Callable, tuple[np.ndarray, np.ndarray]], +) -> None: + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return alt_binary_ops_float[0](a, b) + + _perform_alt_test(testee, *alt_binary_ops_float[1]) + + +def test_alt_ternary_clamp() -> None: + """Tests `jax.lax.clamp()` primitive. + + This primitive is similar to `numpy.clip()` but with a different signature. + Furthermore, this is a ternary operation. + """ + + def testee(min_: np.ndarray, val_: np.ndarray, max_: np.ndarray) -> np.ndarray: + return jax.lax.clamp(min_, val_, max_) # type: ignore[return-value] + + shape = (20, 20) + min_ = testutil.make_array(shape) / 2.0 + max_ = testutil.make_array(shape) / 2.0 + 0.5 + val_ = testutil.make_array(shape) + + jace_res = _perform_alt_test(testee, min_, val_, max_) + + # Ensure that all branches were taken. + assert not any(np.all(jace_res == x) for x in (min_, val_, max_)) + + +def test_alt_compare_operation( + alt_binary_compare_ops: tuple[Callable, tuple[np.ndarray, np.ndarray]], +) -> None: + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return alt_binary_compare_ops[0](a, b) + + _perform_alt_test(testee, *alt_binary_compare_ops[1]) + + +def test_alt_logical_bitwise_operation( + logical_ops: tuple[Callable, tuple[np.ndarray, ...]], +) -> None: + inputs: tuple[np.ndarray, ...] = logical_ops[1] + + def testee(*args: np.ndarray) -> np.ndarray: + return logical_ops[0](*args) + + _perform_alt_test(testee, *inputs) + + +def test_alt_unary_integer_power() -> None: + def testee(a: np.ndarray) -> np.ndarray: + return a**3 + + a = testutil.make_array((10, 2, 3)) + _perform_alt_test(testee, a) diff --git a/tests/integration_tests/primitive_translators/test_primitive_broadcast_in_dim.py b/tests/integration_tests/primitive_translators/test_primitive_broadcast_in_dim.py new file mode 100644 index 0000000..7d434fd --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_broadcast_in_dim.py @@ -0,0 +1,83 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for the `broadcast_in_dim` primitive. + +Parts of the tests are also implemented inside +`test_sub_translators_squeeze_expand_dims.py`, because this primitive has a relation to +`squeeze`. + +Todo: + - `np.meshgrid` + - `np.ix_` + - `np.indices` +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@pytest.fixture(params=[(10,), (10, 1), (1, 10)]) +def vector_shape(request) -> tuple[int, ...]: + """Shapes used in the `test_bid_vector()` tests.""" + return request.param + + +def test_bid_scalar() -> None: + """Broadcast a scalar to a matrix.""" + + def testee(a: float) -> jax.Array: + return jnp.broadcast_to(a, (2, 2)) + + a = 1.032 + ref = testee(a) + res = jace.jit(testee)(a) + + assert res.shape == ref.shape + assert res.dtype == ref.dtype + assert np.all(res == ref), f"Expected '{ref.tolist()}' got '{res.tolist()}'." + + +def test_bid_literal() -> None: + """Broadcast a literal to a matrix.""" + + def testee(a: float) -> jax.Array: + return jnp.broadcast_to(1.0, (10, 10)) + a + + ref = testee(0.0) + res = jace.jit(testee)(0.0) + assert res.shape == ref.shape + assert res.dtype == ref.dtype + assert np.all(res == ref) + + +def test_bid_vector(vector_shape: Sequence[int]) -> None: + """Broadcast a vector to a tensor.""" + + def testee(a: np.ndarray) -> jax.Array: + return jnp.broadcast_to(a, (10, 10)) + + a = testutil.make_array(vector_shape) + ref = testee(a) + res = jace.jit(testee)(a) + assert res.shape == ref.shape + assert res.dtype == ref.dtype + assert np.all(res == ref) diff --git a/tests/integration_tests/primitive_translators/test_primitive_concatenate.py b/tests/integration_tests/primitive_translators/test_primitive_concatenate.py new file mode 100644 index 0000000..e09160d --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_concatenate.py @@ -0,0 +1,80 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace +from jace.util import translation_cache as tcache + +from tests import util as testutil + + +def test_cat_1d_arrays() -> None: + """Concatenate two 1d arrays.""" + + a1 = testutil.make_array(10) + a2 = testutil.make_array(10) + + def testee(a1: np.ndarray, a2: np.ndarray) -> jax.Array: + return jax.lax.concatenate((a1, a2), 0) + + ref = testee(a1, a2) + res = jace.jit(testee)(a1, a2) + + assert res.shape == ref.shape + assert np.all(ref == res) + + +def test_cat_nd() -> None: + """Concatenate arrays of higher dimensions.""" + nb_arrays = 4 + std_shape: list[int] = [2, 3, 4, 5, 3] + + for cat_dim in range(len(std_shape)): + tcache.clear_translation_cache() + + # Create the input that we ware using. + input_arrays: list[np.ndarray] = [] + for _ in range(nb_arrays): + shape = std_shape.copy() + shape[cat_dim] = (testutil.make_array((), dtype=np.int32) % 10) + 1 # type: ignore[call-overload] # type confusion + input_arrays.append(testutil.make_array(shape)) + + def testee(inputs: list[np.ndarray]) -> np.ndarray | jax.Array: + return jax.lax.concatenate(inputs, cat_dim) # noqa: B023 [function-uses-loop-variable] + + ref = testee(input_arrays) + res = jace.jit(testee)(input_arrays) + + assert res.shape == ref.shape + assert np.all(ref == res) + + +@pytest.mark.skip(reason="JAX does not support scalars as inputs.") +def test_cat_1d_array_scalars(): + """Concatenate an 1d array with scalars. + + This does not work, it is to observe JAX. + """ + + a1 = testutil.make_array(10) + s1 = testutil.make_array(()) + s2 = testutil.make_array(()) + + def testee(a1: np.ndarray, s1: np.float64, s2: np.float64) -> np.ndarray | jax.Array: + return jnp.concatenate((s1, a1, s2), 0) + + ref = testee(a1, s1, s2) + res = jace.jit(testee)(a1, s1, s2) + + assert res.shape == ref.shape + assert np.all(ref == res) diff --git a/tests/integration_tests/primitive_translators/test_primitive_cond.py b/tests/integration_tests/primitive_translators/test_primitive_cond.py new file mode 100644 index 0000000..da8e61d --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_cond.py @@ -0,0 +1,214 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace +from jace.util import translation_cache as tcache + +from tests import util as testutil + + +def _perform_cond_test( + testee: Callable[[np.float64, tuple[Any, ...]], Any], branch_args: tuple[Any, ...] +) -> None: + """ + Performs a test for the condition primitives. + + It assumes that the first argument is used for the condition and that the + conditions is applied at `0.5`. + The test function adds a prologue, that performs some operations on the + `branch_args` and performs some computations on the final value. + This is done to simulate the typical usage, as it was observed that + sometimes the optimization fails. + """ + tcache.clear_translation_cache() + + def prologue(branch_args: tuple[Any, ...]) -> tuple[Any, ...]: + return tuple( + jnp.exp(jnp.cos(jnp.sin(branch_arg))) ** i + for i, branch_arg in enumerate(branch_args, 2) + ) + + def epilogue(result: Any) -> Any: + return jnp.exp(jnp.sin(jnp.sin(result))) + + def final_testee( + val: np.float64, + branch_args: tuple[Any, ...], + ) -> Any: + return epilogue(testee(jnp.sin(val) + 0.5, prologue(branch_args))) # type: ignore[arg-type] + + vals: list[np.float64] = [np.float64(-0.5), np.float64(0.6)] + wrapped = jace.jit(testee) + + for val in vals: + res = wrapped(val, branch_args) + ref = testee(val, branch_args) + + assert np.all(res == ref) + assert (1,) if ref.shape == () else ref.shape == res.shape + + +def test_cond_full_branches() -> None: + def testee(val: np.float64, branch_args: tuple[np.ndarray, np.ndarray]) -> np.ndarray: + return jax.lax.cond( + val < 0.5, + lambda arg: jnp.sin(arg[0]), + lambda arg: jnp.cos(arg[1]), + branch_args, + ) + + branch_args = tuple(testutil.make_array(1) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +def test_cond_scalar_brnaches() -> None: + def testee(val: np.float64, branch_args: tuple[np.float64, np.float64]) -> np.float64: + return jax.lax.cond( + val < 0.5, + lambda arg: arg[0] + 2.0, + lambda arg: arg[1] + 3.0, + branch_args, + ) + + branch_args = tuple(testutil.make_array(()) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +def test_cond_literal_bool() -> None: + for branch_sel in [True, False]: + + def testee(val: np.float64, branch_args: tuple[np.ndarray, np.ndarray]) -> np.ndarray: + return jax.lax.cond( + branch_sel, # noqa: B023 [function-uses-loop-variable] + lambda arg: jnp.sin(arg[0]) + val, + lambda arg: jnp.cos(arg[1]), + branch_args, + ) + + branch_args = tuple(testutil.make_array(1) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +def test_cond_one_empty_branch() -> None: + def testee(val, branch_args: tuple[np.ndarray, np.ndarray]) -> np.ndarray: + return jax.lax.cond( + val < 0.5, + lambda xtrue: xtrue[0], + lambda xfalse: jnp.array([1]) + xfalse[1], + branch_args, + ) + + branch_args = tuple(testutil.make_array(1) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +@pytest.mark.skip(reason="Literal return value is not implemented.") +def test_cond_literal_branch() -> None: + def testee(val: np.float64, branch_args: tuple[np.ndarray, np.ndarray]) -> np.ndarray: + return jax.lax.cond( + val < 0.5, + lambda xtrue: 1.0, # noqa: ARG005 [unused-lambda-argument] + lambda xfalse: xfalse[1], + branch_args, + ) + + branch_args = tuple(testutil.make_array(()) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +def test_cond_complex_branches() -> None: + def true_branch(arg: np.ndarray) -> np.ndarray: + return jnp.where( + jnp.asin(arg) <= 0.0, + jnp.exp(jnp.cos(jnp.sin(arg))), + arg * 4.0, + ) + + def false_branch(arg: np.ndarray) -> np.ndarray: + return true_branch(jnp.exp(jnp.cos(arg) ** 7)) # type: ignore[arg-type] + + def testee(val: np.float64, branch_args: tuple[np.ndarray, np.ndarray]) -> np.ndarray: + cond_res = jax.lax.cond( + val < 0.5, + lambda arg: true_branch(arg[0]), + lambda arg: false_branch(arg[1]), + branch_args, + ) + return true_branch(cond_res) + + branch_args = tuple(testutil.make_array((100, 100)) for _ in range(2)) + _perform_cond_test(testee, branch_args) + + +def test_cond_switch() -> None: + def testee( + selector: int, + branch_args: tuple[Any, ...], + ) -> np.ndarray: + return jax.lax.switch( + selector, + ( + lambda args: jnp.sin(args[0]), + lambda args: jnp.exp(args[1]), + lambda args: jnp.cos(args[2]), + ), + branch_args, + ) + + wrapped = jace.jit(testee) + branch_args = tuple(testutil.make_array((100, 100)) for _ in range(3)) + + # These are the values that we will use for the selector. + # Note that we also use some invalid values. + selectors = [-1, 0, 1, 2, 3, 4] + + for selector in selectors: + ref = testee(selector, branch_args) + res = wrapped(selector, branch_args) + + assert ref.shape == res.shape + assert np.allclose(ref, res) + + +@pytest.mark.skip("DaCe is not able to optimize it away.") +def test_cond_switch_literal_selector() -> None: + def testee( + branch_args: tuple[Any, ...], + ) -> np.ndarray: + return jax.lax.switch( + 2, + ( + lambda args: jnp.sin(args[0]), + lambda args: jnp.exp(args[1]), + lambda args: jnp.cos(args[2]), + ), + branch_args, + ) + + branch_args = tuple(testutil.make_array((100, 100)) for _ in range(3)) + + wrapped = jace.jit(testee) + lowered = wrapped.lower(branch_args) + compiled = lowered.compile(jace.optimization.DEFAULT_OPTIMIZATIONS) + + ref = testee(branch_args) + res = wrapped(branch_args) + + assert ref.shape == res.shape + assert np.allclose(ref, res) + lowered.as_sdfg().view() + assert compiled._compiled_sdfg.sdfg.number_of_nodes() == 1 diff --git a/tests/integration_tests/primitive_translators/test_primitive_convert_element_type.py b/tests/integration_tests/primitive_translators/test_primitive_convert_element_type.py new file mode 100644 index 0000000..2e3664d --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_convert_element_type.py @@ -0,0 +1,87 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests the element type conversion functionality. + +Todo: + The tests should only run on certain occasion. +""" + +from __future__ import annotations + +from typing import Final + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +pytest.skip("Takes too long", allow_module_level=True) + + +# fmt: off +_DACE_REAL_TYPES: Final[list[type]] = [ # type: ignore[unreachable] + np.int_, np.int8, np.int16, np.int32, np.int64, + np.uint, np.uint8, np.uint16, np.uint32, np.uint64, + np.float64, np.float32, np.float64, +] +_DACE_COMPLEX_TYPES: Final[list[type]] = [ + np.complex128, np.complex64, np.complex128, +] +# fmt: on + + +@pytest.fixture(params=_DACE_REAL_TYPES) +def src_type(request) -> type: + """All valid source types, with the exception of bool.""" + return request.param + + +@pytest.fixture(params=_DACE_REAL_TYPES + _DACE_COMPLEX_TYPES) +def dst_type(request) -> type: + """All valid destination types, with the exception of bool. + + Includes also complex types, because going from real to complex is useful, + but the other way is not. + """ + return request.param + + +def _convert_element_type_impl(input_type: type, output_type: type) -> None: + """Implementation of the tests of the convert element types primitive.""" + lowering_cnt = [0] + a: np.ndarray = testutil.make_array((10, 10), input_type) + ref: np.ndarray = np.array(a, copy=True, dtype=output_type) + + @jace.jit + def converter(a: np.ndarray) -> jax.Array: + lowering_cnt[0] += 1 + return jnp.array(a, copy=False, dtype=output_type) + + res = converter(a) + assert lowering_cnt[0] == 1 + assert ( + res.dtype == output_type + ), f"Expected '{output_type}', but got '{res.dtype}', input was '{input_type}'." + assert np.allclose(ref, res) + + +def test_convert_element_type_main(src_type: type, dst_type: type) -> None: + _convert_element_type_impl(src_type, dst_type) + + +def test_convert_element_type_from_bool(src_type: type) -> None: + _convert_element_type_impl(np.bool_, src_type) + + +def test_convert_element_type_to_bool(src_type: type) -> None: + _convert_element_type_impl(src_type, np.bool_) diff --git a/tests/integration_tests/primitive_translators/test_primitive_copy.py b/tests/integration_tests/primitive_translators/test_primitive_copy.py new file mode 100644 index 0000000..11fefc9 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_copy.py @@ -0,0 +1,29 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import jax +import numpy as np +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +def test_copy() -> None: + @jace.jit + def testee(a: np.ndarray) -> jax.Array: + return jnp.copy(a) + + a = testutil.make_array((10, 10, 10)) + res = testee(a) + assert a.dtype == res.dtype + assert a.shape == res.shape + assert a.__array_interface__["data"][0] != res.__array_interface__["data"][0] # type: ignore[attr-defined] + assert np.all(res == a) diff --git a/tests/integration_tests/primitive_translators/test_primitive_gather.py b/tests/integration_tests/primitive_translators/test_primitive_gather.py new file mode 100644 index 0000000..35cfbb2 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_gather.py @@ -0,0 +1,83 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import numpy as np +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +def _perform_gather_test( + testee: Callable, + *args: Any, +) -> None: + wrapped = jace.jit(testee) + + expected = testee(*args) + result = wrapped(*args) + + assert np.allclose(expected, result) + + +def test_gather_simple_1(): + def testee( + a: np.ndarray, + idx: np.ndarray, + ) -> np.ndarray: + return a[idx] + + a = testutil.make_array(100) + idx = testutil.make_array(300, dtype=np.int32, low=0, high=100) + _perform_gather_test(testee, a, idx) + + +def test_gather_1(): + def testee( + a: np.ndarray, + idx: np.ndarray, + ) -> np.ndarray: + return a[idx, :, idx] + + a = testutil.make_array((300, 3, 300)) + idx = testutil.make_array(400, dtype=np.int32, low=1, high=300) + _perform_gather_test(testee, a, idx) + + +def test_gather_2(): + def testee( + a: np.ndarray, + idx: np.ndarray, + ) -> np.ndarray: + return a[idx, :, :] + + a = testutil.make_array((300, 3, 300)) + idx = testutil.make_array(400, dtype=np.int32, low=1, high=300) + _perform_gather_test(testee, a, idx) + + +def test_gather_3(): + def testee( + a: np.ndarray, + b: np.ndarray, + idx: np.ndarray, + idx2: np.ndarray, + ) -> np.ndarray: + c = jnp.sin(a) + b + return jnp.exp(c[idx, :, idx2]) # type: ignore[return-value] # Type confusion. + + a = testutil.make_array((300, 3, 300)) + b = testutil.make_array((300, 3, 300)) + idx = testutil.make_array(400, dtype=np.int32, low=1, high=300) + idx2 = testutil.make_array(400, dtype=np.int32, low=1, high=300) + _perform_gather_test(testee, a, b, idx, idx2) diff --git a/tests/integration_tests/primitive_translators/test_primitive_iota.py b/tests/integration_tests/primitive_translators/test_primitive_iota.py new file mode 100644 index 0000000..14e4ac0 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_iota.py @@ -0,0 +1,38 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + + +def test_iota_arange() -> None: + def testee(a: int) -> jax.Array: + return jnp.arange(18, dtype=int) + a + + ref = testee(0) + res = jace.jit(testee)(0) + assert np.all(ref == res) + + +@pytest.mark.parametrize("d", [0, 1, 2, 3]) +def test_iota_broadcast(d) -> None: + shape = (2, 2, 2, 2) + + def testee(a: np.int32) -> jax.Array: + return jax.lax.broadcasted_iota("int32", shape, d) + a + + ref = testee(np.int32(0)) + res = jace.jit(testee)(np.int32(0)) + + assert res.shape == shape + assert np.all(ref == res), f"Expected: {ref.tolist()}; Got: {res.tolist()}" diff --git a/tests/integration_tests/primitive_translators/test_primitive_pjit.py b/tests/integration_tests/primitive_translators/test_primitive_pjit.py new file mode 100644 index 0000000..512185b --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_pjit.py @@ -0,0 +1,68 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the `pjit` primitive.""" + +from __future__ import annotations + +from collections.abc import Generator + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +@pytest.fixture(autouse=True) +def _disable_jit() -> Generator[None, None, None]: + """Overwrites the global `_disable_jit` fixture and enables jit operations.""" + with jax.disable_jit(disable=False): + yield + + +def test_pjit_simple() -> None: + """Simple nested Jaxpr expression.""" + + def testee(a: np.ndarray) -> np.ndarray: + return jax.jit(lambda a: jnp.sin(a))(a) # noqa: PLW0108 [unnecessary-lambda] # Lambda needed to trigger a `pjit` level. + + a = testutil.make_array((10, 10)) + + jace_wrapped = jace.jit(testee) + jace_lowered = jace_wrapped.lower(a) + res = jace_wrapped(a) + ref = testee(a) + + assert jace_lowered._jaxpr.eqns[0].primitive.name == "pjit" + assert np.allclose(res, ref) + assert res.dtype == ref.dtype + assert res.shape == ref.shape + + +@pytest.mark.parametrize("shape", [(10, 10), ()]) +def test_pjit_literal(shape) -> None: + """Test for `pjit` with literal inputs.""" + + def testee(pred: np.ndarray, fbranch: np.ndarray) -> jax.Array: + return jnp.where(pred, 2, fbranch) + + pred = testutil.make_array(shape, np.bool_) + fbranch = pred * 0 + + jace_wrapped = jace.jit(testee) + jace_lowered = jace_wrapped.lower(pred, fbranch) + res = jace_wrapped(pred, fbranch) + ref = testee(pred, fbranch) + + assert np.all(ref == res) + assert jace_lowered._jaxpr.eqns[0].primitive.name == "pjit" + assert any(isinstance(invar, jax.core.Literal) for invar in jace_lowered._jaxpr.eqns[0].invars) + assert res.dtype == ref.dtype diff --git a/tests/integration_tests/primitive_translators/test_primitive_reshape.py b/tests/integration_tests/primitive_translators/test_primitive_reshape.py new file mode 100644 index 0000000..4a504c5 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_reshape.py @@ -0,0 +1,76 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests the rehaping functionality.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +def _test_impl_reshaping( + src_shape: Sequence[int], dst_shape: Sequence[int], order: str = "C" +) -> None: + """Performs a reshaping from `src_shape` to `dst_shape`.""" + a = testutil.make_array(src_shape, order=order) + + def testee(a: np.ndarray) -> jax.Array: + return jnp.reshape(a, dst_shape) + + ref = testee(a) + res = jace.jit(testee)(a) + + assert res.shape == dst_shape + assert np.all(res == ref) + + +@pytest.fixture(params=["C", "F"]) +def mem_order(request) -> str: + """Gets the memory order that we want.""" + return request.param + + +@pytest.fixture(params=[(216, 1, 1), (1, 216, 1), (1, 1, 216), (1, 6, 36), (36, 1, 6)]) +def new_shape(request) -> None: + """New shapes for the `test_reshaping_same_rank()` test.""" + return request.param + + +@pytest.fixture(params=[(12, 1), (1, 12), (1, 1, 12), (1, 2, 6)]) +def expanded_shape(request) -> None: + """New shapes for the `test_reshaping_removing_rank()` test.""" + return request.param + + +@pytest.fixture(params=[(216,), (6, 36), (36, 6), (216, 1)]) +def reduced_shape(request) -> None: + """New shapes for the `test_reshaping_adding_rank()` test.""" + return request.param + + +def test_reshaping_same_rank(new_shape: Sequence[int], mem_order: str) -> None: + """The rank, numbers of dimensions, stays the same,""" + _test_impl_reshaping((6, 6, 6), new_shape, mem_order) + + +def test_reshaping_adding_rank(expanded_shape: Sequence[int], mem_order: str) -> None: + """Adding ranks to an array.""" + _test_impl_reshaping((12,), expanded_shape, mem_order) + + +def test_reshaping_removing_rank(reduced_shape: Sequence[int], mem_order: str) -> None: + """Removing ranks from an array.""" + _test_impl_reshaping((6, 6, 6), reduced_shape, mem_order) diff --git a/tests/integration_tests/primitive_translators/test_primitive_select_n.py b/tests/integration_tests/primitive_translators/test_primitive_select_n.py new file mode 100644 index 0000000..a0cab1f --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_select_n.py @@ -0,0 +1,81 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests the `select_n` translator.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +@pytest.fixture(params=[True, False]) +def pred(request) -> np.bool_: + """Predicate used in the `test_mapped_unary_scalar_literal_*` tests.""" + return np.bool_(request.param) + + +def _perform_test(testee: Callable, *args: Any) -> None: + res = testee(*args) + ref = jace.jit(testee)(*args) + assert np.all(res == ref) + + +def test_select_n_where() -> None: + def testee(pred: np.ndarray, tbranch: np.ndarray, fbranch: np.ndarray) -> jax.Array: + return jnp.where(pred, tbranch, fbranch) + + shape = (10, 10) + pred = testutil.make_array(shape, np.bool_) + tbranch = testutil.make_array(shape) + fbranch = testutil.make_array(shape) + _perform_test(testee, pred, tbranch, fbranch) + + +def test_select_n_where_literal_1(pred) -> None: + def testee(pred: np.ndarray, fbranch: np.ndarray) -> jax.Array: + return jnp.where(pred, 2, fbranch) + + fbranch = 1 + _perform_test(testee, pred, fbranch) + + +def test_select_n_where_literal_2(pred) -> None: + def testee(pred: np.ndarray, tbranch: np.ndarray) -> jax.Array: + return jnp.where(pred, tbranch, 3) + + tbranch = 2 + _perform_test(testee, pred, tbranch) + + +def test_select_n_where_literal_3(pred) -> None: + def testee(pred: np.ndarray) -> jax.Array: + return jnp.where(pred, 8, 9) + + _perform_test(testee, pred) + + +def test_select_n_many_inputs() -> None: + """Tests the generalized way of using the primitive.""" + + def testee(pred: np.ndarray, *cases: np.ndarray) -> jax.Array: + return jax.lax.select_n(pred, *cases) + + nbcases = 10 + shape = (10, 10) + cases = [np.full(shape, i) for i in range(nbcases)] + pred = np.arange(cases[0].size).reshape(shape) % nbcases + _perform_test(testee, pred, *cases) diff --git a/tests/integration_tests/primitive_translators/test_primitive_slicing.py b/tests/integration_tests/primitive_translators/test_primitive_slicing.py new file mode 100644 index 0000000..6df48e8 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_slicing.py @@ -0,0 +1,101 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for slicing translator.""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest + +import jace + +from tests import util as testutil + + +@pytest.fixture() +def a_20x20x20() -> np.ndarray: + return testutil.make_array((20, 20, 20)) + + +@pytest.fixture() +def a_4x4x4x4() -> np.ndarray: + return testutil.make_array((4, 4, 4, 4)) + + +@pytest.fixture( + params=[ + (1, 2, 1, 2), + (0, 0, 0, 0), + (3, 3, 3, 3), # Will lead to readjustment of the start index. + (3, 1, 3, 0), # Will lead to readjustment of the start index. + ] +) +def full_dynamic_start_idx(request) -> tuple[int, int, int, int]: + """Start indexes for the slice window of `test_dynamic_slice_full_dynamic()`.""" + return request.param + + +def test_slice_no_strides(a_20x20x20: np.ndarray) -> None: + """Test without strides.""" + + def testee(a: np.ndarray) -> jax.Array: + # Read as: a[2:18, 3:19, 4:17] + return jax.lax.slice(a, (2, 3, 4), (18, 19, 17), None) + + ref = testee(a_20x20x20) + res = jace.jit(testee)(a_20x20x20) + + assert ref.shape == res.shape + assert np.all(ref == res) + + +def test_slice_strides(a_20x20x20: np.ndarray) -> None: + """Test with strides.""" + + def testee(a: np.ndarray) -> jax.Array: + # Read as: a[2:18:1, 3:19:2, 4:17:3] + return jax.lax.slice(a, (2, 3, 4), (18, 19, 17), (1, 2, 3)) + + ref = testee(a_20x20x20) + res = jace.jit(testee)(a_20x20x20) + + assert ref.shape == res.shape + assert np.all(ref == res) + + +def test_dynamic_slice_full_dynamic( + a_4x4x4x4: np.ndarray, full_dynamic_start_idx: tuple[int, int, int, int] +) -> None: + def testee(a: np.ndarray, s1: int, s2: int, s3: int, s4: int) -> jax.Array: + return jax.lax.dynamic_slice(a, (s1, s2, s3, s4), (2, 2, 2, 2)) + + res = jace.jit(testee)(a_4x4x4x4, *full_dynamic_start_idx) + ref = testee(a_4x4x4x4, *full_dynamic_start_idx) + + assert np.all(ref == res) + + +def test_dynamic_slice_partially_dynamic(a_4x4x4x4: np.ndarray) -> None: + def testee(a: np.ndarray, s1: int, s2: int) -> jax.Array: + return jax.lax.dynamic_slice(a, (s1, 1, s2, 2), (2, 2, 2, 2)) + + res = jace.jit(testee)(a_4x4x4x4, 1, 2) + ref = testee(a_4x4x4x4, 1, 2) + + assert np.all(ref == res) + + +def test_dynamic_slice_full_literal(a_4x4x4x4: np.ndarray) -> None: + def testee(a: np.ndarray) -> jax.Array: + return jax.lax.dynamic_slice(a, (0, 1, 0, 3), (2, 2, 2, 2)) + + res = jace.jit(testee)(a_4x4x4x4) + ref = testee(a_4x4x4x4) + + assert np.all(ref == res) diff --git a/tests/integration_tests/primitive_translators/test_primitive_squeeze_expand_dims.py b/tests/integration_tests/primitive_translators/test_primitive_squeeze_expand_dims.py new file mode 100644 index 0000000..c823cf6 --- /dev/null +++ b/tests/integration_tests/primitive_translators/test_primitive_squeeze_expand_dims.py @@ -0,0 +1,69 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests about the `squeeze` primitive. + +For several reasons parts of the tests related to broadcasting, especially the +ones in which a single dimension is added, are also here. This is because of +the inverse relationship between `expand_dims` and `squeeze`. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import jax +import numpy as np +import pytest +from jax import numpy as jnp + +import jace + +from tests import util as testutil + + +def _roundtrip_implementation(shape: Sequence[int], axis: int | Sequence[int]) -> None: + """Implementation of the test for `expand_dims()` and `squeeze()`. + + It will first add dimensions and then remove them. + + Args: + shape: Shape of the input array. + axes: A series of axis that should be tried. + """ + a = testutil.make_array(shape) + a_org = a.copy() + + for ops in [jnp.expand_dims, jnp.squeeze]: + with jax.experimental.enable_x64(): + ref = ops(a, axis) # type: ignore[operator] # Function of unknown type. + res = jace.jit(lambda a: ops(a, axis))(a) # type: ignore[operator] # noqa: B023 [function-uses-loop-variable] + + assert ref.shape == res.shape, f"a.shape = {shape}; Expected: {ref.shape}; Got: {res.shape}" + assert ref.dtype == res.dtype + assert np.all(ref == res), f"Value error for shape '{shape}' and axis={axis}" + a = np.array(ref, copy=True) # It is a JAX array, and we have to reverse this. + assert a_org.shape == res.shape + assert np.all(a_org == res) + + +@pytest.fixture(params=[0, -1, 1]) +def single_axis(request) -> int: + return request.param + + +@pytest.fixture(params=[0, -1, (1, 2, 3), (3, 2, 1)]) +def multiple_axis(request) -> tuple[int, ...] | int: + return request.param + + +def test_expand_squeeze_rountrip_simple(single_axis: int) -> None: + _roundtrip_implementation((10,), single_axis) + + +def test_expand_squeeze_rountrip_big(multiple_axis: Sequence[int]) -> None: + _roundtrip_implementation((2, 3, 4, 5), multiple_axis) diff --git a/tests/integration_tests/test_empty_jaxpr.py b/tests/integration_tests/test_empty_jaxpr.py new file mode 100644 index 0000000..598efc9 --- /dev/null +++ b/tests/integration_tests/test_empty_jaxpr.py @@ -0,0 +1,118 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for empty jaxprs. + +Todo: + Add more tests that are related to `cond`, i.e. not all inputs are needed. +""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest + +import jace + + +def test_empty_single_return() -> None: + @jace.jit + def wrapped(a: np.ndarray) -> np.ndarray: + return a + + a = np.arange(12, dtype=np.float64).reshape((4, 3)) + res = wrapped(a) + + assert np.all(res == a) + assert res.__array_interface__["data"][0] != a.__array_interface__["data"][0] + + +def test_empty_multiple_return() -> None: + @jace.jit + def wrapped(a: np.ndarray, b: np.float64) -> tuple[np.ndarray, np.float64]: + return a, b + + a = np.arange(12, dtype=np.float64).reshape((4, 3)) + b = np.float64(30.0) + res = wrapped(a, b) + + assert np.all(res[0] == a) + assert res[1] == b + assert res[0].__array_interface__["data"][0] != a.__array_interface__["data"][0] + + +def test_empty_unused_argument() -> None: + """Empty body and an unused input argument.""" + + @jace.jit + def wrapped(a: np.ndarray, b: np.float64) -> np.ndarray: # noqa: ARG001 [unused-function-argument] + return a + + a = np.arange(12, dtype=np.float64).reshape((4, 3)) + b = np.float64(30.0) + lowered = wrapped.lower(a, b) + compiled = lowered.compile() + res = compiled(a, b) + + assert len(lowered._translated_sdfg.input_names) == 2 + assert len(compiled._compiled_sdfg.input_names) == 2 + assert isinstance(res, np.ndarray) + assert np.all(res == a) + assert res.__array_interface__["data"][0] != a.__array_interface__["data"][0] + + +def test_empty_scalar() -> None: + @jace.jit + def wrapped(a: np.float64) -> np.float64: + return a + + a = np.float64(np.pi) + + assert np.all(wrapped(a) == a) + + +def test_empty_nested() -> None: + @jace.jit + def wrapped(a: np.float64) -> np.float64: + return jax.jit(lambda a: a)(a) + + a = np.float64(np.pi) + + assert np.all(wrapped(a) == a) + + +@pytest.mark.skip(reason="Literal return value is not implemented.") +def test_empty_literal_return() -> None: + """An empty Jaxpr that only contains a literal return value.""" + + def testee() -> np.float64: + return np.float64(3.1415) + + ref = testee() + res = jace.jit(testee)() + + assert np.all(res == ref) + + +@pytest.mark.skip(reason="Literal return value is not implemented.") +def test_empty_with_drop_vars() -> None: + """Jaxpr only containing drop variables. + + Notes: + As a side effect the Jaxpr also has a literal return value. + """ + + @jace.grad + def testee(a: np.float64, b: np.float64) -> np.float64: + return a + b + + a = np.e + ref = testee(a) + res = jace.jit(testee)(a) + + assert np.all(ref == res) diff --git a/tests/integration_tests/test_jaxpr_translator_builder.py b/tests/integration_tests/test_jaxpr_translator_builder.py new file mode 100644 index 0000000..40b4fff --- /dev/null +++ b/tests/integration_tests/test_jaxpr_translator_builder.py @@ -0,0 +1,673 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the `JaxprTranslationBuilder` object. + +Although this is an integration test, the tests here manipulate the builder on +a low and direct level. +""" + +from __future__ import annotations + +import re + +import dace +import jax +import numpy as np +import pytest +from dace import data as dcdata +from jax import core as jax_core, numpy as jnp + +import jace +from jace import translator, util +from jace.util import JaCeVar + +from tests import util as testutil + + +# These are some JaCe variables that we use inside the tests +# Unnamed arrays +array1 = JaCeVar((10, 12), dace.float64) +array2 = JaCeVar((10, 13), dace.float32) +array3 = JaCeVar((11, 16), dace.int64) + +# Unnamed scalars +scal1 = JaCeVar((), dace.float16) +scal2 = JaCeVar((), dace.float32) +scal3 = JaCeVar((), dace.int64) + +# Named variables +narray = JaCeVar((10,), dace.float16, "narr") +nscal = JaCeVar((), dace.int32, "nscal") + + +@pytest.fixture() +def translation_builder() -> translator.JaxprTranslationBuilder: + """Returns an allocated builder instance.""" + name = "fixture_builder" + builder = translator.JaxprTranslationBuilder( + primitive_translators=translator.get_registered_primitive_translators() + ) + jaxpr = jax.make_jaxpr(lambda a: a)(1.0) # dummy jaxpr, needed for construction. + builder._allocate_translation_ctx(name=name, jaxpr=jaxpr) + return builder + + +def test_builder_alloc() -> None: + """Tests for correct allocation.""" + builder = translator.JaxprTranslationBuilder( + primitive_translators=translator.get_registered_primitive_translators() + ) + assert not builder.is_allocated(), "Builder was created allocated." + assert len(builder._ctx_stack) == 0 + + # The reserved names will be tested in `test_builder_fork()`. + sdfg_name = "qwertzuiopasdfghjkl" + jaxpr = jax.make_jaxpr(lambda x: x)(1.0) # dummy jaxpr, needed for construction. + builder._allocate_translation_ctx(name=sdfg_name, jaxpr=jaxpr) + assert len(builder._ctx_stack) == 1 + assert builder.is_root_translator() + + sdfg: dace.SDFG = builder.sdfg + + assert builder._ctx.sdfg is sdfg + assert builder.sdfg.name == sdfg_name + assert sdfg.number_of_nodes() == 1 + assert sdfg.number_of_edges() == 0 + assert sdfg.start_block is builder._ctx.start_state + assert builder._terminal_sdfg_state is builder._ctx.start_state + + +def test_builder_variable_alloc_auto_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests if autonaming of variables works.""" + for i, var in enumerate([array1, array2, scal1, array3, scal2, scal3]): + sdfg_name = translation_builder.add_array(var, update_var_mapping=True) + sdfg_var = translation_builder.get_array(sdfg_name) + assert sdfg_name == chr(97 + i) + if var.shape == (): + assert isinstance(sdfg_var, dcdata.Scalar) + else: + assert isinstance(sdfg_var, dcdata.Array) + assert sdfg_var.shape == var.shape + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_mixed_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Test automatic naming if there are variables with a given name. + + See also `test_builder_variable_alloc_mixed_naming2()`. + """ + # * b c d * f g + for i, var in enumerate([narray, array1, array2, scal1, nscal, scal2, scal3]): + sdfg_name = translation_builder.add_array(var, update_var_mapping=True) + sdfg_var = translation_builder.get_array(sdfg_name) + if var.name is None: + assert sdfg_name == chr(97 + i) + else: + assert sdfg_name == var.name + if var.shape == (): + assert isinstance(sdfg_var, dcdata.Scalar) + else: + assert isinstance(sdfg_var, dcdata.Array) + assert sdfg_var.shape == var.shape + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_mixed_naming2( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests the naming in a mixed setting. + + This time we do not use `update_var_mapping=True`, instead it now depends on the + name. This means that automatic naming will now again include all, letters, but not + in a linear order. + """ + letoff = 0 + # * a b c * d e + for var in [narray, array1, array2, scal1, nscal, scal2, scal3]: + sdfg_name = translation_builder.add_array(var, update_var_mapping=var.name is None) + sdfg_var = translation_builder.get_array(sdfg_name) + if var.name is None: + assert sdfg_name == chr(97 + letoff) + letoff += 1 + else: + assert sdfg_name == var.name + if var.shape == (): + assert isinstance(sdfg_var, dcdata.Scalar) + else: + assert isinstance(sdfg_var, dcdata.Array) + assert sdfg_var.shape == var.shape + assert sdfg_var.dtype == var.dtype + + +def test_builder_variable_alloc_auto_naming_wrapped( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests the variable naming if we have more than 26 variables.""" + single_letters = [chr(x) for x in range(97, 123)] + i = 0 + for let1 in ["", *single_letters[1:]]: # Note `z` is followed by `ba` and not by `aa`. + for let2 in single_letters: + i += 1 + # Create a variable and enter it into the variable naming. + var = JaCeVar(shape=(19, 19), dtype=dace.float64) + sdfg_name = translation_builder.add_array(arg=var, update_var_mapping=True) + mapped_name = translation_builder.map_jax_var_to_sdfg(var) + assert ( + sdfg_name == mapped_name + ), f"Mapping for '{var}' failed, expected '{sdfg_name}' got '{mapped_name}'." + + # Get the name that we really expect, we must also handle some situations. + exp_name = let1 + let2 + if exp_name in util.FORBIDDEN_SDFG_VAR_NAMES: + exp_name = "__jace_forbidden_" + exp_name + assert ( + exp_name == sdfg_name + ), f"Automated naming failed, expected '{exp_name}' but got '{sdfg_name}'." + + +def test_builder_variable_alloc_prefix_naming( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Using the prefix to name variables.""" + prefix_1 = "__my_special_prefix" + exp_name_1 = prefix_1 + "a" + sdfg_name_1 = translation_builder.add_array( + array1, name_prefix=prefix_1, update_var_mapping=False + ) + assert exp_name_1 == sdfg_name_1 + + # Because `update_var_mapping` is `False` above, 'a' will be reused. + prefix_2 = "__my_special_prefix_second_" + exp_name_2 = prefix_2 + "a" + sdfg_name_2 = translation_builder.add_array( + array1, name_prefix=prefix_2, update_var_mapping=False + ) + assert exp_name_2 == sdfg_name_2 + + # Now we use a named variables, which are also affected. + prefix_3 = "__my_special_prefix_third_named_" + exp_name_3 = prefix_3 + nscal.name # type: ignore[operator] # `.name` is not `None`. + sdfg_name_3 = translation_builder.add_array( + nscal, name_prefix=prefix_3, update_var_mapping=False + ) + assert exp_name_3 == sdfg_name_3 + + +def test_builder_nested(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests the ability of the nesting of the builder.""" + + # Now add a variable to the current subtext. + name_1 = translation_builder.add_array(array1, update_var_mapping=True) + assert name_1 == "a" + assert translation_builder.map_jax_var_to_sdfg(array1) == name_1 + assert translation_builder.sdfg.arrays[name_1] is translation_builder.get_array(array1) + assert translation_builder.sdfg.arrays[name_1] is translation_builder.get_array(name_1) + + # For the sake of doing it add a new state to the SDFG. + translation_builder.append_new_state("sake_state") + assert translation_builder.sdfg.number_of_nodes() == 2 + assert translation_builder.sdfg.number_of_edges() == 1 + + # Now we go one subcontext deeper. + jaxpr = jax.make_jaxpr(lambda x: x)(1.0) # dummy jaxpr, needed for construction. + translation_builder._allocate_translation_ctx(name="builder", jaxpr=jaxpr) + assert len(translation_builder._ctx_stack) == 2 + assert translation_builder.sdfg.name == "builder" + assert translation_builder.sdfg.number_of_nodes() == 1 + assert translation_builder.sdfg.number_of_edges() == 0 + assert not translation_builder.is_root_translator() + + # Because we have a new SDFG the mapping to previous SDFG does not work, + # regardless the fact that it still exists. + with pytest.raises( + expected_exception=KeyError, + match=re.escape( + f"JAX variable '{array1}' was supposed to map to '{name_1}', but no such SDFG variable is known." + ), + ): + _ = translation_builder.map_jax_var_to_sdfg(array1) + + # Because the SDFGs are distinct it is possible to add `array1` to the nested one. + # However, it is not able to update the mapping. + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"Cannot change the mapping of '{array1}' from '{name_1}' to '{name_1}'."), + ): + _ = translation_builder.add_array(array1, update_var_mapping=True) + assert name_1 not in translation_builder.sdfg.arrays + + # Without updating the mapping it is possible create the variable. + assert name_1 == translation_builder.add_array(array1, update_var_mapping=False) + + # Now add a new variable, the map is shared, so a new name will be generated. + name_2 = translation_builder.add_array(array2, update_var_mapping=True) + assert name_2 == "b" + assert name_2 == translation_builder.map_jax_var_to_sdfg(array2) + + # Now we go one stack level back. + translation_builder._clear_translation_ctx() + assert len(translation_builder._ctx_stack) == 1 + assert translation_builder.sdfg.number_of_nodes() == 2 + assert translation_builder.sdfg.number_of_edges() == 1 + + # Again the variable that was declared in the last stack is now no longer present. + # Note if the nested SDFG was integrated into the parent SDFG it would be + # accessible + with pytest.raises( + expected_exception=KeyError, + match=re.escape( + f"JAX variable '{array2}' was supposed to map to '{name_2}', but no such SDFG variable is known." + ), + ): + _ = translation_builder.map_jax_var_to_sdfg(array2) + assert name_2 == translation_builder._jax_name_map[array2] + + # Now add a new variable, since the map is shared, we will now get the next name. + name_3 = translation_builder.add_array(array3, update_var_mapping=True) + assert name_3 == "c" + assert name_3 == translation_builder.map_jax_var_to_sdfg(array3) + + +def test_builder_append_state(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests the functionality of appending states.""" + sdfg: dace.SDFG = translation_builder.sdfg + + terminal_state_1: dace.SDFGState = translation_builder.append_new_state("terminal_state_1") + assert sdfg.number_of_nodes() == 2 + assert sdfg.number_of_edges() == 1 + assert terminal_state_1 is translation_builder._terminal_sdfg_state + assert translation_builder._terminal_sdfg_state is translation_builder._ctx.terminal_state + assert translation_builder._ctx.start_state is sdfg.start_block + assert translation_builder._ctx.start_state is not terminal_state_1 + assert next(iter(sdfg.edges())).src is sdfg.start_block + assert next(iter(sdfg.edges())).dst is terminal_state_1 + + # Specifying an explicit append state that is the terminal should also update the + # terminal state of the builder. + terminal_state_2: dace.SDFGState = translation_builder.append_new_state( + "terminal_state_2", prev_state=terminal_state_1 + ) + assert sdfg.number_of_nodes() == 3 + assert sdfg.number_of_edges() == 2 + assert terminal_state_2 is translation_builder._terminal_sdfg_state + assert sdfg.out_degree(terminal_state_1) == 1 + assert sdfg.out_degree(terminal_state_2) == 0 + assert sdfg.in_degree(terminal_state_2) == 1 + assert next(iter(sdfg.in_edges(terminal_state_2))).src is terminal_state_1 + + # Specifying a previous node that is not the terminal state should not do anything. + non_terminal_state: dace.SDFGState = translation_builder.append_new_state( + "non_terminal_state", prev_state=terminal_state_1 + ) + assert translation_builder._terminal_sdfg_state is not non_terminal_state + assert sdfg.in_degree(non_terminal_state) == 1 + assert sdfg.out_degree(non_terminal_state) == 0 + assert next(iter(sdfg.in_edges(non_terminal_state))).src is terminal_state_1 + + +def test_builder_variable_multiple_versions( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Add an already known variable, but with a different name.""" + # Now we will add `array1` and then different ways of updating it. + narray1: str = translation_builder.add_array(array1, update_var_mapping=True) + + # It will fail if we use the prefix, because we also want to update. + prefix = "__jace_prefix" + prefix_expected_name = prefix + narray1 + with pytest.raises( + expected_exception=ValueError, + match=re.escape( + f"Cannot change the mapping of '{array1}' from '{translation_builder.map_jax_var_to_sdfg(array1)}' to '{prefix_expected_name}'." + ), + ): + _ = translation_builder.add_array(array1, update_var_mapping=True, name_prefix=prefix) + assert prefix_expected_name not in translation_builder.sdfg.arrays + + # But if we do not want to update it then it works. + prefix_sdfg_name = translation_builder.add_array( + array1, update_var_mapping=False, name_prefix=prefix + ) + assert prefix_expected_name == prefix_sdfg_name + assert prefix_expected_name in translation_builder.sdfg.arrays + assert narray1 == translation_builder.map_jax_var_to_sdfg(array1) + + +def test_builder_variable_invalid_prefix( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Use invalid prefix.""" + # It will fail if we use the prefix, because we also want to update. + for iprefix in ["0_", "_ja ", "_!"]: + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"add_array({array1}): The proposed name '{iprefix}a', is invalid."), + ): + _ = translation_builder.add_array(array1, update_var_mapping=False, name_prefix=iprefix) + assert len(translation_builder.sdfg.arrays) == 0 + + +def test_builder_variable_alloc_list( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api.""" + var_list_1 = [array1, nscal, scal2] + exp_names_1 = ["a", nscal.name, "c"] + + res_names_1 = translation_builder.create_jax_var_list(var_list_1, update_var_mapping=True) + assert len(translation_builder.arrays) == 3 + assert res_names_1 == exp_names_1 + + # Now a mixture of the collection and creation. + var_list_2 = [array2, nscal, scal1] + exp_names_2 = ["d", nscal.name, "e"] + + res_names_2 = translation_builder.create_jax_var_list(var_list_2, update_var_mapping=True) + assert res_names_2 == exp_names_2 + assert len(translation_builder.arrays) == 5 + + +@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") +def test_builder_variable_alloc_list_cleaning( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will fail because `update_var_mapping=False` thus the third variable will + cause an error because it is proposed to `a`, which is already used. + """ + var_list = [array1, nscal, scal2] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"add_array({scal2}): The proposed name 'a', is used."), + ): + _ = translation_builder.create_jax_var_list(var_list) + + assert len(translation_builder.arrays) == 0 + + +def test_builder_variable_alloc_list_prevent_creation( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `prevent_creation` flag. + """ + # First create a variable. + translation_builder.add_array(array1, update_var_mapping=True) + assert len(translation_builder.arrays) == 1 + + # Now create the variables + var_list = [array1, array2] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"'prevent_creation' given but have to create '{array2}'."), + ): + translation_builder.create_jax_var_list(var_list, prevent_creation=True) + assert len(translation_builder.arrays) == 1 + assert translation_builder.map_jax_var_to_sdfg(array1) == "a" + + +@pytest.mark.skip(reason="'create_jax_var_list()' does not clean up in case of an error.") +def test_builder_variable_alloc_list_only_creation( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `only_creation` flag. + """ + # First create a variable. + translation_builder.add_array(array1, update_var_mapping=True) + assert len(translation_builder.arrays) == 1 + + # Now create the variables + var_list = [array2, array1] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape(f"'only_creation' given '{array1}' already exists."), + ): + translation_builder.create_jax_var_list(var_list, only_creation=True) + assert len(translation_builder.arrays) == 1 + assert translation_builder.map_jax_var_to_sdfg(array1) == "a" + + +def test_builder_variable_alloc_list_handle_literal( + translation_builder: translator.JaxprTranslationBuilder, +) -> None: + """Tests part of the `JaxprTranslationBuilder.create_jax_var_list()` api. + + It will test the `handle_literals` flag. + """ + + val = np.array(1) + aval = jax_core.get_aval(val) + lit = jax_core.Literal(val, aval) + var_list = [lit] + + with pytest.raises( + expected_exception=ValueError, + match=re.escape("Encountered a literal but `handle_literals` was `False`."), + ): + translation_builder.create_jax_var_list(var_list, handle_literals=False) + assert len(translation_builder.arrays) == 0 + + name_list = translation_builder.create_jax_var_list(var_list, handle_literals=True) + assert len(translation_builder.arrays) == 0 + assert name_list == [None] + + +def test_builder_constants(translation_builder: translator.JaxprTranslationBuilder) -> None: + """Tests part of the `JaxprTranslationBuilder._create_constants()` api. + + See also the `test_subtranslators_alu.py::test_add3` test. + """ + # Create the Jaxpr that we need. + constant = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + jaxpr = jax.make_jaxpr(lambda x: x + jax.numpy.array(constant))(1.0) + + # We have to manually allocate the builder context. + # You should not do that. + translation_builder._allocate_translation_ctx(name="Manual_test", jaxpr=jaxpr) + + # No create the constants. + translation_builder._create_constants(jaxpr) + + # Test if it was created with the correct value. + assert len(translation_builder.arrays) == 1 + assert len(translation_builder._jax_name_map) == 1 + assert next(iter(translation_builder._jax_name_map.values())) == "__const_a" + assert len(translation_builder.sdfg.constants) == 1 + assert np.all(translation_builder.sdfg.constants["__const_a"] == constant) + + +def test_builder_scalar_return_value() -> None: + """Tests if scalars can be returned directly.""" + + def scalar_ops(a: float) -> float: + return a + a - a * a + + lower_cnt = [0] + + @jace.jit + def wrapped(a: float) -> float: + lower_cnt[0] += 1 + return scalar_ops(a) + + vals = testutil.make_array(100) + for i in range(vals.size): + res = wrapped(vals[i]) + ref = scalar_ops(vals[i]) + assert np.allclose(res, ref) + assert lower_cnt[0] == 1 + + +def test_builder_scalar_return_type() -> None: + """As JAX we always return an array, even for a scalar.""" + + @jace.jit + def wrapped(a: np.float64) -> np.float64: + return a + a - a * a + + a = np.float64(1.0) + res = wrapped(a) + assert res.shape == (1,) + assert res.dtype == np.float64 + assert np.all(res == np.float64(1.0)) + + +def test_builder_multiple_return_values() -> None: + """Tests the case that we return multiple value. + + Currently this is always a tuple. + """ + + @jace.jit + def wrapped(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return a + b, a - b + + a = testutil.make_array((2, 2)) + b = testutil.make_array((2, 2)) + + lowered = wrapped.lower(a, b) + compiled = lowered.compile() + + ref = (a + b, a - b) + res = compiled(a, b) + + assert len(lowered._translated_sdfg.input_names) == 2 + assert len(compiled._compiled_sdfg.input_names) == 2 + assert len(lowered._translated_sdfg.output_names) == 2 + assert len(compiled._compiled_sdfg.output_names) == 2 + assert isinstance(res, tuple), f"Expected 'tuple', but got '{type(res).__name__}'." + assert len(res) == 2 + assert np.allclose(ref, res) + + +def test_builder_direct_return() -> None: + """Tests the case, when an input value is returned as output. + + Note: + The test function below will not return a reference to its input, + but perform an actual copy. This behaviour does look strange from a + Python point of view, however, it is (at the time of writing) + consistent with what JAX does, even when passing JAX arrays directly. + """ + + @jace.jit + def wrapped(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + return a + b, b, a + + a = testutil.make_array((2, 2)) + b = testutil.make_array((2, 2)) + + ref0 = a + b + res = wrapped(a, b) + + assert isinstance(res, tuple) + assert len(res) == 3 + assert np.allclose(ref0, res[0]) + assert np.all(res[2] == a) + assert res[2].__array_interface__["data"][0] != a.__array_interface__["data"][0] + assert np.all(res[1] == b) + assert res[1].__array_interface__["data"][0] != b.__array_interface__["data"][0] + + +@pytest.mark.skip(reason="Literal return values are not supported.") +def test_builder_literal_return_value() -> None: + """Tests if there can be literals in the return values.""" + + def testee(a: np.ndarray) -> tuple[np.ndarray, np.float64, np.ndarray]: + return (a + 1.0, np.float64(1.0), a - 1.0) + + a = testutil.make_array((2, 2)) + ref = testee(a) + res = jace.jit(testee)(a) + + assert isinstance(res, tuple) + assert len(res) == 3 + assert res[1].dtype is np.float64 + assert all(np.allclose(ref[i], res[i]) for i in range(3)) + + +def test_builder_unused_arg() -> None: + """Tests if there is an unused argument.""" + + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: # noqa: ARG001 [unused-function-argument] + return a + 3.0 + + a = testutil.make_array((10, 10)) + b = testutil.make_array((11, 11)) + c = testutil.make_array((20, 20)) + + wrapped = jace.jit(testee) + lowered = wrapped.lower(a, b) + compiled = lowered.compile() + + ref = testee(a, b) + res1 = compiled(a, b) # Correct call + res2 = compiled(a, c) # wrong call to show that nothing is affected. + + assert len(lowered._translated_sdfg.input_names) == 2 + assert len(compiled._compiled_sdfg.input_names) == 2 + assert np.all(res1 == res2) + assert np.allclose(ref, res1) + + +def test_builder_jace_var() -> None: + """Simple tests about the `JaCeVar` objects.""" + for iname in ["do", "", "_ _", "9al", "_!"]: + with pytest.raises( + expected_exception=ValueError, match=re.escape(f"Supplied the invalid name '{iname}'.") + ): + _ = JaCeVar((), dace.int8, name=iname) + + +def test_builder_strides_lowering() -> None: + """Tests if we can lower without standard strides.""" + + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + a = testutil.make_array((4, 3), order="F") + b = testutil.make_array((4, 3), order="C") + ref = testee(a, b) + a_ref_strides = (1, 4) + b_ref_strides = (3, 1) + + lowered = jace.jit(testee).lower(a, b) + a_res_strides = lowered.as_sdfg().arrays["__jace_input_0"].strides + b_res_strides = lowered.as_sdfg().arrays["__jace_input_1"].strides + + compiled = lowered.compile() + res = compiled(a, b) + + assert ref.shape == res.shape + assert np.allclose(ref, res) + assert a_ref_strides == a_res_strides + assert b_ref_strides == b_res_strides + + +def test_builder_drop_variables() -> None: + """Tests if the builder can handle drop variables.""" + + @jace.grad + def testee(a: np.float64) -> jax.Array: + return jnp.exp(jnp.sin(jnp.tan(a**3))) ** 2 + + a = np.e + ref = testee(a) + res = jace.jit(testee)(a) + + assert np.allclose(ref, res) diff --git a/tests/integration_tests/test_primitive_translator_managing.py b/tests/integration_tests/test_primitive_translator_managing.py new file mode 100644 index 0000000..b7cf3d9 --- /dev/null +++ b/tests/integration_tests/test_primitive_translator_managing.py @@ -0,0 +1,217 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for managing the primitive subtranslators.""" + +from __future__ import annotations + +import re +from collections.abc import Generator, Mapping +from typing import Any + +import numpy as np +import pytest + +import jace +from jace import translator + +from tests import util as testutil + + +@pytest.fixture(autouse=True) +def _conserve_builtin_translators() -> Generator[None, None, None]: + """Restores the set of registered subtranslators after a test.""" + initial_translators = translator.get_registered_primitive_translators() + yield + testutil.set_active_primitive_translators_to(initial_translators) + + +@pytest.fixture() +def no_builtin_translators() -> Generator[None, None, None]: # noqa: PT004 [pytest-missing-fixture-name-underscore] # This is how you should do it: https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#use-fixtures-in-classes-and-modules-with-usefixtures + """This fixture can be used if the test does not want any builtin translators.""" + initial_translators = testutil.set_active_primitive_translators_to({}) + yield + testutil.set_active_primitive_translators_to(initial_translators) + + +# <------------- Definitions needed for the test + + +class SubTrans1(translator.PrimitiveTranslator): + @property + def primitive(self): + return "non_existing_primitive1" + + def __call__(self) -> None: # type: ignore[override] # Arguments + raise NotImplementedError + + +class SubTrans2(translator.PrimitiveTranslator): + @property + def primitive(self): + return "non_existing_primitive2" + + def __call__(self) -> None: # type: ignore[override] # Arguments + raise NotImplementedError + + +@translator.make_primitive_translator("non_existing_callable_primitive3") +def primitive_translator_3_callable(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 [unused-function-argument] + raise NotImplementedError + + +@translator.make_primitive_translator("add") +def fake_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 [unused-function-argument] + raise NotImplementedError("'fake_add_translator()' was called.") + + +def test_has_pjit(): + assert "pjit" in translator.get_registered_primitive_translators() + + +@pytest.mark.usefixtures("no_builtin_translators") +def test_subtranslatior_managing() -> None: + """Basic functionality of the subtranslators.""" + original_active_subtrans = translator.get_registered_primitive_translators() + assert len(original_active_subtrans) == 0 + + # Create the classes. + sub1 = SubTrans1() + sub2 = SubTrans2() + + # These are all primitive translators + prim_translators = [sub1, sub2, primitive_translator_3_callable] + + # Add the instances. + for sub in prim_translators: + assert translator.register_primitive_translator(sub) is sub + + # Tests if they were correctly registered + active_subtrans = translator.get_registered_primitive_translators() + for expected in prim_translators: + assert active_subtrans[expected.primitive] is expected + assert len(active_subtrans) == 3 + + +def test_subtranslatior_managing_swap() -> None: + """Tests the `translator.set_active_primitive_translators_to()` functionality.""" + + # Allows to compare the structure of dicts. + def same_structure(d1: Mapping, d2: Mapping) -> bool: + return d1.keys() == d2.keys() and all(id(d2[k]) == id(d1[k]) for k in d1) + + initial_primitives = translator.get_registered_primitive_translators() + assert "add" in initial_primitives + + # Generate a set of translators that we swap in + new_active_primitives = initial_primitives.copy() + new_active_primitives["add"] = fake_add_translator + + # Now perform the changes. + old_active = testutil.set_active_primitive_translators_to(new_active_primitives) + assert ( + new_active_primitives is not translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY + ) + assert same_structure(old_active, initial_primitives) + assert same_structure(new_active_primitives, translator.get_registered_primitive_translators()) + + +def test_subtranslatior_managing_callable_annotation() -> None: + """Test if `translator.make_primitive_translator()` works.""" + + prim_name = "non_existing_property" + + @translator.make_primitive_translator(prim_name) + def non_existing_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 [unused-function-argument] + raise NotImplementedError + + assert hasattr(non_existing_translator, "primitive") + assert non_existing_translator.primitive == prim_name + + +def test_subtranslatior_managing_overwriting() -> None: + """Tests if we are able to overwrite a translator in the global registry.""" + current_add_translator = translator.get_registered_primitive_translators()["add"] + + # This will not work because overwriting is not activated. + with pytest.raises( + expected_exception=ValueError, + match=re.escape( + "Explicit override=True needed for primitive 'add' to overwrite existing one." + ), + ): + translator.register_primitive_translator(fake_add_translator) + assert current_add_translator is translator.get_registered_primitive_translators()["add"] + + # Now we use overwrite. + assert fake_add_translator is translator.register_primitive_translator( + fake_add_translator, overwrite=True + ) + assert fake_add_translator is translator.get_registered_primitive_translators()["add"] + + +@pytest.mark.usefixtures("no_builtin_translators") +def test_subtranslatior_managing_overwriting_2() -> None: + """Again an overwriting test, but this time a bit more complicated. + + It also shows if the translator was actually called. + """ + + trans_cnt = [0] + + @translator.register_primitive_translator(overwrite=True) + @translator.make_primitive_translator("add") + def still_useless_but_a_bit_less(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 [unused-function-argument] + trans_cnt[0] += 1 + + @jace.jit + def foo(a: int) -> int: + b = a + 1 + c = b + 1 + d = c + 1 + return d + 1 + + with pytest.warns( + UserWarning, + match='WARNING: Use of uninitialized transient "e" in state output_processing_stage', + ): + _ = foo.lower(1) + assert trans_cnt[0] == 4 + + +def test_subtranslatior_managing_decoupling() -> None: + """Shows that we have proper decoupling. + + I.e. changes to the global state, does not affect already annotated functions. + """ + + # This will use the translators that are currently installed. + @jace.jit + def foo(a: np.ndarray) -> np.ndarray: + b = a + np.int32(1) + c = b + np.int32(1) + d = c + np.int32(1) + return d + np.int32(1) + + # Now register the add translator. + translator.register_primitive_translator(fake_add_translator, overwrite=True) + + # Since `foo` was already constructed, a new registering can not change anything. + a = np.zeros((10, 10), dtype=np.int32) + assert np.all(foo(a) == 4) + + # But if we now annotate a new function, then we will get fake translator + @jace.jit + def foo_fail(a: np.ndarray) -> np.ndarray: + b = a + np.int32(1) + return b + np.int32(1) + + with pytest.raises( + expected_exception=NotImplementedError, + match=re.escape("'fake_add_translator()' was called."), + ): + _ = foo_fail.lower(a) diff --git a/tests/test_caching.py b/tests/test_caching.py deleted file mode 100644 index 01fabc9..0000000 --- a/tests/test_caching.py +++ /dev/null @@ -1,254 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Tests for the caching infrastructure. -.""" - -from __future__ import annotations - -import itertools as it - -import numpy as np -import pytest - -import jace -from jace import optimization, stages -from jace.util import translation_cache as tcache - - -@pytest.fixture(autouse=True) -def _clear_translation_cache(): - """Decorator that clears the translation cache. - - Ensures that a function finds an empty cache and clears up afterwards. - """ - tcache.clear_translation_cache() - yield - tcache.clear_translation_cache() - - -def test_caching_same_sizes() -> None: - """The behaviour of the cache if same sizes are used, in two different functions.""" - - # Counter for how many time it was lowered. - lowering_cnt = [0] - - # This is the pure Python function. - def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A * B - - # this is the wrapped function. - @jace.jit - def wrapped(A, B): - lowering_cnt[0] += 1 - return testee(A, B) - - # First batch of arguments. - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - # The second batch of argument, same structure but different values. - AA = A + 1.0362 - BB = B + 0.638956 - - # Now let's lower it once directly and call it. - lowered: stages.JaCeLowered = wrapped.lower(A, B) - compiled: stages.JaCeCompiled = lowered.compile() - assert lowering_cnt[0] == 1 - assert np.allclose(testee(A, B), compiled(A, B)) - - # Now lets call the wrapped object directly, since we already did the lowering - # no longering (and compiling) is needed. - assert np.allclose(testee(A, B), wrapped(A, B)) - assert lowering_cnt[0] == 1 - - # Now lets call it with different objects, that have the same structure. - # Again no lowering should happen. - assert np.allclose(testee(AA, BB), wrapped(AA, BB)) - assert wrapped.lower(AA, BB) is lowered - assert wrapped.lower(A, B) is lowered - assert lowering_cnt[0] == 1 - - -def test_caching_different_sizes(): - """The behaviour of the cache if different sizes where used.""" - - # Counter for how many time it was lowered. - lowering_cnt = [0] - - # This is the wrapped function. - @jace.jit - def wrapped(A, B): - lowering_cnt[0] += 1 - return A * B - - # First size of arguments - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - # Second size of arguments - C = np.arange(16, dtype=np.float64).reshape((4, 4)) - D = np.full((4, 4), 10, dtype=np.float64) - - # Now lower the function once for each. - lowered1 = wrapped.lower(A, B) - lowered2 = wrapped.lower(C, D) - assert lowering_cnt[0] == 2 - assert lowered1 is not lowered2 - - # Now also check if the compilation works as intended - compiled1 = lowered1.compile() - compiled2 = lowered2.compile() - assert lowering_cnt[0] == 2 - assert compiled1 is not compiled2 - - -@pytest.mark.skip("'convert_element_type' primitive is not implemented.") -def test_caching_different_structure() -> None: - """Now tests if we can handle multiple arguments with different structures. - - Todo: - - Extend with strides once they are part of the cache. - """ - - # This is the wrapped function. - lowering_cnt = [0] - - @jace.jit - def wrapped(A, B): - lowering_cnt[0] += 1 - return A * 4.0, B + 2.0 - - A = np.full((4, 30), 10, dtype=np.float64) - B = np.full((4, 3), 10, dtype=np.float64) - C = np.full((5, 3), 14, dtype=np.float64) - D = np.full((6, 3), 14, dtype=np.int64) - - # These are the known lowerings. - lowerings: dict[tuple[int, int], stages.JaCeLowered] = {} - lowering_ids: set[int] = set() - # These are the known compilations. - compilations: dict[tuple[int, int], stages.JaCeCompiled] = {} - compiled_ids: set[int] = set() - - # Generating the lowerings - for arg1, arg2 in it.permutations([A, B, C, D], 2): - lower = wrapped.lower(arg1, arg2) - compiled = lower.compile() - assert id(lower) not in lowering_ids - assert id(compiled) not in compiled_ids - lowerings[id(arg1), id(arg2)] = lower - lowering_ids.add(id(lower)) - compilations[id(arg1), id(arg2)] = compiled - compiled_ids.add(id(compiled)) - - # Now check if they are still cached. - for arg1, arg2 in it.permutations([A, B, C, D], 2): - lower = wrapped.lower(arg1, arg2) - clower = lowerings[id(arg1), id(arg2)] - assert clower is lower - - compiled1 = lower.compile() - compiled2 = clower.compile() - ccompiled = compilations[id(arg1), id(arg2)] - assert compiled1 is compiled2 - assert compiled1 is ccompiled - - -def test_caching_compilation() -> None: - """Tests the compilation cache, this is just very simple.""" - - @jace.jit - def jaceWrapped(A: np.ndarray, B: np.ndarray) -> np.ndarray: - C = A * B - D = C + A - E = D + B # Just enough state. - return A + B + C + D + E - - # These are the argument - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - # Now we lower it. - jaceLowered = jaceWrapped.lower(A, B) - - # Compiling it without any information. - optiCompiled = jaceLowered.compile() - - # This should be the same as passing the defaults directly. - assert optiCompiled is jaceLowered.compile(optimization.DEFAULT_OPTIMIZATIONS) - - # Also if we pass the empty dict, we should get the default. - assert optiCompiled is jaceLowered.compile({}) - - # Now we disable all optimizations - unoptiCompiled = jaceLowered.compile(optimization.NO_OPTIMIZATIONS) - - # Because of the way how things work the optimized must have more than the - # unoptimized. If there is sharing, then this would not be the case. - assert unoptiCompiled is not optiCompiled - assert optiCompiled._compiled_sdfg.sdfg.number_of_nodes() == 1 - assert ( - optiCompiled._compiled_sdfg.sdfg.number_of_nodes() - < unoptiCompiled._compiled_sdfg.sdfg.number_of_nodes() - ) - - -def test_caching_dtype(): - """Tests if the data type is properly included in the test.""" - - lowering_cnt = [0] - - @jace.jit - def testee(A: np.ndarray) -> np.ndarray: - lowering_cnt[0] += 1 - return A + A - - dtypes = [np.float64, np.float32, np.int32, np.int64] - shape = (10, 10) - - for i, dtype in enumerate(dtypes): - A = np.array((np.random.random(shape) - 0.5) * 10, dtype=dtype) # noqa: NPY002 - - assert lowering_cnt[0] == i - _ = testee(A) - assert lowering_cnt[0] == i + 1 - - assert np.allclose(testee(A), 2 * A) - assert lowering_cnt[0] == i + 1 - - -def test_caching_strides() -> None: - """Test if the cache detects a change in strides.""" - - @jace.jit - def wrapped(A: np.ndarray) -> np.ndarray: - return A + 10.0 - - shape = (10, 100, 1000) - C = np.array( - (np.random.random(shape) - 0.5) * 10, # noqa: NPY002 - order="C", - dtype=np.float64, - ) - F = np.array(C, copy=True, order="F") - - # First we compile run it with C strides. - C_lower = wrapped.lower(C) - C_res = wrapped(C) - - # Now we run it with FORTRAN strides. - # However, this does not work because we do not support strides at all. - # But the cache is aware of this, which helps catch some nasty bugs. - F_lower = None # Remove later - F_res = C_res.copy() # Remove later - F_lower = wrapped.lower(F) - F_res = wrapped(F) - assert F_lower is not C_lower - assert C_res is not F_res - assert np.allclose(F_res, C_res) - assert F_lower is not C_lower diff --git a/tests/test_decorator.py b/tests/test_decorator.py deleted file mode 100644 index 7971b29..0000000 --- a/tests/test_decorator.py +++ /dev/null @@ -1,95 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements tests for the jit decorator. - -Also see the `test_jax_api.py` test file, that tests composability. -""" - -from __future__ import annotations - -import numpy as np -import pytest - -import jace -from jace.util import translation_cache as tcache - - -@pytest.fixture(autouse=True) -def _clear_translation_cache(): - """Decorator that clears the translation cache. - - Ensures that a function finds an empty cache and clears up afterwards. - - Todo: - Should be used _everywhere_. - """ - - tcache.clear_translation_cache() - yield - tcache.clear_translation_cache() - - -def test_decorator_individually(): - """Tests the compilation steps individually.""" - - def testee_(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A + B - - lowering_cnt = [0] - - @jace.jit - def testee(A, B): - lowering_cnt[0] += 1 - return testee_(A, B) - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - lowered = testee.lower(A, B) - compiled = lowered.compile() - - ref = testee_(A, B) - res = compiled(A, B) - - assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." - assert lowering_cnt[0] == 1 - - -def test_decorator_one_go(): - """Tests the compilation steps in one go.""" - - def testee_(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A + B - - lowering_cnt = [0] - - @jace.jit - def testee(A, B): - lowering_cnt[0] += 1 - return testee_(A, B) - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - ref = testee_(A, B) - res = testee(A, B) - - assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." - assert lowering_cnt[0] == 1 - - -def test_decorator_wrapped(): - """Tests if some properties are set correctly.""" - - def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A * B - - wrapped = jace.jit(testee) - - assert wrapped.wrapped_fun is testee - assert wrapped.__wrapped__ is testee diff --git a/tests/test_empty_jaxpr.py b/tests/test_empty_jaxpr.py deleted file mode 100644 index 36e8247..0000000 --- a/tests/test_empty_jaxpr.py +++ /dev/null @@ -1,48 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements tests for empty jaxprs. -.""" - -from __future__ import annotations - -import jax -import numpy as np -import pytest - -import jace - - -def test_empty_array(): - @jace.jit - def testee(A: np.ndarray) -> np.ndarray: - return A - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - - assert np.all(testee(A) == A) - - -def test_empty_scalar(): - @jace.jit - def testee(A: float) -> float: - return A - - A = np.pi - - assert np.all(testee(A) == A) - - -@pytest.mark.skip(reason="Nested Jaxpr are not handled.") -def test_empty_nested(): - @jace.jit - def testee3(A: float) -> float: - return jax.jit(lambda A: A)(A) - - A = np.pi - - assert np.all(testee3(A) == A) diff --git a/tests/test_jax_api.py b/tests/test_jax_api.py deleted file mode 100644 index 0c1905d..0000000 --- a/tests/test_jax_api.py +++ /dev/null @@ -1,200 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Tests the compatibility of the JaCe api to Jax.""" - -from __future__ import annotations - -import jax -import numpy as np -import pytest -from jax import numpy as jnp - -import jace - - -np.random.seed(42) # noqa: NPY002 # random generator - - -def test_jit(): - """Simple add function.""" - - def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A + B - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - jax_testee = jax.jit(testee) - jace_testee = jace.jit(testee) - - ref = jax_testee(A, B) - res = jace_testee(A, B) - - assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." - - -def test_composition_itself(): - """Tests if JaCe is composable with itself.""" - - # Pure Python functions - def f_ref(x): - return jnp.sin(x) - - def df_ref(x): - return jnp.cos(x) - - def ddf_ref(x): - return -jnp.sin(x) - - # Annotated functions. - - @jace.jit - def f(x): - return f_ref(x) - - @jace.jit - def df(x): - return jace.grad(f)(x) - - @jace.jit - @jace.grad - def ddf(x): - return df(x) - - assert all(isinstance(x, jace.stages.JaCeWrapped) for x in [f, df, ddf]) - - x = 1.0 - for fun, fref in zip([f, df, ddf], [f_ref, df_ref, ddf_ref]): - ref = fref(x) - res = fun(x) - assert np.allclose(ref, res), f"f: Expected '{ref}', got '{res}'." - - -@pytest.mark.skip(reason="Nested Jaxpr are not handled.") -def test_composition_with_jax(): - """Tests if JaCe can interact with Jax and vice versa.""" - - def base_fun(A, B, C): - return A + B * jnp.sin(C) - A * B - - @jace.jit - def jace_fun(A, B, C): - return jax.jit(base_fun)(A, B, C) - - def jax_fun(A, B, C): - return jace.jit(base_fun)(A, B, C) - - A, B, C = (np.random.random((10, 3, 50)) for _ in range(3)) # noqa: NPY002 # random generator - - assert np.allclose(jace_fun(A, B, C), jax_fun(A, B, C)) - - -@pytest.mark.skip(reason="Nested Jaxpr are not handled.") -def test_composition_with_jax_2(): - """Second test if JaCe can interact with Jax and vice versa.""" - - @jax.jit - def f1_jax(A, B): - return A + B - - @jace.jit - def f2_jace(A, B, C): - return f1_jax(A, B) - C - - @jax.jit - def f3_jax(A, B, C, D): - return f2_jace(A, B, C) * D - - @jace.jit - def f3_jace(A, B, C, D): - return f3_jax(A, B, C, D) - - A, B, C, D = (np.random.random((10, 3, 50)) for _ in range(4)) # noqa: NPY002 # random generator - - ref = ((A + B) - C) * D - res_jax = f3_jax(A, B, C, D) - res_jace = f3_jace(A, B, C, D) - - assert np.allclose(ref, res_jax), "Jax failed." - assert np.allclose(ref, res_jace), "JaCe Failed." - - -def test_grad_annotation_direct(): - """Test if `jace.grad` works directly.""" - - def f(x): - return jnp.sin(jnp.exp(jnp.cos(x**2))) - - @jax.grad - def jax_ddf(x): - return jax.grad(f)(x) - - @jax.jit - def jace_ddf(x): - return jace.grad(jace.grad(f))(x) - - # These are the random numbers where we test - Xs = (np.random.random(10) - 0.5) * 10 # noqa: NPY002 # Random number generator - - for i in range(Xs.shape[0]): - x = Xs[i] - res = jace_ddf(x) - ref = jax_ddf(x) - assert np.allclose(res, ref) - - -def test_grad_control_flow(): - """Tests if `grad` and controlflow works. - - This requirement is mentioned in `https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-autodiff`. - """ - - @jace.grad - def df(x): - if x < 3: - return 3.0 * x**2 - return -4 * x - - x1 = 2.0 - df_x1 = 6 * x1 - x2 = 4.0 - df_x2 = -4.0 - - res_1 = df(x1) - res_2 = df(x2) - - assert df(x1) == df_x1, f"Failed lower branch, expected '{df_x1}', got '{res_1}'." - assert df(x2) == df_x2, f"Failed upper branch, expected '{df_x2}', got '{res_2}'." - - -@pytest.mark.skip(reason="Running JaCe with disabled 'x64' support does not work.") -def test_disabled_x64(): - """Tests the behaviour of the tool chain if x64 is disabled. - - If you want to test, if this restriction still applies, you can enable the test. - """ - - def testee(A: np.ndarray, B: np.float64) -> np.ndarray: - return A + B - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.float64(10.0) - - # Run them with disabled x64 support - with jax.experimental.disable_x64(): - # JaCe - jace_testee = jace.jit(testee) - jace_lowered = jace_testee.lower(A, B) - jace_comp = jace_lowered.compile() - res = jace_comp(A, B) - - # Jax - jax_testee = jax.jit(testee) - ref = jax_testee(A, B) - - assert np.allclose(ref, res), "Expected that: {ref.tolist()}, but got {res.tolist()}." diff --git a/tests/test_sub_translators_alu.py b/tests/test_sub_translators_alu.py deleted file mode 100644 index 603f57c..0000000 --- a/tests/test_sub_translators_alu.py +++ /dev/null @@ -1,63 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements tests for the ALU translator.""" - -from __future__ import annotations - -import jax -import numpy as np - -import jace - - -def test_add(): - """Simple add function.""" - - def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: - return A + B - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - ref = testee(A, B) - res = jace.jit(testee)(A, B) - - assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." - - -def test_add2(): - """Simple add function, with literal.""" - - def testee(A: np.ndarray, B: np.ndarray) -> np.ndarray: - c = A + 0.01 - d = B * 0.6 - e = c / 1.0 - f = d - 0.1 - return e + f * d - - A = np.arange(12, dtype=np.float64).reshape((4, 3)) - B = np.full((4, 3), 10, dtype=np.float64) - - ref = testee(A, B) - res = jace.jit(testee)(A, B) - - assert np.allclose(ref, res), f"Expected '{ref.tolist()}' got '{res.tolist()}'." - - -def test_add3(): - """Simple add function, with constant.""" - - def testee(A: np.ndarray) -> np.ndarray: - return A + jax.numpy.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) - - A = np.ones((3, 3), dtype=np.float64) - - ref = testee(A) - res = jace.jit(testee)(A) - - assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." diff --git a/tests/test_subtranslator_helper.py b/tests/test_subtranslator_helper.py deleted file mode 100644 index 52672b0..0000000 --- a/tests/test_subtranslator_helper.py +++ /dev/null @@ -1,221 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements tests for managing the primitive subtranslators.""" - -from __future__ import annotations - -import re -from typing import Any - -import numpy as np -import pytest - -import jace -from jace import translator -from jace.translator import ( - get_registered_primitive_translators, - make_primitive_translator, - register_primitive_translator, -) - - -@pytest.fixture(autouse=True) -def _conserve_builtin_translators(): - """Restores the set of registered subtranslators after a test.""" - initial_translators = translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.copy() - yield - translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() - translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.update(initial_translators) - - -@pytest.fixture() -def no_builtin_translators(): # noqa: PT004 # This is how you should do it: https://docs.pytest.org/en/7.1.x/how-to/fixtures.html#use-fixtures-in-classes-and-modules-with-usefixtures - """This fixture can be used if the test does not want any builtin translators.""" - initial_translators = translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.copy() - translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() - yield - translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.clear() - translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY.update(initial_translators) - - -# These are definitions of some Subtranslators that can be used to test things. -class SubTrans1(translator.PrimitiveTranslator): - @property - def primitive(self): - return "non_existing_primitive1" - - def __call__(self) -> None: # type: ignore[override] # Arguments - raise NotImplementedError - - -class SubTrans2(translator.PrimitiveTranslator): - @property - def primitive(self): - return "non_existing_primitive2" - - def __call__(self) -> None: # type: ignore[override] # Arguments - raise NotImplementedError - - -@make_primitive_translator("non_existing_callable_primitive3") -def SubTrans3_Callable(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - raise NotImplementedError - - -@make_primitive_translator("add") -def fake_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - raise NotImplementedError - - -def test_are_subtranslators_imported(): - """Tests if something is inside the list of subtranslators.""" - # Must be adapted if new primitives are implemented. - assert len(get_registered_primitive_translators()) > 0 - - -@pytest.mark.usefixtures("no_builtin_translators") -def test_subtranslatior_managing(): - """Basic functionality of the subtranslators.""" - original_active_subtrans = get_registered_primitive_translators() - assert len(original_active_subtrans) == 0 - - # Create the classes. - sub1 = SubTrans1() - sub2 = SubTrans2() - - # These are all primitive translators - prim_translators = [sub1, sub2, SubTrans3_Callable] - - # Add the instances. - for sub in prim_translators: - assert register_primitive_translator(sub) is sub - - # Tests if they were correctly registered - active_subtrans = get_registered_primitive_translators() - for expected in prim_translators: - assert active_subtrans[expected.primitive] is expected - assert len(active_subtrans) == 3 - - -def test_subtranslatior_managing_isolation(): - """Tests if `get_registered_primitive_translators()` decouples.""" - assert ( - get_registered_primitive_translators() - is not translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY - ) - - initial_primitives = get_registered_primitive_translators() - assert get_registered_primitive_translators() is not initial_primitives - assert "add" in initial_primitives, "For this test the 'add' primitive must be registered." - org_add_prim = initial_primitives["add"] - - initial_primitives["add"] = fake_add_translator - assert org_add_prim is not fake_add_translator - assert get_registered_primitive_translators()["add"] is org_add_prim - - -@pytest.mark.usefixtures("no_builtin_translators") -def test_subtranslatior_managing_callable_annotation(): - """Test if `make_primitive_translator()` works.""" - - prim_name = "non_existing_property" - - @make_primitive_translator(prim_name) - def non_existing_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - raise NotImplementedError - - assert hasattr(non_existing_translator, "primitive") - assert non_existing_translator.primitive == prim_name - assert len(get_registered_primitive_translators()) == 0 - - -def test_subtranslatior_managing_overwriting(): - """Tests if we are able to overwrite something.""" - current_add_translator = get_registered_primitive_translators()["add"] - - @make_primitive_translator("add") - def useless_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - raise NotImplementedError - - # This will not work because it is not overwritten. - with pytest.raises( - expected_exception=ValueError, - match=re.escape( - "Explicit override=True needed for primitive 'add' to overwrite existing one." - ), - ): - register_primitive_translator(useless_add_translator) - assert current_add_translator is get_registered_primitive_translators()["add"] - - # Now we use overwrite, thus it will now work. - assert useless_add_translator is register_primitive_translator( - useless_add_translator, overwrite=True - ) - assert useless_add_translator is get_registered_primitive_translators()["add"] - - -@pytest.mark.usefixtures("no_builtin_translators") -def test_subtranslatior_managing_overwriting_2(): - """Again an overwriting test, but this time a bit more complicated.""" - - trans_cnt = [0] - - @register_primitive_translator(overwrite=True) - @make_primitive_translator("add") - def still_useless_but_a_bit_less(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - trans_cnt[0] += 1 - - @jace.jit - def foo(A): - B = A + 1 - C = B + 1 - D = C + 1 - return D + 1 - - with pytest.warns( - UserWarning, - match=re.escape('Use of uninitialized transient "e" in state output_processing_stage'), - ): - _ = foo.lower(1) - assert trans_cnt[0] == 4 - - -def test_subtranslatior_managing_decoupling(): - """Shows that we have proper decoupling. - - I.e. changes to the global state, does not affect already annotated functions. - """ - - # This will use the translators that are currently installed. - @jace.jit - def foo(A): - B = A + 1 - C = B + 1 - D = C + 1 - return D + 1 - - @register_primitive_translator(overwrite=True) - @make_primitive_translator("add") - def useless_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 - raise NotImplementedError("The 'useless_add_translator' was called as expected.") - - # Since `foo` was already constructed, a new registering can not change anything. - A = np.zeros((10, 10)) - assert np.all(foo(A) == 4) - - # But if we now annotate a new function, then we will get the uselss translator - @jace.jit - def foo_fail(A): - B = A + 1 - return B + 1 - - with pytest.raises( - expected_exception=NotImplementedError, - match=re.escape("The 'useless_add_translator' was called as expected."), - ): - _ = foo_fail.lower(A) diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 0000000..5ce9af1 --- /dev/null +++ b/tests/unit_tests/__init__.py @@ -0,0 +1,8 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""JaCe's unit tests.""" diff --git a/tests/unit_tests/test_caching.py b/tests/unit_tests/test_caching.py new file mode 100644 index 0000000..85e5b3b --- /dev/null +++ b/tests/unit_tests/test_caching.py @@ -0,0 +1,443 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the caching infrastructure. +.""" + +from __future__ import annotations + +import itertools as it + +import jax +import numpy as np +from jax import numpy as jnp + +import jace +from jace import optimization, stages +from jace.util import translation_cache as tcache + +from tests import util as testutil + + +def test_caching_working() -> None: + """Simple test if the caching actually works.""" + + lowering_cnt = [0] + + @jace.jit + def wrapped(a: np.ndarray) -> jax.Array: + lowering_cnt[0] += 1 + return jnp.sin(a) + + a = testutil.make_array((10, 10)) + ref = np.sin(a) + res_ids: set[int] = set() + # We have to store the array, because numpy does reuse the memory. + res_set: list[jax.Array] = [] + + for _ in range(10): + res = wrapped(a) + res_id = res.__array_interface__["data"][0] # type: ignore[attr-defined] + + assert np.allclose(res, ref) + assert lowering_cnt[0] == 1 + assert res_id not in res_ids + res_ids.add(res_id) + res_set.append(res) + + +def test_caching_same_sizes() -> None: + """The behaviour of the cache if same sizes are used, in two different functions.""" + + # Counter for how many time it was lowered. + lowering_cnt = [0] + + # This is the pure Python function. + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a * b + + # this is the wrapped function. + @jace.jit + def wrapped(a: np.ndarray, b: np.ndarray) -> np.ndarray: + lowering_cnt[0] += 1 + return testee(a, b) + + # First batch of arguments. + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + # The second batch of argument, same structure, but different values. + aa = a + 1.0362 + bb = b + 0.638956 + + # Now let's lower it once directly and call it. + lowered: stages.JaCeLowered = wrapped.lower(a, b) + compiled: stages.JaCeCompiled = lowered.compile() + assert lowering_cnt[0] == 1 + assert np.allclose(testee(a, b), compiled(a, b)) + + # Now lets call the wrapped object directly, since we already did the lowering + # no lowering (and compiling) is needed. + assert np.allclose(testee(a, b), wrapped(a, b)) + assert lowering_cnt[0] == 1 + + # Now lets call it with different objects, that have the same structure. + # Again no lowering should happen. + assert np.allclose(testee(aa, bb), wrapped(aa, bb)) + assert wrapped.lower(aa, bb) is lowered + assert wrapped.lower(a, b) is lowered + assert lowering_cnt[0] == 1 + + +def test_caching_different_sizes() -> None: + """The behaviour of the cache if different sizes where used.""" + + # Counter for how many time it was lowered. + lowering_cnt = [0] + + # This is the wrapped function. + @jace.jit + def wrapped(a, b): + lowering_cnt[0] += 1 + return a * b + + # First size of arguments + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + # Second size of arguments + c = testutil.make_array((4, 4)) + d = testutil.make_array((4, 4)) + + # Now lower the function once for each. + lowered1 = wrapped.lower(a, b) + lowered2 = wrapped.lower(c, d) + assert lowering_cnt[0] == 2 + assert lowered1 is not lowered2 + + # Now also check if the compilation works as intended + compiled1 = lowered1.compile() + compiled2 = lowered2.compile() + assert lowering_cnt[0] == 2 + assert compiled1 is not compiled2 + + +def test_caching_different_structure() -> None: + """Now tests if we can handle multiple arguments with different structures. + + Todo: + - Extend with strides once they are part of the cache. + """ + + # This is the wrapped function. + lowering_cnt = [0] + + @jace.jit + def wrapped(a, b): + lowering_cnt[0] += 1 + return a * 4.0, b + 2.0 + + a = testutil.make_array((4, 30), dtype=np.float64) + b = testutil.make_array((4, 3), dtype=np.float64) + c = testutil.make_array((4, 3), dtype=np.int64) + d = testutil.make_array((6, 3), dtype=np.int64) + + # These are the known lowered instances. + lowerings: dict[tuple[int, int], stages.JaCeLowered] = {} + lowering_ids: set[int] = set() + # These are the known compilation instances. + compilations: dict[tuple[int, int], stages.JaCeCompiled] = {} + compiled_ids: set[int] = set() + + # Generating the lowerings + for arg1, arg2 in it.permutations([a, b, c, d], 2): + lower = wrapped.lower(arg1, arg2) + compiled = lower.compile() + assert id(lower) not in lowering_ids + assert id(compiled) not in compiled_ids + lowerings[id(arg1), id(arg2)] = lower + lowering_ids.add(id(lower)) + compilations[id(arg1), id(arg2)] = compiled + compiled_ids.add(id(compiled)) + + # Now check if they are still cached. + for arg1, arg2 in it.permutations([a, b, c, d], 2): + lower = wrapped.lower(arg1, arg2) + clower = lowerings[id(arg1), id(arg2)] + assert clower is lower + + compiled1 = lower.compile() + compiled2 = clower.compile() + ccompiled = compilations[id(arg1), id(arg2)] + assert compiled1 is compiled2 + assert compiled1 is ccompiled + + +def test_caching_compilation() -> None: + """Tests the compilation cache.""" + + @jace.jit + def jace_wrapped(a: np.ndarray, b: np.ndarray) -> np.ndarray: + c = a * b + d = c + a + e = d + b # Just enough state. + return a + b + c + d + e + + # These are the argument + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + # Now we lower it. + jace_lowered = jace_wrapped.lower(a, b) + + # Compiling it with and without optimizations enabled + optized_compiled = jace_lowered.compile(optimization.DEFAULT_OPTIMIZATIONS) + unoptized_compiled = jace_lowered.compile(optimization.NO_OPTIMIZATIONS) + + # Because of the way how things work the optimized must have more than the + # unoptimized. If there is sharing, then this would not be the case. + assert unoptized_compiled is not optized_compiled + assert optized_compiled._compiled_sdfg.sdfg.number_of_nodes() == 1 + assert ( + optized_compiled._compiled_sdfg.sdfg.number_of_nodes() + < unoptized_compiled._compiled_sdfg.sdfg.number_of_nodes() + ) + + # Now we check if they are still inside the cache. + assert optized_compiled is jace_lowered.compile(optimization.DEFAULT_OPTIMIZATIONS) + assert unoptized_compiled is jace_lowered.compile(optimization.NO_OPTIMIZATIONS) + + +def test_caching_compilation_options() -> None: + """Tests if the global optimization managing works.""" + lowering_cnt = [0] + + @jace.jit + def wrapped(a: float) -> float: + lowering_cnt[0] += 1 + return a + 1.0 + + lower_cache = wrapped._cache + lowered = wrapped.lower(1.0) + compile_cache = lowered._cache + + assert len(lower_cache) == 1 + assert len(compile_cache) == 0 + assert lowering_cnt[0] == 1 + + # Using the first set of options. + with stages.set_compiler_options(optimization.NO_OPTIMIZATIONS): + _ = wrapped(2.0) + + # Except from one entry in the compile cache, nothing should have changed. + assert len(lower_cache) == 1 + assert len(compile_cache) == 1 + assert compile_cache.front()[0].stage_id == id(lowered) + assert lowering_cnt[0] == 1 + + # Now we change the options again which then will lead to another compilation, + # but not to another lowering. + with stages.set_compiler_options(optimization.DEFAULT_OPTIMIZATIONS): + _ = wrapped(2.0) + + assert len(lower_cache) == 1 + assert len(compile_cache) == 2 + assert compile_cache.front()[0].stage_id == id(lowered) + assert lowering_cnt[0] == 1 + + +def test_caching_dtype() -> None: + """Tests if the data type is properly included in the test.""" + + lowering_cnt = [0] + + @jace.jit + def testee(a: np.ndarray) -> np.ndarray: + lowering_cnt[0] += 1 + return a + a + + dtypes = [np.float64, np.float32, np.int32, np.int64] + shape = (10, 10) + + for i, dtype in enumerate(dtypes): + a = testutil.make_array(shape, dtype=dtype) + + # First lowering + assert lowering_cnt[0] == i + _ = testee(a) + assert lowering_cnt[0] == i + 1 + + # Second, implicit, lowering, which must be cached. + assert np.allclose(testee(a), 2 * a) + assert lowering_cnt[0] == i + 1 + + +def test_caching_eviction_simple() -> None: + """Simple tests for cache eviction.""" + + @jace.jit + def testee(a: np.ndarray) -> np.ndarray: + return a + 1.0 + + cache: tcache.StageCache = testee._cache + assert len(cache) == 0 + + first_lowered = testee.lower(np.ones(10)) + first_key = cache.front()[0] + assert len(cache) == 1 + + second_lowered = testee.lower(np.ones(11)) + second_key = cache.front()[0] + assert len(cache) == 2 + assert second_key != first_key + + third_lowered = testee.lower(np.ones(12)) + third_key = cache.front()[0] + assert len(cache) == 3 + assert third_key != second_key + assert third_key != first_key + + # Test if the key association is correct. + # We have to do it in this order, because reading the key modifies the order. + assert cache.front()[0] == third_key + assert cache[first_key] is first_lowered + assert cache.front()[0] == first_key + assert cache[second_key] is second_lowered + assert cache.front()[0] == second_key + assert cache[third_key] is third_lowered + assert cache.front()[0] == third_key + + # We now evict the second key, which should not change anything on the order. + cache.popitem(second_key) + assert len(cache) == 2 + assert first_key in cache + assert second_key not in cache + assert third_key in cache + assert cache.front()[0] == third_key + + # Now we modify first_key, which moves it to the front. + cache[first_key] = first_lowered + assert len(cache) == 2 + assert first_key in cache + assert third_key in cache + assert cache.front()[0] == first_key + + # Now we evict the oldest one, which is third_key + cache.popitem(None) + assert len(cache) == 1 + assert first_key in cache + assert cache.front()[0] == first_key + + +def test_caching_eviction_complex() -> None: + """Tests if the stuff is properly evicted if the cache is full.""" + + @jace.jit + def testee(a: np.ndarray) -> np.ndarray: + return a + 1.0 + + cache: tcache.StageCache = testee._cache + capacity = cache.capacity + assert len(cache) == 0 + + # Lets fill the cache to the brim. + for i in range(capacity): + a = np.ones(i + 10) + lowered = testee.lower(a) + assert len(cache) == i + 1 + + if i == 0: + first_key: tcache.StageTransformationSpec = cache.front()[0] + first_lowered = cache[first_key] + assert lowered is first_lowered + elif i == 1: + second_key: tcache.StageTransformationSpec = cache.front()[0] + assert second_key != first_key + assert cache[second_key] is lowered + assert first_key in cache + + assert len(cache) == capacity + assert first_key in cache + assert second_key in cache + + # Now we will modify the first key, this should make it the newest. + assert cache.front()[0] != first_key + cache[first_key] = first_lowered + assert len(cache) == capacity + assert first_key in cache + assert second_key in cache + assert cache.front()[0] == first_key + + # Now we will add a new entry to the cache, this will evict the second entry. + _ = testee.lower(np.ones(capacity + 1000)) + assert len(cache) == capacity + assert cache.front()[0] != first_key + assert first_key in cache + assert second_key not in cache + + +def test_caching_strides() -> None: + """Test if the cache detects a change in strides.""" + + lower_cnt = [0] + + @jace.jit + def wrapped(a: np.ndarray) -> np.ndarray: + lower_cnt[0] += 1 + return a + 10.0 + + shape = (10, 100, 1000) + array_c = testutil.make_array(shape, order="C") + array_f = np.array(array_c, copy=True, order="F") + + # First we compile run it with c strides. + lower_c = wrapped.lower(array_c) + res_c = wrapped(array_c) + + lower_f = wrapped.lower(array_f) + res_f = lower_f.compile()(array_f) + + assert res_c is not res_f + assert lower_f is not lower_c + assert np.allclose(res_f, res_c) + + # In previous versions JAX did not cached the result of the tracing, + # but in newer version the tracing itself is also cached + if lower_c._jaxpr is lower_f._jaxpr: + assert lower_cnt[0] == 1 + else: + assert lower_cnt[0] == 2 + + +def test_caching_jax_numpy_array() -> None: + """Tests if jax arrays are handled the same way as numpy array.""" + + def _test_impl( + for_lowering: np.ndarray | jax.Array, for_calling: np.ndarray | jax.Array + ) -> None: + tcache.clear_translation_cache() + lowering_cnt = [0] + + @jace.jit + def wrapped(a: np.ndarray | jax.Array) -> np.ndarray | jax.Array: + lowering_cnt[0] += 1 + return a + 1.0 + + # Explicit lowering. + _ = wrapped(for_lowering) + assert lowering_cnt[0] == 1 + + # Now calling with the second argument, it should not longer again. + _ = wrapped(for_calling) + assert lowering_cnt[0] == 1, "Expected no further lowering." + + a_numpy = testutil.make_array((10, 10)) + a_jax = jnp.array(a_numpy, copy=True) + assert a_numpy.dtype == a_jax.dtype + + _test_impl(a_numpy, a_jax) + _test_impl(a_jax, a_numpy) diff --git a/tests/unit_tests/test_decorator.py b/tests/unit_tests/test_decorator.py new file mode 100644 index 0000000..7460491 --- /dev/null +++ b/tests/unit_tests/test_decorator.py @@ -0,0 +1,79 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements tests for the jit decorator. + +Also see the `test_jax_api.py` test file, that tests composability. +""" + +from __future__ import annotations + +import numpy as np + +import jace + +from tests import util as testutil + + +def test_decorator_individually() -> None: + """Tests the compilation steps individually.""" + + def testee_(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + lowering_cnt = [0] + + @jace.jit + def testee(a, b): + lowering_cnt[0] += 1 + return testee_(a, b) + + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + lowered = testee.lower(a, b) + compiled = lowered.compile() + + ref = testee_(a, b) + res = compiled(a, b) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + assert lowering_cnt[0] == 1 + + +def test_decorator_one_go() -> None: + """Tests the compilation steps in one go.""" + + def testee_(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + lowering_cnt = [0] + + @jace.jit + def testee(a, b): + lowering_cnt[0] += 1 + return testee_(a, b) + + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + ref = testee_(a, b) + res = testee(a, b) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + assert lowering_cnt[0] == 1 + + +def test_decorator_wrapped() -> None: + """Tests if some properties are set correctly.""" + + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a * b + + wrapped = jace.jit(testee) + + assert wrapped.wrapped_fun is testee diff --git a/tests/unit_tests/test_jax_api.py b/tests/unit_tests/test_jax_api.py new file mode 100644 index 0000000..b4327b7 --- /dev/null +++ b/tests/unit_tests/test_jax_api.py @@ -0,0 +1,282 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests the compatibility of the JaCe api to JAX.""" + +from __future__ import annotations + +import jax +import numpy as np +import pytest +from jax import numpy as jnp, tree_util as jax_tree + +import jace +from jace import translated_jaxpr_sdfg as tjsdfg, translator, util +from jace.translator import post_translation as ptranslation + +from tests import util as testutil + + +def test_jit() -> None: + """Simple add function.""" + + def testee(a: np.ndarray, b: np.ndarray) -> np.ndarray: + return a + b + + a = testutil.make_array((4, 3)) + b = testutil.make_array((4, 3)) + + jax_testee = jax.jit(testee) + jace_testee = jace.jit(testee) + + ref = jax_testee(a, b) + res = jace_testee(a, b) + + assert np.allclose(ref, res), f"Expected '{ref}' got '{res}'." + + +def test_composition_itself() -> None: + """Tests if JaCe is composable with itself.""" + + # Pure Python functions + def f_ref(x): + return jnp.sin(x) + + def df_ref(x): + return jnp.cos(x) + + def ddf_ref(x): + return -jnp.sin(x) + + # Annotated functions. + + @jace.jit + def f(x): + return f_ref(x) + + @jace.jit + def df(x): + return jace.grad(f)(x) + + @jace.jit + @jace.grad + def ddf(x): + return df(x) + + assert all(isinstance(x, jace.stages.JaCeWrapped) for x in [f, df, ddf]) + + x = 1.0 + for fun, fref in zip([f, df, ddf], [f_ref, df_ref, ddf_ref]): + ref = fref(x) + res = fun(x) + assert np.allclose(ref, res), f"f: Expected '{ref}', got '{res}'." + + +def test_composition_with_jax() -> None: + """Tests if JaCe can interact with JAX and vice versa.""" + + def base_fun(a, b, c): + return a + b * jnp.sin(c) - a * b + + @jace.jit + def jace_fun(a, b, c): + return jax.jit(base_fun)(a, b, c) + + def jax_fun(a, b, c): + return jace.jit(base_fun)(a, b, c) + + a, b, c = (testutil.make_array((10, 3, 50)) for _ in range(3)) + + assert np.allclose(jace_fun(a, b, c), jax_fun(a, b, c)) + + +def test_composition_with_jax_2() -> None: + """Second test if JaCe can interact with JAX and vice versa.""" + + @jax.jit + def f1_jax(a, b): + return a + b + + @jace.jit + def f2_jace(a, b, c): + return f1_jax(a, b) - c + + @jax.jit + def f3_jax(a, b, c, d): + return f2_jace(a, b, c) * d + + @jace.jit + def f3_jace(a, b, c, d): + return f3_jax(a, b, c, d) + + a, b, c, d = (testutil.make_array((10, 3, 50)) for _ in range(4)) + + ref = ((a + b) - c) * d + res_jax = f3_jax(a, b, c, d) + res_jace = f3_jace(a, b, c, d) + + assert np.allclose(ref, res_jax), "JAX failed." + assert np.allclose(ref, res_jace), "JaCe Failed." + + +def test_grad_annotation_direct() -> None: + """Test if `jace.grad` works directly.""" + + def f(x): + return jnp.sin(jnp.exp(jnp.cos(x**2))) + + @jax.grad + def jax_ddf(x): + return jax.grad(f)(x) + + @jax.jit + def jace_ddf(x): + return jace.grad(jace.grad(f))(x) + + # These are the random numbers where we test + xs = (testutil.make_array(10) - 0.5) * 10 + + for i in range(xs.shape[0]): + x = xs[i] + res = jace_ddf(x) + ref = jax_ddf(x) + assert np.allclose(res, ref) + + +def test_grad_control_flow() -> None: + """Tests if `grad` and controlflow works. + + This requirement is mentioned in the [documentation](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-autodiff). + """ + + @jace.grad + def df(x): + if x < 3: + return 3.0 * x**2 + return -4 * x + + x1 = 2.0 + df_x1 = 6 * x1 + x2 = 4.0 + df_x2 = -4.0 + + res_1 = df(x1) + res_2 = df(x2) + + assert df(x1) == df_x1, f"Failed lower branch, expected '{df_x1}', got '{res_1}'." + assert df(x2) == df_x2, f"Failed upper branch, expected '{df_x2}', got '{res_2}'." + + +def test_disabled_x64() -> None: + """Tests the behaviour of the tool chain if x64 support is disabled. + + Notes: + Once the x64 issue is resolved make this test a bit more useful. + """ + + def testee(a: np.ndarray, b: np.float64) -> np.ndarray: + return a + b + + a = testutil.make_array((4, 3)) + b = np.float64(10.0) + + # Run them with disabled x64 support + # This is basically a reimplementation of the `JaCeWrapped.lower()` function. + # but we have to do it this way to disable the x64 mode in translation. + with jax.experimental.disable_x64(): + jaxpr = jax.make_jaxpr(testee)(a, b) + + flat_call_args = jax_tree.tree_leaves(((a, b), {})) + builder = translator.JaxprTranslationBuilder( + primitive_translators=translator.get_registered_primitive_translators() + ) + trans_ctx: translator.TranslationContext = builder.translate_jaxpr(jaxpr) + + tsdfg: tjsdfg.TranslatedJaxprSDFG = ptranslation.postprocess_jaxpr_sdfg( + trans_ctx=trans_ctx, fun=testee, flat_call_args=flat_call_args + ) + + # Because x64 is disabled JAX traces the input as float32, even if we have passed + # float64 as input! Calling the resulting SDFG with the arguments we used for + # lowering will result in an error, because of the situation, + # `sizeof(float32) < sizeof(float64)`, no out of bound error would result, but the + # values are garbage. + assert all( + tsdfg.sdfg.arrays[input_name].dtype.as_numpy_dtype().type is np.float32 + for input_name in tsdfg.input_names + ) + + +@pytest.mark.usefixtures("_enable_jit") +def test_tracing_detection() -> None: + """Tests our ability to detect if tracing is going on.""" + expected_tracing_state = False + + def testee(a: float, b: int) -> float: + c = a + b + assert util.is_tracing_ongoing(a, b) == expected_tracing_state + assert util.is_tracing_ongoing() == expected_tracing_state + return a + c + + # We do not expect tracing to happen. + _ = testee(1.0, 1) + + # Now tracing is going on + expected_tracing_state = True + _ = jax.jit(testee)(1.0, 1) + _ = jace.jit(testee)(1.0, 1) + + # Tracing should now again be disabled + expected_tracing_state = False + _ = testee + + +def test_no_input() -> None: + """Tests if we can handle the case of no input.""" + + @jace.jit + def ones10x10() -> jax.Array: + return jnp.ones((10, 10), dtype=np.int32) + + res = ones10x10() + + assert res.shape == (10, 10) + assert res.dtype == np.int32 + assert np.all(res == 1) + + +def test_jax_array_as_input() -> None: + """This function tests if we use JAX arrays as inputs.""" + + def testee(a: jax.Array) -> jax.Array: + return jnp.sin(a + 1.0) + + a = jnp.array(testutil.make_array((10, 19))) + + ref = testee(a) + res = jace.jit(testee)(a) + + assert res.shape == ref.shape + assert res.dtype == ref.dtype + assert np.allclose(res, ref) + + +def test_jax_pytree() -> None: + """Perform if pytrees are handled correctly.""" + + def testee(a: dict[str, np.ndarray]) -> dict[str, jax.Array]: + mod_a = {k: jnp.sin(v) for k, v in a.items()} + mod_a["__additional"] = jnp.asin(a["a1"]) + return mod_a + + a = {f"a{i}": testutil.make_array((10, 10)) for i in range(4)} + ref = testee(a) + res = jace.jit(testee)(a) + + assert len(res) == len(ref) + assert type(res) == type(ref) + assert (np.allclose(res[k], ref[k]) for k in ref) diff --git a/tests/test_misc.py b/tests/unit_tests/test_misc.py similarity index 63% rename from tests/test_misc.py rename to tests/unit_tests/test_misc.py index 80abefd..263df3e 100644 --- a/tests/test_misc.py +++ b/tests/unit_tests/test_misc.py @@ -14,9 +14,11 @@ import jace +from tests import util as testutil + @pytest.mark.skip("Possible bug in DaCe.") -def test_mismatch_in_datatyte_calling(): +def test_mismatch_in_datatype_calling() -> None: """Tests compilation and calling with different types. Note that this more or less tests the calling implementation of the `CompiledSDFG` @@ -25,16 +27,16 @@ class in DaCe. As I understand the `CompiledSDFG::_construct_args()` function th """ @jace.jit - def testee(A: np.ndarray) -> np.ndarray: - return -A + def testee(a: np.ndarray) -> np.ndarray: + return -a # Different types. - A1 = np.arange(12, dtype=np.float32).reshape((4, 3)) - A2 = np.arange(12, dtype=np.int64).reshape((4, 3)) + a1 = testutil.make_array((4, 3), dtype=np.float32) + a2 = testutil.make_array((4, 3), dtype=np.int64) # Lower and compilation for first type - callee = testee.lower(A1).compile() + callee = testee.lower(a1).compile() # But calling with the second type - with pytest.raises(Exception): # noqa: B017, PT011 # Unknown exception. - _ = callee(A2) + with pytest.raises(Exception): # noqa: B017, PT011 [assert-raises-exception, pytest-raises-too-broad] # Unknown exception. + _ = callee(a2) diff --git a/tests/test_package.py b/tests/unit_tests/test_package.py similarity index 93% rename from tests/test_package.py rename to tests/unit_tests/test_package.py index 5237aeb..4d63fcc 100644 --- a/tests/test_package.py +++ b/tests/unit_tests/test_package.py @@ -15,5 +15,5 @@ @pytest.mark.skip(reason="This does not work yet.") -def test_version(): +def test_version() -> None: assert importlib.metadata.version("jace") == m.__version__ diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..3b38bfe --- /dev/null +++ b/tests/util.py @@ -0,0 +1,91 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Utility functions for the testing infrastructure.""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + +import numpy as np + +from jace import translator + + +__all__ = ["make_array"] + + +def make_array( + shape: Sequence[int] | int, + dtype: type = np.float64, + order: str = "C", + low: Any = None, + high: Any = None, +) -> np.ndarray: + """Generates a NumPy ndarray with shape `shape`. + + The function uses the generator that is managed by the `_reset_random_seed()` + fixture. Thus inside a function the value will be deterministic. + + Args: + shape: The shape to use. + dtype: The data type to use. + order: The order of the underlying array + low: Minimal value. + high: Maximal value. + + Note: + The exact meaning of `low` and `high` depend on the type. For `bool` they + are ignored. For float both must be specified and then values inside + `[low, high)` are generated. For integer it is possible to specify only one. + The appropriate numeric limit is used for the other. + """ + + if shape == (): + return dtype(make_array((1,), dtype)[0]) + if isinstance(shape, int): + shape = (shape,) + + if dtype == np.bool_: + res = np.random.random(shape) > 0.5 # noqa: NPY002 [numpy-legacy-random] + elif np.issubdtype(dtype, np.integer): + iinfo: np.iinfo = np.iinfo(dtype) + res = np.random.randint( # noqa: NPY002 [numpy-legacy-random] + low=iinfo.min if low is None else low, + high=iinfo.max if high is None else high, + size=shape, + dtype=dtype, + ) + elif np.issubdtype(dtype, np.complexfloating): + res = make_array(shape, np.float64) + 1.0j * make_array(shape, np.float64) + else: + res = np.random.random(shape) # type: ignore[assignment] # noqa: NPY002 [numpy-legacy-random] + if low is not None and high is not None: + res = low + (high - low) * res + assert (low is None) == (high is None) + + return np.array(res, order=order, dtype=dtype) # type: ignore[call-overload] # Because we use `str` as `order`. + + +def set_active_primitive_translators_to( + new_active_primitives: Mapping[str, translator.PrimitiveTranslator], +) -> dict[str, translator.PrimitiveTranslator]: + """Exchanges the currently active set of translators with `new_active_primitives`. + + The function will return the set of translators the were active before the call. + + Args: + new_active_primitives: The new set of active translators. + """ + assert all( + primitive_name == translator.primitive + for primitive_name, translator in new_active_primitives.items() + ) + previously_active_translators = translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY + translator.primitive_translator._PRIMITIVE_TRANSLATORS_REGISTRY = {**new_active_primitives} + return previously_active_translators