Skip to content

Commit

Permalink
Merge branch 'master' into dd/sum-fields
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Sep 25, 2024
2 parents f625921 + 7d9ce27 commit 016613d
Show file tree
Hide file tree
Showing 28 changed files with 681 additions and 550 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cache_dependencies.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ jobs:
gh cache delete "$cache_key"
done
# Update the matplotlib version if needed later
- name: Set up virtual environment
run: |
python -m venv .venv-${{ matrix.python-version }}
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Cache Python environment
id: cache-env
Expand Down
6 changes: 2 additions & 4 deletions .github/workflows/regression_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ jobs:
source .venv-${{ matrix.python-version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand All @@ -78,7 +77,6 @@ jobs:
if: env.has_changes == 'true'
run: |
source .venv-${{ matrix.python-version }}/bin/activate
pip install matplotlib==3.7.2
pwd
lscpu
python -m pytest -v -m regression\
Expand Down Expand Up @@ -107,9 +105,9 @@ jobs:
- name: Upload coverage
if: env.has_changes == 'true'
id : codecov
uses: Wandalen/wretry.action@v1.3.0
uses: Wandalen/wretry.action@v3.5.0
with:
action: codecov/codecov-action@v3
action: codecov/codecov-action@v4
with: |
token: ${{ secrets.CODECOV_TOKEN }}
name: codecov-umbrella
Expand Down
6 changes: 2 additions & 4 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ jobs:
source .venv-${{ matrix.combos.python_version }}/bin/activate
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Set Swap Space
if: env.has_changes == 'true'
Expand All @@ -84,7 +83,6 @@ jobs:
if: env.has_changes == 'true'
run: |
source .venv-${{ matrix.combos.python_version }}/bin/activate
pip install matplotlib==3.7.2
pwd
lscpu
python -m pytest -v -m unit \
Expand Down Expand Up @@ -113,9 +111,9 @@ jobs:
- name: Upload coverage
if: env.has_changes == 'true'
id : codecov
uses: Wandalen/wretry.action@v1.3.0
uses: Wandalen/wretry.action@v3.5.0
with:
action: codecov/codecov-action@v3
action: codecov/codecov-action@v4
with: |
token: ${{ secrets.CODECOV_TOKEN }}
name: codecov-umbrella
Expand Down
25 changes: 11 additions & 14 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,46 +156,43 @@ Opening a PR will trigger a suite of tests and style/formatting checks that must
We also require approval from at least one (ideally multiple) of the main DESC developers, who may have suggested changes
or edits to your PR.

What if the ``test_compute_everything`` test fails, or there is a conflict in ``master_compute_data.pkl``?
What if the ``test_compute_everything`` test fails, or there is a conflict in ``master_compute_data_rpz.pkl``?
----------------------------------------------------------------------------------------------------------
When the outputs of the compute quantities tested by the`test_compute_everything` [test](https://github.com/PlasmaControl/DESC/blob/master/tests/test_compute_funs.py) are changed in a PR, that test will fail.
When the outputs of the compute quantities tested by the `test_compute_everything` [test](https://github.com/PlasmaControl/DESC/blob/master/tests/test_compute_everything.py) are changed in a PR, that test will fail.
The three main reasons this could occur are:

- The PR was not intended to change how things are computed, but messed up something unexpected and now the compute quantities are incorrect, if you did not expect these changes in the PR then look into why these differences are happening and fix the PR.
- The PR updated the way one of the existing compute index quantities are computed (either by a redefinition or perhaps fixing an error present in ``master``)
- The PR added a new class parametrization (such as a new subclass of ``Curve`` like ``LinearCurve`` etc)

If the 2nd case is the reason, then you must update the ``master_compute_data.pkl`` file with the correct quantities being computed by your PR:
If the 2nd case is the reason, then you must update the ``master_compute_data_rpz.pkl`` file with the correct quantities being computed by your PR:

- First, run the test with ``pytest tests -k test_compute_everything`` and inspect the compute quantities whose values are in error, to ensure that only the quantities you expect to be different are shown (and that the new values are indeed the correct ones, you should have a test elsewhere for that though).
- If the values are as expected and only the expected compute quantities are different, then replace the block

.. code-block:: python
except AssertionError as e:
error = True
print(e)
if not error_rpz and update_master_data_rpz:
# then update the master compute data
with

.. code-block:: python
except AssertionError as e:
error = False
update_master_data = True
print(e)
if True or (not error_rpz and update_master_data_rpz):
# then update the master compute data
- rerun the test ``pytest tests -k test_compute_everything``, now any compute quantity that is different between the PR and master will be updated with the PR value
- ``git restore tests/test_compute_funs.py`` to remove the change you made to the test
- ``git add tests/inputs/master_compute_data.pkl`` and commit to commit the new data file
- ``git restore tests/test_compute_everything.py`` to remove the change you made to the test
- ``git add tests/inputs/master_compute_data_rpz.pkl`` and commit to commit the new data file

If the 3rd case is the reason, then you must simply add the new parametrization to the ``test_compute_everything`` [test](https://github.com/PlasmaControl/DESC/blob/master/tests/test_compute_funs.py)
If the 3rd case is the reason, then you must simply add the new parametrization to the ``test_compute_everything`` [test](https://github.com/PlasmaControl/DESC/blob/master/tests/test_compute_everything.py)

- ``things`` dictionary with a sensible example instance of the class to use for the test, and
- to the ``grid`` dictionary with a sensible default grid to use when computing the compute quantities for the new class
- Then, rerunning the test ``pytest tests -k test_compute_everything`` will add the compute quantities for the new class and save them to the ``.pkl`` file
- ``git add tests/inputs/master_compute_data.pkl`` and commit to commit the new data file
- ``git add tests/inputs/master_compute_data_rpz.pkl`` and commit to commit the new data file

Styleguides
^^^^^^^^^^^
Expand Down
58 changes: 27 additions & 31 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,7 @@
from jax.numpy import bincount, flatnonzero, repeat, take
from jax.numpy.fft import irfft, rfft, rfft2
from jax.scipy.fft import dct, idct
from jax.scipy.linalg import (
block_diag,
cho_factor,
cho_solve,
eigh_tridiagonal,
qr,
solve_triangular,
)
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import (
register_pytree_node,
Expand All @@ -98,6 +91,31 @@
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will always run on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

# JAX implementation is not differentiable on gpu.
eigh_tridiagonal = execute_on_cpu(jax.scipy.linalg.eigh_tridiagonal)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".
Expand All @@ -123,28 +141,6 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

def sign(x):
"""Sign function, but returns 1 for x==0.
Expand Down Expand Up @@ -427,7 +423,7 @@ def tangent_solve(g, y):

trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz

def imap(f, xs, batch_size=None, in_axes=0, out_axes=0):
def imap(f, xs, *, batch_size=None, in_axes=0, out_axes=0):
"""Generalizes jax.lax.map; uses numpy."""
if not isinstance(xs, np.ndarray):
raise NotImplementedError(
Expand Down
6 changes: 3 additions & 3 deletions desc/compute/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
label="\\alpha",
units="~",
units_long="None",
description="Field line label, defined on [0, 2pi)",
description="Field line label",
dim=1,
params=[],
transforms={},
Expand All @@ -1503,7 +1503,7 @@ def _Z_zzz(params, transforms, profiles, data, **kwargs):
data=["theta_PEST", "phi", "iota"],
)
def _alpha(params, transforms, profiles, data, **kwargs):
data["alpha"] = (data["theta_PEST"] - data["iota"] * data["phi"]) % (2 * jnp.pi)
data["alpha"] = data["theta_PEST"] - data["iota"] * data["phi"]
return data


Expand Down Expand Up @@ -3077,7 +3077,7 @@ def _theta(params, transforms, profiles, data, **kwargs):
data=["theta", "lambda"],
)
def _theta_PEST(params, transforms, profiles, data, **kwargs):
data["theta_PEST"] = (data["theta"] + data["lambda"]) % (2 * jnp.pi)
data["theta_PEST"] = data["theta"] + data["lambda"]
return data


Expand Down
Loading

0 comments on commit 016613d

Please sign in to comment.