-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test: Added an Extensive Set of Tests #21
base: main
Are you sure you want to change the base?
Changes from all commits
345d1bb
979f4ec
2d2c160
8adae42
011a2d6
a3ac868
5f186f6
df82038
60b7acb
2190faf
5d6962a
6c980a9
6dd65f3
a0996eb
f7fe13d
7b7b7ef
28c77ca
054bb22
c7e8b9f
a28a9c8
5bb4b5e
959849a
047bb46
64135b9
5756d1e
8b66af8
a5826e9
6f4c45b
4735092
758adc8
7fa03c3
ba854e0
6b634bd
cd7fc01
2cdd3fd
e4bd289
37c933c
699b93e
53aae7b
eabda1f
f0ea5b4
fae3ce3
71eae32
43ba073
a23f0ab
555e815
fc91156
1be3237
9ab2e9a
e4a099d
8bdea82
f5eb441
4179d40
acfcf39
beeb9fd
cf99b02
7424687
b996041
bb9d81d
95b34cf
72c055a
66fa038
fa599c3
e164d75
7fe6f3b
b518ec7
94bdf57
380dca9
35dc1f1
93e7935
ff66995
74435a2
acb9501
5373ca9
e9597f2
4499cfb
50b4692
5c1e8c6
01cc777
9ab7360
c8b9763
f5b7ccc
1c3d3a6
7990569
972c9c0
95f04be
bbc51d0
f1846eb
bf68021
a376aad
8a79f83
a4697be
b64b666
ed23d49
3518bdc
7d6cc9c
630fcce
e46637d
c7f6cc9
73f0116
ca7f32a
3c6194a
174bea8
4f87173
c5ebe96
45fce56
a2ee92d
b7fcc80
ceb46be
b47284d
40f8574
50860bb
4f52bad
9af8a4e
647c5f7
c8c612c
4788405
414c55f
8cec411
9b40b27
894b10c
1928b95
9909f6d
5f24405
cc7f649
03d6d08
0da47fa
b7a166f
367e7b6
1e825c3
7607fd9
b759c3b
d123166
39dd108
889ec88
c1074c4
9ac17bd
6a8e2b5
c05ee3c
ce4c643
1743374
efc27e0
0221d7b
fd9fb8e
e43dd53
744c2b2
a11f111
9cd0e01
3377f41
8065ed2
272dd00
3e5408b
3f1d2ad
e6730d1
82cf898
da57db9
6eef078
1700885
90a25a6
61638df
6369908
afe5cdf
0aa1726
3424a59
d8603d5
54374d3
261d902
027ae35
7cdb5f5
21e64a8
9fe9e2d
2ecb7ee
0b52d4b
4eb806f
744738b
71b6422
e3a96ad
bcd941f
4be66b1
617c5c1
031b33f
c582669
565c69d
0dd9404
6b69a45
f7ee981
8698397
f3a65d6
969f2f8
29d09e7
d029d13
bb938c7
4f343b3
72d6771
1766bc3
d3dac22
25f52a9
ed8f653
f4104b8
95853a3
67e546c
0ec665b
5bc6220
5385a5c
1296b9d
7615be3
70710c1
44a8a80
c8a6aa8
819eebb
110b6f4
380399a
320c528
e2ebc4e
071923c
0869287
e7fa770
aa92f01
36e6b6b
a25ecb7
3048019
8932815
f9ee01a
fa65ee7
b9f9427
5a3c87f
d6265bc
25616ae
24d97fb
9f22bd2
411bd7b
c8b7d86
cb600d3
c29fc0d
b0ac641
3ee8dad
e787ee9
846a345
a546932
090c3a2
e667854
6fe78b0
d88752a
2c7b3c8
5687f40
cc10a48
18bdee9
3588a77
2169470
e17b4eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Main Differences Between DaCe and JaCe and JAX and JaCe | ||
|
||
Essentially JaCe is a frontend that allows DaCe to process JAX code, thus it has to be compatible with both, at least in some sense. | ||
We will now list the main differences between them, furthermore, you should also consult the ROADMAP. | ||
|
||
### JAX vs. JaCe: | ||
|
||
- JaCe always traces with enabled `x64` mode. | ||
This is a restriction that might be lifted in the future. | ||
- JAX returns scalars as zero-dimensional arrays, JaCe returns them as array with shape `(1, )`. | ||
- In JAX parts of the computation runs on CPU parts on GPU, in JaCe everything runs (currently) either on CPU or GPU. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean here? Which parts run on CPU/GPU in JAX? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The JAX compiler, i.e. XLA can decide to do this. |
||
- Currently JaCe is only able to run on CPU (will be lifted soon). | ||
- Currently JaCe is not able to run distributed (will be lifted later). | ||
- Currently not all primitives are supported. | ||
- JaCe does not return `jax.Array` instances, but NumPy/CuPy arrays. | ||
- The execution is not asynchronous. | ||
Comment on lines
+15
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two points could be also fixed in the future, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
### DaCe vs. JaCe: | ||
|
||
- JaCe accepts complex objects using JAX' pytrees. | ||
- JaCe will support scalar inputs on GPU. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""Contains all common fixture we need.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""General configuration for the tests. | ||
|
||
Todo: | ||
- Implement some fixture that allows to force validation. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Generator | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
|
||
from jace import optimization, stages | ||
from jace.util import translation_cache as tcache | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For all fixtures in this module:
Additionally, I would define another fixture requesting all other fixtures expected in the standard case and use this # This file
@pytest.fixture
def standard_jace_test_settings(enable_x64_mode_in_jax, disable_jit, ...) -> ...:
....
# Other test files
pytestmark = pytest.mark.usefixtures("standard_jace_test_settings") Finally, I'd create a simpler type alias for the return type of generator fixtures as suggested here : T = TypeVar("T")
YieldFixture = Generator[T, None, None]
@pytest.fixture
def foo() -> YieldFixture[str]:
yield "foo" |
||
@pytest.fixture(autouse=True) | ||
def _enable_x64_mode_in_jax() -> Generator[None, None, None]: | ||
"""Fixture of enable the `x64` mode in JAX. | ||
|
||
Currently, JaCe requires that `x64` mode is enabled and will do all JAX | ||
things with it enabled. However, if we use JAX with the intend to compare | ||
it against JaCe we must also enable it for JAX. | ||
""" | ||
with jax.experimental.enable_x64(): | ||
yield | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _disable_jit() -> Generator[None, None, None]: | ||
"""Fixture for disable the dynamic jiting in JAX, used by default. | ||
|
||
Using this fixture has two effects. | ||
- JAX will not cache the results, i.e. every call to a jitted function will | ||
result in a tracing operation. | ||
- JAX will not use implicit jit operations, i.e. nested Jaxpr expressions | ||
using `pjit` are avoided. | ||
|
||
This essentially disable the `jax.jit` decorator, however, the `jace.jit` | ||
decorator is still working. | ||
|
||
Note: | ||
The second point, i.e. preventing JAX from running certain things in `pjit`, | ||
is the main reason why this fixture is used by default, without it | ||
literal substitution is useless and essentially untestable. | ||
In certain situation it can be disabled. | ||
""" | ||
with jax.disable_jit(disable=True): | ||
yield | ||
|
||
|
||
@pytest.fixture() | ||
def _enable_jit() -> Generator[None, None, None]: | ||
"""Fixture to enable jit compilation. | ||
|
||
Essentially it undoes the effects of the `_disable_jit()` fixture. | ||
It is important that this fixture is not automatically activated. | ||
""" | ||
with jax.disable_jit(disable=False): | ||
yield | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _clear_translation_cache() -> Generator[None, None, None]: | ||
"""Decorator that clears the translation cache. | ||
|
||
Ensures that a function finds an empty cache and clears up afterwards. | ||
""" | ||
tcache.clear_translation_cache() | ||
yield | ||
tcache.clear_translation_cache() | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _reset_random_seed() -> None: | ||
"""Fixture for resetting the random seed. | ||
|
||
This ensures that for every test the random seed of NumPy is reset. | ||
This seed is used by the `util.mkarray()` helper. | ||
""" | ||
np.random.seed(42) # noqa: NPY002 [numpy-legacy-random] | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def _set_compile_options() -> Generator[None, None, None]: | ||
"""Disable all optimizations of jitted code. | ||
|
||
Without explicitly supplied arguments `JaCeLowered.compile()` will not | ||
perform any optimizations. | ||
Please not that certain tests might override this fixture. | ||
""" | ||
with stages.set_compiler_options(optimization.NO_OPTIMIZATIONS): | ||
yield |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""JaCe's integration tests. | ||
|
||
Currently they are mostly related to the primitive translators. | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""Tests related to the actual primitive subtranslators.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) | ||
# | ||
# Copyright (c) 2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
"""General configuration for the tests of the primitive translators.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Generator | ||
|
||
import pytest | ||
|
||
from jace import optimization, stages | ||
|
||
|
||
@pytest.fixture( | ||
autouse=True, | ||
params=[ | ||
optimization.NO_OPTIMIZATIONS, | ||
pytest.param( | ||
optimization.DEFAULT_OPTIMIZATIONS, | ||
marks=pytest.mark.skip( | ||
"Simplify bug 'https://github.com/spcl/dace/issues/1595'; resolved > 16.1" | ||
), | ||
), | ||
], | ||
) | ||
def _set_compile_options(request) -> Generator[None, None, None]: | ||
"""Set the options used for testing the primitive translators. | ||
|
||
This fixture override the global defined fixture. | ||
|
||
Todo: | ||
Implement a system that only runs the optimization case in CI. | ||
""" | ||
with stages.set_compiler_options(request.param): | ||
yield |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main issue is that DaCe does not have a real concept of zero dimensional arrays, as far as I know.
Consider the following two functions
if you pass a zero dimensional array to
bar()
then it will be casted to an scalar, if you pass it tobaz()
an error will happen.Furthermore, the binary interface of the SDFG can not return scalars, return values have to be arrays there is no way around that without patching the code generator and making a lot of changes to handle special cases.
So I decided to follow PEP20 and decided that this case is not special enough to change the rule.
If you want this feature then please open an issue.