diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index ad00eeef5d..68c9442106 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -4,9 +4,6 @@ on: pull_request: branches: - master - paths-ignore: - - 'docs/**' - - 'devtools/**' workflow_dispatch: inputs: debug_enabled: @@ -22,6 +19,12 @@ concurrency: jobs: benchmark: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.9'] + group: [1, 2] steps: # Enable tmate debugging of manually-triggered workflows if the input option was provided @@ -31,41 +34,91 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ github.event.pull_request.head.sha }} - - name: Set up Python 3.9 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/benchmarks/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/benchmark.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 - - name: Install dependencies + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt - - name: Benchmark with pytest-benchmark + + - name: Benchmark with pytest-benchmark (PR) + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd lscpu cd tests/benchmarks python -m pytest benchmark_cpu_small.py -vv \ --benchmark-save='Latest_Commit' \ --durations=0 \ - --benchmark-save-data + --benchmark-save-data \ + --splits 2 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration + - name: Checkout current master + if: env.has_changes == 'true' uses: actions/checkout@v4 with: ref: master clean: false + - name: Checkout benchmarks from PR head + if: env.has_changes == 'true' run: git checkout ${{ github.event.pull_request.head.sha }} -- tests/benchmarks - - name: Benchmark with pytest-benchmark + + - name: Benchmark with pytest-benchmark (MASTER) + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd lscpu cd tests/benchmarks python -m pytest benchmark_cpu_small.py -vv \ --benchmark-save='master' \ --durations=0 \ - --benchmark-save-data - - name: put benchmark results in same folder + --benchmark-save-data \ + --splits 2 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration + + - name: Put benchmark results in same folder + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate pwd cd tests/benchmarks find .benchmarks/ -type f -printf "%T@ %p\n" | sort -n | cut -d' ' -f 2- | tail -n 1 > temp1 @@ -75,22 +128,36 @@ jobs: mkdir compare_results cp $t1 compare_results cp $t2 compare_results + + - name: Download artifact + if: always() && env.has_changes == 'true' + uses: actions/download-artifact@v4 + with: + pattern: benchmark_artifact_* + path: tests/benchmarks + - name: Compare latest commit results to the master branch results + if: env.has_changes == 'true' run: | - pwd + source .venv-${{ matrix.python-version }}/bin/activate cd tests/benchmarks + pwd python compare_bench_results.py cat commit_msg.txt - - name: comment PR with the results + + - name: Comment PR with the results + if: env.has_changes == 'true' uses: thollander/actions-comment-pull-request@v2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: filePath: tests/benchmarks/commit_msg.txt comment_tag: benchmark + - name: Upload benchmark data - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: - name: benchmark_artifact + name: benchmark_artifact_${{ matrix.group }} path: tests/benchmarks/.benchmarks + include-hidden-files: true diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 993e56ee4b..ae0be3dbb7 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -6,26 +6,48 @@ jobs: black_format: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.10'] + steps: - uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' - - name: Install dependencies + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt + - name: Check files using the black formatter run: | + source .venv-${{ matrix.python-version }}/bin/activate black --version black --check desc/ tests/ || black_return_code=$? echo "BLACK_RETURN_CODE=$black_return_code" >> $GITHUB_ENV black desc/ tests/ + - name: Annotate diff changes using reviewdog uses: reviewdog/action-suggester@v1 with: tool_name: blackfmt + - name: Fail if not formatted run: | exit ${{ env.BLACK_RETURN_CODE }} diff --git a/.github/workflows/cache_dependencies.yml b/.github/workflows/cache_dependencies.yml new file mode 100644 index 0000000000..55da0c2e6d --- /dev/null +++ b/.github/workflows/cache_dependencies.yml @@ -0,0 +1,56 @@ +name: Cache dependencies +# This workflow is triggered every 2 days and updates the Python +# and pip dependencies cache +on: + schedule: + - cron: '30 8 */2 * *' # This triggers the workflow at 4:30 AM ET every 2 days + # cron syntax uses UTC time, so 4:30 AM ET is 8:30 AM UTC (for daylight time) + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Delete old cached file with same python version + run: | + echo "Current Cached files list" + gh cache list + echo "Deleting cached files with pattern: ${{ runner.os }}-venv-${{ matrix.python-version }}-" + for cache_key in $(gh cache list --json key -q ".[] | select(.key | startswith(\"${{ runner.os }}-venv-${{ matrix.python-version }}-\")) | .key"); do + echo "Deleting cache with key: $cache_key" + gh cache delete "$cache_key" + done + + - name: Set up virtual environment + run: | + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate + python -m pip install --upgrade pip + pip install -r devtools/dev-requirements.txt + + - name: Cache Python environment + id: cache-env + uses: actions/cache@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Verify virtual environment activation + run: | + source .venv-${{ matrix.python-version }}/bin/activate + python --version + pip --version + pip list diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index c3ea0f96e9..2eef77dcc0 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -5,17 +5,24 @@ on: [pull_request, workflow_dispatch] jobs: flake8_linting: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.10'] + name: Linting steps: - uses: actions/checkout@v4 - - name: Set up Python 3.10 + - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: ${{ matrix.python-version }} + + # For some reason, loading venv makes this way slower - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt + - name: flake8 Lint uses: reviewdog/action-flake8@v3 with: diff --git a/.github/workflows/nbtests.yml b/.github/workflows/nbtests.yml deleted file mode 100644 index 6d1fc6ca24..0000000000 --- a/.github/workflows/nbtests.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Notebook tests - -on: - push: - branches: - - master - - dev - pull_request: - branches: - - master - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - notebook_tests: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.10'] - group: [1, 2] - - steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r devtools/dev-requirements.txt - - name: Test notebooks with pytest and nbmake - run: | - pwd - lscpu - export PYTHONPATH=$(pwd) - pytest -v --nbmake "./docs/notebooks" \ - --nbmake-timeout=2000 \ - --ignore=./docs/notebooks/zernike_eval.ipynb \ - --splits 2 \ - --group ${{ matrix.group }} \ diff --git a/.github/workflows/notebook_tests.yml b/.github/workflows/notebook_tests.yml new file mode 100644 index 0000000000..88c6570c75 --- /dev/null +++ b/.github/workflows/notebook_tests.yml @@ -0,0 +1,83 @@ +name: Notebook tests + +on: + push: + branches: + - master + - dev + pull_request: + branches: + - master + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + notebook_tests: + + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + strategy: + matrix: + python-version: ['3.10'] + group: [1, 2, 3] + + steps: + - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'docs/notebooks/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/notebook_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' + run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate + python -m pip install --upgrade pip + pip install -r devtools/dev-requirements.txt + + - name: Test notebooks with pytest and nbmake + if: env.has_changes == 'true' + run: | + source .venv-${{ matrix.python-version }}/bin/activate + pwd + lscpu + export PYTHONPATH=$(pwd) + pytest -v --nbmake "./docs/notebooks" \ + --nbmake-timeout=2000 \ + --ignore=./docs/notebooks/zernike_eval.ipynb \ + --splits 3 \ + --group ${{ matrix.group }} \ + --splitting-algorithm least_duration diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_tests.yml similarity index 57% rename from .github/workflows/regression_test.yml rename to .github/workflows/regression_tests.yml index d9ef1072d6..12ef17b1e0 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_tests.yml @@ -18,6 +18,8 @@ jobs: regression_tests: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} strategy: matrix: python-version: ['3.10'] @@ -25,24 +27,61 @@ jobs: steps: - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/regression_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.python-version }} + key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.python-version }} + source .venv-${{ matrix.python-version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt pip install matplotlib==3.7.2 + - name: Set Swap Space + if: env.has_changes == 'true' uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 + - name: Test with pytest + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.python-version }}/bin/activate + pip install matplotlib==3.7.2 pwd lscpu - python -m pytest -v -m regression \ + python -m pytest -v -m regression\ --durations=0 \ --cov-report xml:cov.xml \ --cov-config=setup.cfg \ @@ -54,8 +93,9 @@ jobs: --group ${{ matrix.group }} \ --splitting-algorithm least_duration \ --db ./prof.db + - name: save coverage file and plot comparison results - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: name: regression_test_artifact-${{ matrix.python-version }}-${{ matrix.group }} @@ -63,7 +103,9 @@ jobs: ./cov.xml ./mpl_results.html ./prof.db + - name: Upload coverage + if: env.has_changes == 'true' id : codecov uses: Wandalen/wretry.action@v1.3.0 with: diff --git a/.github/workflows/unittest.yml b/.github/workflows/unit_tests.yml similarity index 54% rename from .github/workflows/unittest.yml rename to .github/workflows/unit_tests.yml index 57e5881a05..fe58f21953 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unit_tests.yml @@ -18,30 +18,73 @@ jobs: unit_tests: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} strategy: matrix: combos: [{group: 1, python_version: '3.9'}, {group: 2, python_version: '3.10'}, {group: 3, python_version: '3.11'}, - {group: 4, python_version: '3.12'}] + {group: 4, python_version: '3.12'}, + {group: 5, python_version: '3.12'}, + {group: 6, python_version: '3.12'}, + {group: 7, python_version: '3.12'}, + {group: 8, python_version: '3.12'}] steps: - uses: actions/checkout@v4 + + - name: Filter changes + id: changes + uses: dorny/paths-filter@v3 + with: + filters: | + has_changes: + - 'desc/**' + - 'tests/**' + - 'requirements.txt' + - 'devtools/dev-requirements.txt' + - 'setup.cfg' + - '.github/workflows/unit_tests.yml' + + - name: Check for relevant changes + id: check_changes + run: echo "has_changes=${{ steps.changes.outputs.has_changes }}" >> $GITHUB_ENV + - name: Set up Python ${{ matrix.combos.python_version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.combos.python_version }} - - name: Install dependencies + + - name: Restore Python environment cache + if: env.has_changes == 'true' + id: restore-env + uses: actions/cache/restore@v4 + with: + path: .venv-${{ matrix.combos.python_version }} + key: ${{ runner.os }}-venv-${{ matrix.combos.python_version }}-${{ hashFiles('devtools/dev-requirements.txt', 'requirements.txt') }} + + - name: Set up virtual environment if not restored from cache + if: steps.restore-env.outputs.cache-hit != 'true' && env.has_changes == 'true' run: | + gh cache list + python -m venv .venv-${{ matrix.combos.python_version }} + source .venv-${{ matrix.combos.python_version }}/bin/activate python -m pip install --upgrade pip pip install -r devtools/dev-requirements.txt pip install matplotlib==3.7.2 + - name: Set Swap Space + if: env.has_changes == 'true' uses: pierotofy/set-swap-space@master with: swap-size-gb: 10 + - name: Test with pytest + if: env.has_changes == 'true' run: | + source .venv-${{ matrix.combos.python_version }}/bin/activate + pip install matplotlib==3.7.2 pwd lscpu python -m pytest -v -m unit \ @@ -52,12 +95,13 @@ jobs: --mpl \ --mpl-results-path=mpl_results.html \ --mpl-generate-summary=html \ - --splits 4 \ + --splits 8 \ --group ${{ matrix.combos.group }} \ --splitting-algorithm least_duration \ --db ./prof.db + - name: save coverage file and plot comparison results - if: always() + if: always() && env.has_changes == 'true' uses: actions/upload-artifact@v4 with: name: unit_test_artifact-${{ matrix.combos.python_version }}-${{ matrix.combos.group }} @@ -65,7 +109,9 @@ jobs: ./cov.xml ./mpl_results.html ./prof.db + - name: Upload coverage + if: env.has_changes == 'true' id : codecov uses: Wandalen/wretry.action@v1.3.0 with: diff --git a/.github/workflows/scheduled.yml b/.github/workflows/weekly_tests.yml similarity index 75% rename from .github/workflows/scheduled.yml rename to .github/workflows/weekly_tests.yml index a584db5cb8..2fb309bd8a 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/weekly_tests.yml @@ -11,11 +11,10 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - combos: [{group: 1, python_version: '3.8'}, - {group: 2, python_version: '3.9'}, - {group: 3, python_version: '3.10'}, - {group: 4, python_version: '3.11'}, - {group: 5, python_version: '3.12'}] + combos: [{group: 1, python_version: '3.9'}, + {group: 2, python_version: '3.10'}, + {group: 3, python_version: '3.11'}, + {group: 4, python_version: '3.12'}] steps: - uses: actions/checkout@v4 @@ -37,6 +36,6 @@ jobs: lscpu python -m pytest -v -m unit \ --durations=0 \ - --splits 5 \ + --splits 4 \ --group ${{ matrix.combos.group }} \ --splitting-algorithm least_duration diff --git a/CHANGELOG.md b/CHANGELOG.md index aab80a4173..ccf156230a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ Changelog New Features - Add ``use_signed_distance`` flag to ``PlasmaVesselDistance`` which will use a signed distance as the target, which is positive when the plasma is inside of the vessel surface and negative if the plasma is outside of the vessel surface, to allow optimizer to distinguish if the equilbrium surface exits the vessel surface and guard against it by targeting a positive signed distance. +- Add ``VectorPotentialField`` class to allow calculation of magnetic fields from a user-specified + vector potential function. +- Add ``compute_magnetic_vector_potential`` methods to most ``MagneticField`` objects to allow vector potential + computation. +- Add ability to save and load vector potential information from ``mgrid`` files. +- Changes ``ToroidalFlux`` objective to default using a 1D loop integral of the vector potential +to compute the toroidal flux when possible, as opposed to a 2D surface integral of the magnetic field dotted with ``n_zeta``. +- Allow specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file + +Bug Fixes + +- Fixes bugs that occur when saving asymmetric equilibria as wout files +- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file + + v0.12.1 ------- diff --git a/README.rst b/README.rst index 18e4400c79..24e56055ad 100644 --- a/README.rst +++ b/README.rst @@ -111,12 +111,12 @@ Contribute :target: https://desc-docs.readthedocs.io/en/latest/?badge=latest :alt: Documentation -.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml/badge.svg - :target: https://github.com/PlasmaControl/DESC/actions/workflows/unittest.yml +.. |UnitTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml/badge.svg + :target: https://github.com/PlasmaControl/DESC/actions/workflows/unit_tests.yml :alt: UnitTests -.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml/badge.svg - :target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_test.yml +.. |RegressionTests| image:: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml/badge.svg + :target: https://github.com/PlasmaControl/DESC/actions/workflows/regression_tests.yml :alt: RegressionTests .. |Codecov| image:: https://codecov.io/gh/PlasmaControl/DESC/branch/master/graph/badge.svg?token=5LDR4B1O7Z diff --git a/codecov.yml b/codecov.yml index 8d3a272f14..e4c14a2bc1 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,7 +4,7 @@ comment: # this is a top-level key require_changes: false # if true: only post the comment if coverage changes require_base: true # [true :: must have a base report to post] require_head: true # [true :: must have a head report to post] - after_n_builds: 10 + after_n_builds: 14 coverage: status: patch: diff --git a/desc/backend.py b/desc/backend.py index c26213b045..3b47cba4c5 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -66,19 +66,23 @@ ) if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign? - jit = jax.jit - fori_loop = jax.lax.fori_loop - cond = jax.lax.cond - switch = jax.lax.switch - while_loop = jax.lax.while_loop - vmap = jax.vmap - bincount = jnp.bincount - repeat = jnp.repeat - take = jnp.take - scan = jax.lax.scan - from jax import custom_jvp + from jax import custom_jvp, jit, vmap + + imap = jax.lax.map from jax.experimental.ode import odeint - from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular + from jax.lax import cond, fori_loop, scan, switch, while_loop + from jax.nn import softmax as softargmax + from jax.numpy import bincount, flatnonzero, repeat, take + from jax.numpy.fft import irfft, rfft, rfft2 + from jax.scipy.fft import dct, idct + from jax.scipy.linalg import ( + block_diag, + cho_factor, + cho_solve, + eigh_tridiagonal, + qr, + solve_triangular, + ) from jax.scipy.special import gammaln, logsumexp from jax.tree_util import ( register_pytree_node, @@ -90,6 +94,10 @@ treedef_is_leaf, ) + trapezoid = ( + jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid + ) + def put(arr, inds, vals): """Functional interface for array "fancy indexing". @@ -328,6 +336,8 @@ def root( This routine may be used on over or under-determined systems, in which case it will solve it in a least squares / least norm sense. """ + from desc.utils import safenorm + if fixup is None: fixup = lambda x, *args: x if jac is None: @@ -392,7 +402,7 @@ def tangent_solve(g, y): x, (res, niter) = jax.lax.custom_root( res, x0, solve, tangent_solve, has_aux=True ) - return x, (jnp.linalg.norm(res), niter) + return x, (safenorm(res), niter) # we can't really test the numpy backend stuff in automated testing, so we ignore it @@ -401,15 +411,54 @@ def tangent_solve(g, y): jit = lambda func, *args, **kwargs: func execute_on_cpu = lambda func: func import scipy.optimize + from numpy.fft import irfft, rfft, rfft2 # noqa: F401 + from scipy.fft import dct, idct # noqa: F401 from scipy.integrate import odeint # noqa: F401 from scipy.linalg import ( # noqa: F401 block_diag, cho_factor, cho_solve, + eigh_tridiagonal, qr, solve_triangular, ) from scipy.special import gammaln, logsumexp # noqa: F401 + from scipy.special import softmax as softargmax # noqa: F401 + + trapezoid = np.trapezoid if hasattr(np, "trapezoid") else np.trapz + + def imap(f, xs, batch_size=None, in_axes=0, out_axes=0): + """Generalizes jax.lax.map; uses numpy.""" + if not isinstance(xs, np.ndarray): + raise NotImplementedError( + "Require numpy array input, or install jax to support pytrees." + ) + xs = np.moveaxis(xs, source=in_axes, destination=0) + return np.stack([f(x) for x in xs], axis=out_axes) + + def vmap(fun, in_axes=0, out_axes=0): + """A numpy implementation of jax.lax.map whose API is a subset of jax.vmap. + + Like Python's builtin map, + except inputs and outputs are in the form of stacked arrays, + and the returned object is a vectorized version of the input function. + + Parameters + ---------- + fun: callable + Function (A -> B) + in_axes: int + Axis to map over. + out_axes: int + An integer indicating where the mapped axis should appear in the output. + + Returns + ------- + fun_vmap: callable + Vectorized version of fun. + + """ + return lambda xs: imap(fun, xs, in_axes=in_axes, out_axes=out_axes) def tree_stack(*args, **kwargs): """Stack pytree for numpy backend.""" @@ -592,32 +641,6 @@ def while_loop(cond_fun, body_fun, init_val): val = body_fun(val) return val - def vmap(fun, out_axes=0): - """A numpy implementation of jax.lax.map whose API is a subset of jax.vmap. - - Like Python's builtin map, - except inputs and outputs are in the form of stacked arrays, - and the returned object is a vectorized version of the input function. - - Parameters - ---------- - fun: callable - Function (A -> B) - out_axes: int - An integer indicating where the mapped axis should appear in the output. - - Returns - ------- - fun_vmap: callable - Vectorized version of fun. - - """ - - def fun_vmap(fun_inputs): - return np.stack([fun(fun_input) for fun_input in fun_inputs], axis=out_axes) - - return fun_vmap - def scan(f, init, xs, length=None, reverse=False, unroll=1): """Scan a function over leading array axes while carrying along state. @@ -657,9 +680,14 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): ys.append(y) return carry, np.stack(ys) - def bincount(x, weights=None, minlength=None, length=None): - """Same as np.bincount but with a dummy parameter to match jnp.bincount API.""" - return np.bincount(x, weights, minlength) + def bincount(x, weights=None, minlength=0, length=None): + """A numpy implementation of jnp.bincount.""" + x = np.clip(x, 0, None) + if length is None: + length = max(minlength, x.max() + 1) + else: + minlength = max(minlength, length) + return np.bincount(x, weights, minlength)[:length] def repeat(a, repeats, axis=None, total_repeat_length=None): """A numpy implementation of jnp.repeat.""" @@ -778,6 +806,13 @@ def root( out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol) return out.x, out + def flatnonzero(a, size=None, fill_value=0): + """A numpy implementation of jnp.flatnonzero.""" + nz = np.flatnonzero(a) + if size is not None: + nz = np.pad(nz, (0, max(size - nz.size, 0)), constant_values=fill_value) + return nz + def take( a, indices, diff --git a/desc/coils.py b/desc/coils.py index f184365918..9ffc5015c7 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -19,7 +19,6 @@ from desc.compute import get_params, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec from desc.compute.geom_utils import reflection_matrix from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.geometry import ( FourierPlanarCurve, FourierRZCurve, @@ -29,7 +28,7 @@ from desc.grid import LinearGrid from desc.magnetic_fields import _MagneticField from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter -from desc.utils import equals, errorif, flatten_list, warnif +from desc.utils import equals, errorif, flatten_list, safenorm, warnif @jit @@ -82,6 +81,53 @@ def biot_savart_hh(eval_pts, coil_pts_start, coil_pts_end, current): return B +@jit +def biot_savart_vector_potential_hh(eval_pts, coil_pts_start, coil_pts_end, current): + """Biot-Savart law for vector potential for filamentary coils following [1]. + + The coil is approximated by a series of straight line segments + and an analytic expression is used to evaluate the vector potential from each + segment. This expression assumes the Coulomb gauge. + + Parameters + ---------- + eval_pts : array-like shape(n,3) + Evaluation points in cartesian coordinates + coil_pts_start, coil_pts_end : array-like shape(m,3) + Points in cartesian space defining the start and end of each segment. + Should be a closed curve, such that coil_pts_start[0] == coil_pts_end[-1] + though this is not checked. + current : float + Current through the coil (in Amps). + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential in cartesian components at specified points + + [1] Hanson & Hirshman, "Compact expressions for the Biot-Savart + fields of a filamentary segment" (2002) + """ + d_vec = coil_pts_end - coil_pts_start + L = jnp.linalg.norm(d_vec, axis=-1) + d_vec_over_L = ((1 / L) * d_vec.T).T + + Ri_vec = eval_pts[jnp.newaxis, :] - coil_pts_start[:, jnp.newaxis, :] + Ri = jnp.linalg.norm(Ri_vec, axis=-1) + Rf = jnp.linalg.norm( + eval_pts[jnp.newaxis, :] - coil_pts_end[:, jnp.newaxis, :], axis=-1 + ) + Ri_p_Rf = Ri + Rf + + eps = L[:, jnp.newaxis] / (Ri_p_Rf) + + A_mag = 1.0e-7 * current * jnp.log((1 + eps) / (1 - eps)) # 1.0e-7 == mu_0/(4 pi) + + # Now just need to multiply by e^ = d_vec/L = (x_f - x_i)/L + A = jnp.sum(A_mag[:, :, jnp.newaxis] * d_vec_over_L[:, jnp.newaxis, :], axis=0) + return A + + @jit def biot_savart_quad(eval_pts, coil_pts, tangents, current): """Biot-Savart law for filamentary coil using numerical quadrature. @@ -123,6 +169,42 @@ def biot_savart_quad(eval_pts, coil_pts, tangents, current): return B +@jit +def biot_savart_vector_potential_quad(eval_pts, coil_pts, tangents, current): + """Biot-Savart law (for A) for filamentary coil using numerical quadrature. + + This expression assumes the Coulomb gauge. + + Parameters + ---------- + eval_pts : array-like shape(n,3) + Evaluation points in cartesian coordinates + coil_pts : array-like shape(m,3) + Points in cartesian space defining coil + tangents : array-like, shape(m,3) + Tangent vectors to the coil at coil_pts. If the curve is given + by x(s) with curve parameter s, coil_pts = x, tangents = dx/ds*ds where + ds is the spacing between points. + current : float + Current through the coil (in Amps). + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential in cartesian components at specified points. + """ + dl = tangents + R_vec = eval_pts[jnp.newaxis, :] - coil_pts[:, jnp.newaxis, :] + R_mag = jnp.linalg.norm(R_vec, axis=-1) + + vec = dl[:, jnp.newaxis, :] + denom = R_mag + + # 1e-7 == mu_0/(4 pi) + A = jnp.sum(1.0e-7 * current * vec / denom[:, :, None], axis=0) + return A + + class _Coil(_MagneticField, Optimizable, ABC): """Base class representing a magnetic field coil. @@ -187,10 +269,16 @@ def _compute_position(self, params=None, grid=None, **kwargs): x = x.at[:, :, 1].set(jnp.mod(x[:, :, 1], 2 * jnp.pi)) return x - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. The coil current may be overridden by including `current` in the `params` dictionary. @@ -208,6 +296,9 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns @@ -223,6 +314,14 @@ def compute_magnetic_field( may not be zero if not fully converged. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + op = {"B": biot_savart_quad, "A": biot_savart_vector_potential_quad}[ + compute_A_or_B + ] assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis.lower() == "rpz": @@ -256,13 +355,87 @@ def compute_magnetic_field( data["x_s"] = rpz2xyz_vec(data["x_s"], phi=data["x"][:, 1]) data["x"] = rpz2xyz(data["x"]) - B = biot_savart_quad( - coords, data["x"], data["x_s"] * data["ds"][:, None], current - ) + AB = op(coords, data["x"], data["x_s"] * data["ds"][:, None], current) if basis.lower() == "rpz": - B = xyz2rpz_vec(B, phi=phi) - return B + AB = xyz2rpz_vec(AB, phi=phi) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field at specified points, in either rpz or xyz coordinates + + Notes + ----- + Uses direct quadrature of the Biot-Savart integral for filamentary coils with + tangents provided by the underlying curve class. Convergence should be + exponential in the number of points used to discretize the curve, though curl(B) + may not be zero if not fully converged. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + vector_potential : ndarray, shape(n,3) + Magnetic vector potential at specified points, in either rpz or + xyz coordinates. + + Notes + ----- + Uses direct quadrature of the Biot-Savart integral for filamentary coils with + tangents provided by the underlying curve class. Convergence should be + exponential in the number of points used to discretize the curve, though curl(B) + may not be zero if not fully converged. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") def __repr__(self): """Get the string form of the object.""" @@ -783,10 +956,16 @@ def __init__( ): super().__init__(current, X, Y, Z, knots, method, name) - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. The coil current may be overridden by including `current` in the `params` dictionary. @@ -804,6 +983,9 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- @@ -817,6 +999,12 @@ def compute_magnetic_field( is approximately quadratic in the number of coil points. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + op = {"B": biot_savart_hh, "A": biot_savart_vector_potential_hh}[compute_A_or_B] assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "rpz": @@ -826,7 +1014,9 @@ def compute_magnetic_field( else: current = params.pop("current", self.current) - data = self.compute(["x"], grid=source_grid, params=params, basis="xyz") + data = self.compute( + ["x"], grid=source_grid, params=params, basis="xyz", transforms=transforms + ) # need to make sure the curve is closed. If it's already closed, this doesn't # do anything (effectively just adds a segment of zero length which has no # effect on the overall result) @@ -837,11 +1027,85 @@ def compute_magnetic_field( # coils curvature which is a 2nd derivative of the position, and doing that # with only possibly c1 cubic splines is inaccurate, so we don't do it # (for now, maybe in the future?) - B = biot_savart_hh(coords, coil_pts_start, coil_pts_end, current) + AB = op(coords, coil_pts_start, coil_pts_end, current) if basis == "rpz": - B = xyz2rpz_vec(B, x=coords[:, 0], y=coords[:, 1]) - return B + AB = xyz2rpz_vec(AB, x=coords[:, 0], y=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field at specified points, in either rpz or xyz coordinates + + Notes + ----- + Discretizes the coil into straight segments between grid points, and uses the + Hanson-Hirshman expression for exact field from a straight segment. Convergence + is approximately quadratic in the number of coil points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The coil current may be overridden by including `current` + in the `params` dictionary. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate magnetic vector potential at in [R,phi,Z] + or [X,Y,Z] coordinates. + params : dict, optional + Parameters to pass to Curve. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None, optional + Grid used to discretize coil. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + A : ndarray, shape(n,3) + Magnetic vector potential at specified points, in either + rpz or xyz coordinates + + Notes + ----- + Discretizes the coil into straight segments between grid points, and uses the + Hanson-Hirshman expression for exact vector potential from a straight segment. + Convergence is approximately quadratic in the number of coil points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def from_values( @@ -1153,8 +1417,14 @@ def _compute_position(self, params=None, grid=None, **kwargs): x = rpz return x - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): """Compute magnetic field at a set of points. @@ -1171,13 +1441,22 @@ def compute_magnetic_field( points. Should NOT include endpoint at 2pi. transforms : dict of Transform or array-like Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(n,3) - Magnetic field at specified nodes, in [R,phi,Z] or [X,Y,Z] coordinates. + Magnetic field or vector potential at specified nodes, in [R,phi,Z] + or [X,Y,Z] coordinates. """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if params is None: @@ -1207,31 +1486,89 @@ def compute_magnetic_field( # field period rotation is easiest in [R,phi,Z] coordinates coords_rpz = xyz2rpz(coords_xyz) + op = { + "B": self[0].compute_magnetic_field, + "A": self[0].compute_magnetic_vector_potential, + }[compute_A_or_B] # sum the magnetic fields from each field period - def nfp_loop(k, B): + def nfp_loop(k, AB): coords_nfp = coords_rpz + jnp.array([0, 2 * jnp.pi * k / self.NFP, 0]) - def body(B, x): - B += self[0].compute_magnetic_field( - coords_nfp, params=x, basis="rpz", source_grid=source_grid - ) - return B, None + def body(AB, x): + AB += op(coords_nfp, params=x, basis="rpz", source_grid=source_grid) + return AB, None - B += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0] - return B + AB += scan(body, jnp.zeros(coords_nfp.shape), tree_stack(params))[0] + return AB - B = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz)) + AB = fori_loop(0, self.NFP, nfp_loop, jnp.zeros_like(coords_rpz)) - # sum the magnetic fields from both halves of the symmetric field period + # sum the magnetic field/potential from both halves of + # the symmetric field period if self.sym: - B = B[: coords.shape[0], :] + B[coords.shape[0] :, :] * jnp.array( + AB = AB[: coords.shape[0], :] + AB[coords.shape[0] :, :] * jnp.array( [-1, 1, 1] ) if basis.lower() == "xyz": - B = rpz2xyz_vec(B, x=coords[:, 0], y=coords[:, 1]) - return B + AB = rpz2xyz_vec(AB, x=coords[:, 0], y=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + field : ndarray, shape(n,3) + Magnetic field at specified nodes, in [R,phi,Z] or [X,Y,Z] coordinates. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + + Returns + ------- + vector_potential : ndarray, shape(n,3) + magnetic vector potential at specified points, in either rpz + or xyz coordinates + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def linspaced_angular( @@ -2002,6 +2339,65 @@ def _compute_position(self, params=None, grid=None, **kwargs): ) return x + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", + ): + """Compute magnetic field or vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + If array-like, should be 1 value per coil. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + If array-like, should be 1 value per coil. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" + + Returns + ------- + field : ndarray, shape(n,3) + magnetic field or vector potential at specified points, in either rpz + or xyz coordinates + + """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + params = self._make_arraylike(params) + source_grid = self._make_arraylike(source_grid) + transforms = self._make_arraylike(transforms) + + AB = 0 + if compute_A_or_B == "B": + for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): + AB += coil.compute_magnetic_field( + coords, par, basis, grd, transforms=tr + ) + elif compute_A_or_B == "A": + for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): + AB += coil.compute_magnetic_vector_potential( + coords, par, basis, grd, transforms=tr + ) + return AB + def compute_magnetic_field( self, coords, params=None, basis="rpz", source_grid=None, transforms=None ): @@ -2029,15 +2425,37 @@ def compute_magnetic_field( magnetic field at specified points, in either rpz or xyz coordinates """ - params = self._make_arraylike(params) - source_grid = self._make_arraylike(source_grid) - transforms = self._make_arraylike(transforms) + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") - B = 0 - for coil, par, grd, tr in zip(self.coils, params, source_grid, transforms): - B += coil.compute_magnetic_field(coords, par, basis, grd, transforms=tr) + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Parameters to pass to coils, either the same for all coils or one for each. + If array-like, should be 1 value per coil. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize coils. If an integer, uses that many equally spaced + points. Should NOT include endpoint at 2pi. + If array-like, should be 1 value per coil. + transforms : dict of Transform or array-like + Transforms for R, Z, lambda, etc. Default is to build from grid. - return B + Returns + ------- + vector_potential : ndarray, shape(n,3) + magnetic vector potential at specified points, in either rpz + or xyz coordinates + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") def to_FourierPlanar( self, N=10, grid=None, basis="xyz", name="", check_intersection=False diff --git a/desc/compat.py b/desc/compat.py index bea90d470e..38f3d9520a 100644 --- a/desc/compat.py +++ b/desc/compat.py @@ -111,6 +111,46 @@ def flip_helicity(eq): return eq +def flip_theta(eq): + """Change the gauge freedom of the poloidal angle of an Equilibrium. + + Equivalent to redefining theta_new = theta_old + π + + Parameters + ---------- + eq : Equilibrium or iterable of Equilibrium + Equilibria to redefine the poloidal angle of. + + Returns + ------- + eq : Equilibrium or iterable of Equilibrium + Same as input, but with the poloidal angle redefined. + + """ + # maybe it's iterable: + if hasattr(eq, "__len__"): + for e in eq: + flip_theta(e) + return eq + + rone = np.ones_like(eq.R_lmn) + rone[eq.R_basis.modes[:, 1] % 2 == 1] *= -1 + eq.R_lmn *= rone + + zone = np.ones_like(eq.Z_lmn) + zone[eq.Z_basis.modes[:, 1] % 2 == 1] *= -1 + eq.Z_lmn *= zone + + lone = np.ones_like(eq.L_lmn) + lone[eq.L_basis.modes[:, 1] % 2 == 1] *= -1 + eq.L_lmn *= lone + + eq.axis = eq.get_axis() + eq.surface = eq.get_surface_at(rho=1) + + return eq + + def rescale( eq, L=("R0", None), B=("B0", None), scale_pressure=True, copy=False, verbose=0 ): diff --git a/desc/compute/_basis_vectors.py b/desc/compute/_basis_vectors.py index 8fca1346d2..72803a1d7a 100644 --- a/desc/compute/_basis_vectors.py +++ b/desc/compute/_basis_vectors.py @@ -11,8 +11,8 @@ from desc.backend import jnp +from ..utils import cross, dot, safediv from .data_index import register_compute_fun -from .utils import cross, dot, safediv @register_compute_fun( diff --git a/desc/compute/_bootstrap.py b/desc/compute/_bootstrap.py index 48af83b4e5..2329682c06 100644 --- a/desc/compute/_bootstrap.py +++ b/desc/compute/_bootstrap.py @@ -13,7 +13,7 @@ from scipy.special import roots_legendre from ..backend import fori_loop, jnp -from ..integrals import surface_averages_map +from ..integrals.surface_integral import surface_averages_map from .data_index import register_compute_fun diff --git a/desc/compute/_curve.py b/desc/compute/_curve.py index 2e96e7a767..f8e4bbeb8c 100644 --- a/desc/compute/_curve.py +++ b/desc/compute/_curve.py @@ -2,9 +2,9 @@ from desc.backend import jnp, sign +from ..utils import cross, dot, safenormalize from .data_index import register_compute_fun from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec -from .utils import cross, dot, safenormalize @register_compute_fun( diff --git a/desc/compute/_equil.py b/desc/compute/_equil.py index 44975de7ac..de0fad7797 100644 --- a/desc/compute/_equil.py +++ b/desc/compute/_equil.py @@ -14,9 +14,9 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( @@ -625,7 +625,7 @@ def _e_sup_helical_times_sqrt_g_mag(params, transforms, profiles, data, **kwargs @register_compute_fun( name="F_anisotropic", - label="F_{anisotropic}", + label="F_{\\mathrm{anisotropic}}", units="N \\cdot m^{-3}", units_long="Newtons / cubic meter", description="Anisotropic force balance error", diff --git a/desc/compute/_field.py b/desc/compute/_field.py index 1ca0adb1fb..97cb44515f 100644 --- a/desc/compute/_field.py +++ b/desc/compute/_field.py @@ -13,14 +13,14 @@ from desc.backend import jnp -from ..integrals import ( +from ..integrals.surface_integral import ( surface_averages, surface_integrals_map, surface_max, surface_min, ) +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( @@ -1644,7 +1644,7 @@ def _B_sub_zeta(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="B_phi|r,t", - label="B_{\\phi} = B \\dot \\mathbf{e}_{\\phi} |_{\\rho, \\theta}", + label="B_{\\phi} = B \\cdot \\mathbf{e}_{\\phi} |_{\\rho, \\theta}", units="T \\cdot m", units_long="Tesla * meters", description="Covariant toroidal component of magnetic field in (ρ,θ,ϕ) " @@ -2269,7 +2269,7 @@ def _B_sub_zeta_rz(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="<|B|>_axis", - label="\\lange |\\mathbf{B}| \\rangle_{axis}", + label="\\langle |\\mathbf{B}| \\rangle_{axis}", units="T", units_long="Tesla", description="Average magnitude of magnetic field on the magnetic axis", diff --git a/desc/compute/_geometry.py b/desc/compute/_geometry.py index 139f91f537..662413501b 100644 --- a/desc/compute/_geometry.py +++ b/desc/compute/_geometry.py @@ -12,8 +12,8 @@ from desc.backend import jnp from ..integrals.surface_integral import line_integrals, surface_integrals +from ..utils import cross, dot, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safenorm @register_compute_fun( diff --git a/desc/compute/_metric.py b/desc/compute/_metric.py index 536bd05bb7..ed4ea48145 100644 --- a/desc/compute/_metric.py +++ b/desc/compute/_metric.py @@ -13,9 +13,9 @@ from desc.backend import jnp -from ..integrals import surface_averages +from ..integrals.surface_integral import surface_averages +from ..utils import cross, dot, safediv, safenorm from .data_index import register_compute_fun -from .utils import cross, dot, safediv, safenorm @register_compute_fun( diff --git a/desc/compute/_omnigenity.py b/desc/compute/_omnigenity.py index cbc54561c9..db37766882 100644 --- a/desc/compute/_omnigenity.py +++ b/desc/compute/_omnigenity.py @@ -13,29 +13,37 @@ from desc.backend import jnp, sign, vmap +from ..utils import cross, dot, safediv from .data_index import register_compute_fun -from .utils import cross, dot, safediv @register_compute_fun( name="B_theta_mn", label="B_{\\theta, m, n}", - units="T \\cdot m}", + units="T \\cdot m", units_long="Tesla * meters", description="Fourier coefficients for covariant poloidal component of " "magnetic field.", dim=1, params=[], - transforms={"B": [[0, 0, 0]]}, + transforms={"B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["B_theta"], + resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", - resolution_requirement="tz", ) def _B_theta_mn(params, transforms, profiles, data, **kwargs): - data["B_theta_mn"] = transforms["B"].fit(data["B_theta"]) + B_theta = transforms["grid"].meshgrid_reshape(data["B_theta"], "rtz") + + def fitfun(x): + return transforms["B"].fit(x.flatten(order="F")) + + B_theta_mn = vmap(fitfun)(B_theta) + # modes stored as shape(rho, mn) flattened + data["B_theta_mn"] = B_theta_mn.flatten() return data @@ -43,7 +51,7 @@ def _B_theta_mn(params, transforms, profiles, data, **kwargs): @register_compute_fun( name="B_phi_mn", label="B_{\\phi, m, n}", - units="T \\cdot m}", + units="T \\cdot m", units_long="Tesla * meters", description="Fourier coefficients for covariant toroidal component of " "magnetic field in (ρ,θ,ϕ) coordinates.", @@ -53,13 +61,21 @@ def _B_theta_mn(params, transforms, profiles, data, **kwargs): profiles=[], coordinates="rtz", data=["B_phi|r,t"], - M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", - N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, aliases="B_zeta_mn", # TODO: remove when phi != zeta + M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", + N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _B_phi_mn(params, transforms, profiles, data, **kwargs): - data["B_phi_mn"] = transforms["B"].fit(data["B_phi|r,t"]) + B_phi = transforms["grid"].meshgrid_reshape(data["B_phi|r,t"], "rtz") + + def fitfun(x): + return transforms["B"].fit(x.flatten(order="F")) + + B_zeta_mn = vmap(fitfun)(B_phi) + # modes stored as shape(rho, mn) flattened + data["B_phi_mn"] = B_zeta_mn.flatten() return data @@ -72,15 +88,16 @@ def _B_phi_mn(params, transforms, profiles, data, **kwargs): + "Boozer Coordinates'", dim=1, params=[], - transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]]}, + transforms={"w": [[0, 0, 0]], "B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["B_theta_mn", "B_phi_mn"], + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_mn(params, transforms, profiles, data, **kwargs): - w_mn = jnp.zeros((transforms["w"].basis.num_modes,)) + w_mn = jnp.zeros((transforms["grid"].num_rho, transforms["w"].basis.num_modes)) Bm = transforms["B"].basis.modes[:, 1] Bn = transforms["B"].basis.modes[:, 2] wm = transforms["w"].basis.modes[:, 1] @@ -89,15 +106,19 @@ def _w_mn(params, transforms, profiles, data, **kwargs): mask_t = (Bm[:, None] == -wm) & (Bn[:, None] == wn) & (wm != 0) mask_z = (Bm[:, None] == wm) & (Bn[:, None] == -wn) & (wm == 0) & (wn != 0) - num_t = (mask_t @ sign(wn)) * data["B_theta_mn"] + num_t = (mask_t @ sign(wn)) * data["B_theta_mn"].reshape( + (transforms["grid"].num_rho, -1) + ) den_t = mask_t @ jnp.abs(wm) - num_z = (mask_z @ sign(wm)) * data["B_phi_mn"] + num_z = (mask_z @ sign(wm)) * data["B_phi_mn"].reshape( + (transforms["grid"].num_rho, -1) + ) den_z = mask_z @ jnp.abs(NFP * wn) - w_mn = jnp.where(mask_t.any(axis=0), mask_t.T @ safediv(num_t, den_t), w_mn) - w_mn = jnp.where(mask_z.any(axis=0), mask_z.T @ safediv(num_z, den_z), w_mn) + w_mn = jnp.where(mask_t.any(axis=0), (mask_t.T @ safediv(num_t, den_t).T).T, w_mn) + w_mn = jnp.where(mask_z.any(axis=0), (mask_z.T @ safediv(num_z, den_z).T).T, w_mn) - data["w_Boozer_mn"] = w_mn + data["w_Boozer_mn"] = w_mn.flatten() return data @@ -110,16 +131,22 @@ def _w_mn(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates'", dim=1, params=[], - transforms={"w": [[0, 0, 0]]}, + transforms={"w": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w(params, transforms, profiles, data, **kwargs): - data["w_Boozer"] = transforms["w"].transform(data["w_Boozer_mn"]) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + w = vmap(transforms["w"].transform)(w_mn) # shape(rho, theta*zeta) + w = w.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w = jnp.moveaxis(w, 0, 1) + data["w_Boozer"] = w.flatten(order="F") return data @@ -132,16 +159,24 @@ def _w(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates', poloidal derivative", dim=1, params=[], - transforms={"w": [[0, 1, 0]]}, + transforms={"w": [[0, 1, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_t(params, transforms, profiles, data, **kwargs): - data["w_Boozer_t"] = transforms["w"].transform(data["w_Boozer_mn"], dt=1) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + # need to close over dt which can't be vmapped + fun = lambda x: transforms["w"].transform(x, dt=1) + w_t = vmap(fun)(w_mn) # shape(rho, theta*zeta) + w_t = w_t.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w_t = jnp.moveaxis(w_t, 0, 1) + data["w_Boozer_t"] = w_t.flatten(order="F") return data @@ -154,16 +189,24 @@ def _w_t(params, transforms, profiles, data, **kwargs): + "'Transformation from VMEC to Boozer Coordinates', toroidal derivative", dim=1, params=[], - transforms={"w": [[0, 0, 1]]}, + transforms={"w": [[0, 0, 1]], "grid": []}, profiles=[], coordinates="rtz", data=["w_Boozer_mn"], resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _w_z(params, transforms, profiles, data, **kwargs): - data["w_Boozer_z"] = transforms["w"].transform(data["w_Boozer_mn"], dz=1) + grid = transforms["grid"] + w_mn = data["w_Boozer_mn"].reshape((grid.num_rho, -1)) + # need to close over dz which can't be vmapped + fun = lambda x: transforms["w"].transform(x, dz=1) + w_z = vmap(fun)(w_mn) # shape(rho, theta*zeta) + w_z = w_z.reshape((grid.num_rho, grid.num_theta, grid.num_zeta), order="F") + w_z = jnp.moveaxis(w_z, 0, 1) + data["w_Boozer_z"] = w_z.flatten(order="F") return data @@ -290,21 +333,38 @@ def _sqrtg_B(params, transforms, profiles, data, **kwargs): description="Boozer harmonics of magnetic field", dim=1, params=[], - transforms={"B": [[0, 0, 0]]}, + transforms={"B": [[0, 0, 0]], "grid": []}, profiles=[], coordinates="rtz", data=["sqrt(g)_B", "|B|", "rho", "theta_B", "zeta_B"], + resolution_requirement="tz", + grid_requirement={"is_meshgrid": True}, M_booz="int: Maximum poloidal mode number for Boozer harmonics. Default 2*eq.M", N_booz="int: Maximum toroidal mode number for Boozer harmonics. Default 2*eq.N", ) def _B_mn(params, transforms, profiles, data, **kwargs): - nodes = jnp.array([data["rho"], data["theta_B"], data["zeta_B"]]).T norm = 2 ** (3 - jnp.sum((transforms["B"].basis.modes == 0), axis=1)) - data["|B|_mn"] = ( - norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0 - * (transforms["B"].basis.evaluate(nodes).T @ (data["sqrt(g)_B"] * data["|B|"])) - / transforms["B"].grid.num_nodes + grid = transforms["grid"] + + def fun(rho, theta_B, zeta_B, sqrtg_B, B): + # this fits Boozer modes on a single surface + nodes = jnp.array([rho, theta_B, zeta_B]).T + B_mn = ( + norm # 1 if m=n=0, 2 if m=0 or n=0, 4 if m!=0 and n!=0 + * (transforms["B"].basis.evaluate(nodes).T @ (sqrtg_B * B)) + / transforms["B"].grid.num_nodes + ) + return B_mn + + def reshape(x): + return grid.meshgrid_reshape(x, "rtz").reshape((grid.num_rho, -1)) + + rho, theta_B, zeta_B, sqrtg_B, B = map( + reshape, + (data["rho"], data["theta_B"], data["zeta_B"], data["sqrt(g)_B"], data["|B|"]), ) + B_mn = vmap(fun)(rho, theta_B, zeta_B, sqrtg_B, B) + data["|B|_mn"] = B_mn.flatten() return data diff --git a/desc/compute/_profiles.py b/desc/compute/_profiles.py index 940a463951..65bca54b59 100644 --- a/desc/compute/_profiles.py +++ b/desc/compute/_profiles.py @@ -13,9 +13,9 @@ from desc.backend import cond, jnp -from ..integrals import surface_averages, surface_integrals +from ..integrals.surface_integral import surface_averages, surface_integrals +from ..utils import cumtrapz, dot, safediv from .data_index import register_compute_fun -from .utils import cumtrapz, dot, safediv @register_compute_fun( diff --git a/desc/compute/_stability.py b/desc/compute/_stability.py index 4a985a4dc5..1757fee0ba 100644 --- a/desc/compute/_stability.py +++ b/desc/compute/_stability.py @@ -13,9 +13,9 @@ from desc.backend import jnp -from ..integrals import surface_integrals_map +from ..integrals.surface_integral import surface_integrals_map +from ..utils import dot from .data_index import register_compute_fun -from .utils import dot @register_compute_fun( diff --git a/desc/compute/data_index.py b/desc/compute/data_index.py index 26341ec587..f8f30fa36d 100644 --- a/desc/compute/data_index.py +++ b/desc/compute/data_index.py @@ -63,6 +63,7 @@ def register_compute_fun( # noqa: C901 aliases=None, parameterization="desc.equilibrium.equilibrium.Equilibrium", resolution_requirement="", + grid_requirement=None, source_grid_requirement=None, **kwargs, ): @@ -110,6 +111,11 @@ def register_compute_fun( # noqa: C901 If the computation simply performs pointwise operations, instead of a reduction (such as integration) over a coordinate, then an empty string may be used to indicate no requirements. + grid_requirement : dict + Attributes of the grid that the compute function requires. + Also assumes dependencies were computed on such a grid. + As an example, quantities that require tensor product grids over 2 or more + coordinates may specify ``grid_requirement={"is_meshgrid": True}``. source_grid_requirement : dict Attributes of the source grid that the compute function requires. Also assumes dependencies were computed on such a grid. @@ -130,6 +136,8 @@ def register_compute_fun( # noqa: C901 aliases = [] if source_grid_requirement is None: source_grid_requirement = {} + if grid_requirement is None: + grid_requirement = {} if not isinstance(parameterization, (tuple, list)): parameterization = [parameterization] if not isinstance(aliases, (tuple, list)): @@ -168,6 +176,7 @@ def _decorator(func): "dependencies": deps, "aliases": aliases, "resolution_requirement": resolution_requirement, + "grid_requirement": grid_requirement, "source_grid_requirement": source_grid_requirement, } for p in parameterization: diff --git a/desc/compute/geom_utils.py b/desc/compute/geom_utils.py index fc5e1dab83..eeda658b61 100644 --- a/desc/compute/geom_utils.py +++ b/desc/compute/geom_utils.py @@ -4,7 +4,7 @@ from desc.backend import jnp -from .utils import safenorm, safenormalize +from ..utils import safenorm, safenormalize def reflection_matrix(normal): diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 0c6e2f7de3..b5bbe8cbbc 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -33,7 +33,9 @@ def _parse_parameterization(p): return module + "." + klass.__qualname__ -def compute(parameterization, names, params, transforms, profiles, data=None, **kwargs): +def compute( # noqa: C901 + parameterization, names, params, transforms, profiles, data=None, **kwargs +): """Compute the quantity given by name on grid. Parameters @@ -88,6 +90,15 @@ def compute(parameterization, names, params, transforms, profiles, data=None, ** if "grid" in transforms: def check_fun(name): + reqs = data_index[p][name]["grid_requirement"] + for req in reqs: + errorif( + not hasattr(transforms["grid"], req) + or reqs[req] != getattr(transforms["grid"], req), + AttributeError, + f"Expected grid with '{req}:{reqs[req]}' to compute {name}.", + ) + reqs = data_index[p][name]["source_grid_requirement"] errorif( reqs and not hasattr(transforms["grid"], "source_grid"), @@ -517,6 +528,7 @@ def get_transforms( """ from desc.basis import DoubleFourierSeries + from desc.grid import LinearGrid from desc.transform import Transform method = "jitable" if jitable or kwargs.get("method") == "jitable" else "auto" @@ -556,8 +568,15 @@ def get_transforms( ) transforms[c] = c_transform elif c == "B": # used for Boozer transform + # assume grid is a meshgrid but only care about a single surface + if grid.num_rho > 1: + theta = grid.nodes[grid.unique_theta_idx, 1] + zeta = grid.nodes[grid.unique_zeta_idx, 2] + grid_B = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym) + else: + grid_B = grid transforms["B"] = Transform( - grid, + grid_B, DoubleFourierSeries( M=kwargs.get("M_booz", 2 * obj.M), N=kwargs.get("N_booz", 2 * obj.N), @@ -570,8 +589,15 @@ def get_transforms( method=method, ) elif c == "w": # used for Boozer transform + # assume grid is a meshgrid but only care about a single surface + if grid.num_rho > 1: + theta = grid.nodes[grid.unique_theta_idx, 1] + zeta = grid.nodes[grid.unique_zeta_idx, 2] + grid_w = LinearGrid(theta=theta, zeta=zeta, NFP=grid.NFP, sym=grid.sym) + else: + grid_w = grid transforms["w"] = Transform( - grid, + grid_w, DoubleFourierSeries( M=kwargs.get("M_booz", 2 * obj.M), N=kwargs.get("N_booz", 2 * obj.N), @@ -685,187 +711,3 @@ def _has_transforms(qty, transforms, parameterization): [d in transforms[key].derivatives.tolist() for d in derivs[key]] ).all() return all(flags.values()) - - -def dot(a, b, axis=-1): - """Batched vector dot product. - - Parameters - ---------- - a : array-like - First array of vectors. - b : array-like - Second array of vectors. - axis : int - Axis along which vectors are stored. - - Returns - ------- - y : array-like - y = sum(a*b, axis=axis) - - """ - return jnp.sum(a * b, axis=axis, keepdims=False) - - -def cross(a, b, axis=-1): - """Batched vector cross product. - - Parameters - ---------- - a : array-like - First array of vectors. - b : array-like - Second array of vectors. - axis : int - Axis along which vectors are stored. - - Returns - ------- - y : array-like - y = a x b - - """ - return jnp.cross(a, b, axis=axis) - - -def safenorm(x, ord=None, axis=None, fill=0, threshold=0): - """Like jnp.linalg.norm, but without nan gradient at x=0. - - Parameters - ---------- - x : ndarray - Vector or array to norm. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of norm. - axis : {None, int, 2-tuple of ints}, optional - Axis to take norm along. - fill : float, ndarray, optional - Value to return where x is zero. - threshold : float >= 0 - How small is x allowed to be. - - """ - is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = jnp.linalg.norm(y, ord=ord, axis=axis) - n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero - return n - - -def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): - """Normalize a vector to unit length, but without nan gradient at x=0. - - Parameters - ---------- - x : ndarray - Vector or array to norm. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of norm. - axis : {None, int, 2-tuple of ints}, optional - Axis to take norm along. - fill : float, ndarray, optional - Value to return where x is zero. - threshold : float >= 0 - How small is x allowed to be. - - """ - is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x) - # return unit vector with equal components if norm <= threshold - return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) - - -def safediv(a, b, fill=0, threshold=0): - """Divide a/b with guards for division by zero. - - Parameters - ---------- - a, b : ndarray - Numerator and denominator. - fill : float, ndarray, optional - Value to return where b is zero. - threshold : float >= 0 - How small is b allowed to be. - """ - mask = jnp.abs(b) <= threshold - num = jnp.where(mask, fill, a) - den = jnp.where(mask, 1, b) - return num / den - - -def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None): - """Cumulatively integrate y(x) using the composite trapezoidal rule. - - Taken from SciPy, but changed NumPy references to JAX.NumPy: - https://github.com/scipy/scipy/blob/v1.10.1/scipy/integrate/_quadrature.py - - Parameters - ---------- - y : array_like - Values to integrate. - x : array_like, optional - The coordinate to integrate along. If None (default), use spacing `dx` - between consecutive elements in `y`. - dx : float, optional - Spacing between elements of `y`. Only used if `x` is None. - axis : int, optional - Specifies the axis to cumulate. Default is -1 (last axis). - initial : scalar, optional - If given, insert this value at the beginning of the returned result. - Typically, this value should be 0. Default is None, which means no - value at ``x[0]`` is returned and `res` has one element less than `y` - along the axis of integration. - - Returns - ------- - res : ndarray - The result of cumulative integration of `y` along `axis`. - If `initial` is None, the shape is such that the axis of integration - has one less value than `y`. If `initial` is given, the shape is equal - to that of `y`. - - """ - y = jnp.asarray(y) - if x is None: - d = dx - else: - x = jnp.asarray(x) - if x.ndim == 1: - d = jnp.diff(x) - # reshape to correct shape - shape = [1] * y.ndim - shape[axis] = -1 - d = d.reshape(shape) - elif len(x.shape) != len(y.shape): - raise ValueError("If given, shape of x must be 1-D or the " "same as y.") - else: - d = jnp.diff(x, axis=axis) - - if d.shape[axis] != y.shape[axis] - 1: - raise ValueError( - "If given, length of x along axis must be the " "same as y." - ) - - def tupleset(t, i, value): - l = list(t) - l[i] = value - return tuple(l) - - nd = len(y.shape) - slice1 = tupleset((slice(None),) * nd, axis, slice(1, None)) - slice2 = tupleset((slice(None),) * nd, axis, slice(None, -1)) - res = jnp.cumsum(d * (y[slice1] + y[slice2]) / 2.0, axis=axis) - - if initial is not None: - if not jnp.isscalar(initial): - raise ValueError("`initial` parameter should be a scalar.") - - shape = list(res.shape) - shape[axis] = 1 - res = jnp.concatenate( - [jnp.full(shape, initial, dtype=res.dtype), res], axis=axis - ) - - return res diff --git a/desc/equilibrium/coords.py b/desc/equilibrium/coords.py index bb9b5b8be9..d21c1cd73b 100644 --- a/desc/equilibrium/coords.py +++ b/desc/equilibrium/coords.py @@ -598,8 +598,8 @@ def to_sfl( M_grid = M_grid or int(2 * M) N_grid = N_grid or int(2 * N) - grid = ConcentricGrid(L_grid, M_grid, N_grid, node_pattern="ocs") - bdry_grid = LinearGrid(M=M, N=N, rho=1.0) + grid = ConcentricGrid(L_grid, M_grid, N_grid, node_pattern="ocs", NFP=eq.NFP) + bdry_grid = LinearGrid(M=M, N=N, rho=1.0, NFP=eq.NFP) toroidal_coords = eq.compute(["R", "Z", "lambda"], grid=grid) theta = grid.nodes[:, 1] @@ -685,11 +685,14 @@ def get_rtz_grid( rvp : rho, theta_PEST, phi rtz : rho, theta, zeta period : tuple of float - Assumed periodicity for each quantity in inbasis. + Assumed periodicity for functions of the given coordinates. Use ``np.inf`` to denote no periodicity. jitable : bool, optional If false the returned grid has additional attributes. Required to be false to retain nodes at magnetic axis. + kwargs + Additional parameters to supply to the coordinate mapping function. + See ``desc.equilibrium.coords.map_coordinates``. Returns ------- @@ -701,7 +704,7 @@ def get_rtz_grid( [radial, poloidal, toroidal], coordinates=coordinates, period=period ) if "iota" in kwargs: - kwargs["iota"] = grid.expand(kwargs["iota"]) + kwargs["iota"] = grid.expand(jnp.atleast_1d(kwargs["iota"])) inbasis = { "r": "rho", "t": "theta", diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index 8d09d5f64b..a13164dbe6 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -1255,7 +1255,11 @@ def compute_theta_coords( point. Only returned if ``full_output`` is True. """ - warnif(True, DeprecationWarning, msg="Use map_coordinates instead.") + warnif( + True, + DeprecationWarning, + "Use map_coordinates instead of compute_theta_coords.", + ) return map_coordinates( self, flux_coords, diff --git a/desc/grid.py b/desc/grid.py index 4f318afcaf..6a8ab78fe3 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -619,6 +619,7 @@ def meshgrid_reshape(self, x, order): ------- x : ndarray Data reshaped to align with grid nodes. + """ errorif( not self.is_meshgrid, @@ -637,7 +638,8 @@ def meshgrid_reshape(self, x, order): vec = True shape += (-1,) x = x.reshape(shape, order="F") - x = jnp.moveaxis(x, 1, 0) # now shape rtz/raz etc + # swap to change shape from trz/arz to rtz/raz etc. + x = jnp.swapaxes(x, 1, 0) newax = tuple(self.coordinates.index(c) for c in order) if vec: newax += (3,) @@ -788,10 +790,11 @@ def create_meshgrid( rtz : rho, theta, zeta period : tuple of float Assumed periodicity for each coordinate. - Use np.inf to denote no periodicity. + Use ``np.inf`` to denote no periodicity. NFP : int Number of field periods (Default = 1). - Only makes sense to change from 1 if ``period[2]==2π``. + Only makes sense to change from 1 if last coordinate is periodic + with some constant divided by ``NFP``. Returns ------- @@ -1885,8 +1888,13 @@ def _periodic_spacing(x, period=2 * jnp.pi, sort=False, jnp=jnp): x = jnp.sort(x, axis=0) # choose dx to be half the distance between its neighbors if x.size > 1: - dx_0 = x[1] + (period - x[-1]) % period - dx_1 = x[0] + (period - x[-2]) % period + if np.isfinite(period): + dx_0 = x[1] + (period - x[-1]) % period + dx_1 = x[0] + (period - x[-2]) % period + else: + # just set to 0 to stop nan gradient, even though above gives expected value + dx_0 = 0 + dx_1 = 0 if x.size == 2: # then dx[0] == period and dx[-1] == 0, so fix this dx_1 = dx_0 diff --git a/desc/integrals/__init__.py b/desc/integrals/__init__.py index f223e39606..88cc3001ca 100644 --- a/desc/integrals/__init__.py +++ b/desc/integrals/__init__.py @@ -1,5 +1,6 @@ """Classes for function integration.""" +from .bounce_integral import Bounce1D from .singularities import ( DFTInterpolator, FFTInterpolator, diff --git a/desc/integrals/basis.py b/desc/integrals/basis.py new file mode 100644 index 0000000000..91a31edf60 --- /dev/null +++ b/desc/integrals/basis.py @@ -0,0 +1,109 @@ +"""Fast transformable basis.""" + +from functools import partial + +from desc.backend import flatnonzero, jnp, put +from desc.utils import setdefault + + +@partial(jnp.vectorize, signature="(m),(m)->(m)") +def _in_epigraph_and(is_intersect, df_dy_sign, /): + """Set and epigraph of function f with the given set of points. + + Used to return only intersects where the straight line path between + adjacent intersects resides in the epigraph of a continuous map ``f``. + + Parameters + ---------- + is_intersect : jnp.ndarray + Boolean array indicating whether index corresponds to an intersect. + df_dy_sign : jnp.ndarray + Shape ``is_intersect.shape``. + Sign of ∂f/∂y (yᵢ) for f(yᵢ) = 0. + + Returns + ------- + is_intersect : jnp.ndarray + Boolean array indicating whether element is an intersect + and satisfies the stated condition. + + Examples + -------- + See ``desc/integrals/bounce_utils.py::bounce_points``. + This is used there to ensure the domains of integration are magnetic wells. + + """ + # The pairs ``y1`` and ``y2`` are boundaries of an integral only if ``y1 <= y2``. + # For the integrals to be over wells, it is required that the first intersect + # has a non-positive derivative. Now, by continuity, + # ``df_dy_sign[...,k]<=0`` implies ``df_dy_sign[...,k+1]>=0``, + # so there can be at most one inversion, and if it exists, the inversion + # must be at the first pair. To correct the inversion, it suffices to disqualify the + # first intersect as a right boundary, except under an edge case of a series of + # inflection points. + idx = flatnonzero(is_intersect, size=2, fill_value=-1) + edge_case = ( + (df_dy_sign[idx[0]] == 0) + & (df_dy_sign[idx[1]] < 0) + & is_intersect[idx[0]] + & is_intersect[idx[1]] + # In theory, we need to keep propagating this edge case, e.g. + # (df_dy_sign[..., 1] < 0) | ( + # (df_dy_sign[..., 1] == 0) & (df_dy_sign[..., 2] < 0)... + # ). + # At each step, the likelihood that an intersection has already been lost + # due to floating point errors grows, so the real solution is to pick a less + # degenerate pitch value - one that does not ride the global extrema of f. + ) + return put(is_intersect, idx[0], edge_case) + + +def _add2legend(legend, lines): + """Add lines to legend if it's not already in it.""" + for line in setdefault(lines, [lines], hasattr(lines, "__iter__")): + label = line.get_label() + if label not in legend: + legend[label] = line + + +def _plot_intersect(ax, legend, z1, z2, k, k_transparency, klabel): + """Plot intersects on ``ax``.""" + if k is None: + return + + k = jnp.atleast_1d(jnp.squeeze(k)) + assert k.ndim == 1 + z1, z2 = jnp.atleast_2d(z1, z2) + assert z1.ndim == z2.ndim >= 2 + assert k.shape[0] == z1.shape[0] == z2.shape[0] + for p in k: + _add2legend( + legend, + ax.axhline(p, color="tab:purple", alpha=k_transparency, label=klabel), + ) + for i in range(k.size): + _z1, _z2 = z1[i], z2[i] + if _z1.size == _z2.size: + mask = (_z1 - _z2) != 0.0 + _z1 = _z1[mask] + _z2 = _z2[mask] + _add2legend( + legend, + ax.scatter( + _z1, + jnp.full_like(_z1, k[i]), + marker="v", + color="tab:red", + label=r"$z_1$", + ), + ) + _add2legend( + legend, + ax.scatter( + _z2, + jnp.full_like(_z2, k[i]), + marker="^", + color="tab:green", + label=r"$z_2$", + ), + ) diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py new file mode 100644 index 0000000000..dff4db396c --- /dev/null +++ b/desc/integrals/bounce_integral.py @@ -0,0 +1,428 @@ +"""Methods for computing bounce integrals (singular or otherwise).""" + +from interpax import CubicHermiteSpline, PPoly +from orthax.legendre import leggauss + +from desc.backend import jnp +from desc.integrals.bounce_utils import ( + _bounce_quadrature, + _check_bounce_points, + _set_default_plot_kwargs, + bounce_points, + get_pitch_inv, + interp_to_argmin, + plot_ppoly, +) +from desc.integrals.interp_utils import polyder_vec +from desc.integrals.quad_utils import ( + automorphism_sin, + get_quadrature, + grad_automorphism_sin, +) +from desc.io import IOAble +from desc.utils import errorif, setdefault, warnif + + +class Bounce1D(IOAble): + """Computes bounce integrals using one-dimensional local spline methods. + + The bounce integral is defined as ∫ f(λ, ℓ) dℓ, where + dℓ parameterizes the distance along the field line in meters, + f(λ, ℓ) is the quantity to integrate along the field line, + and the boundaries of the integral are bounce points ℓ₁, ℓ₂ s.t. λ|B|(ℓᵢ) = 1, + where λ is a constant defining the integral proportional to the magnetic moment + over energy and |B| is the norm of the magnetic field. + + For a particle with fixed λ, bounce points are defined to be the location on the + field line such that the particle's velocity parallel to the magnetic field is zero. + The bounce integral is defined up to a sign. We choose the sign that corresponds to + the particle's guiding center trajectory traveling in the direction of increasing + field-line-following coordinate ζ. + + Notes + ----- + Brief description of algorithm for developers. + + For applications which reduce to computing a nonlinear function of distance + along field lines between bounce points, it is required to identify these + points with field-line-following coordinates. (In the special case of a linear + function summing integrals between bounce points over a flux surface, arbitrary + coordinate systems may be used as this operation reduces to a surface integral, + which is invariant to the order of summation). + + The DESC coordinate system is related to field-line-following coordinate + systems by a relation whose solution is best found with Newton iteration. + There is a unique real solution to this equation, so Newton iteration is a + globally convergent root-finding algorithm here. For the task of finding + bounce points, even if the inverse map: θ(α, ζ) was known, Newton iteration + is not a globally convergent algorithm to find the real roots of + f : ζ ↦ |B|(ζ) − 1/λ where ζ is a field-line-following coordinate. + For this, function approximation of |B| is necessary. + + The function approximation in ``Bounce1D`` is ignorant that the objects to + approximate are defined on a bounded subset of ℝ². Instead, the domain is + projected to ℝ, where information sampled about the function at infinity + cannot support reconstruction of the function near the origin. As the + functions of interest do not vanish at infinity, pseudo-spectral techniques + are not used. Instead, function approximation is done with local splines. + This is useful if one can efficiently obtain data along field lines and + most efficient if the number of toroidal transits to follow a field line is + not too large. + + After computing the bounce points, the supplied quadrature is performed. + By default, this is a Gauss quadrature after removing the singularity. + Local splines interpolate functions in the integrand to the quadrature nodes. + + See Also + -------- + Bounce2D : Uses two-dimensional pseudo-spectral techniques for the same task. + + Examples + -------- + See ``tests/test_integrals.py::TestBounce1D::test_bounce1d_checks``. + + Attributes + ---------- + required_names : list + Names in ``data_index`` required to compute bounce integrals. + B : jnp.ndarray + Shape (M, L, N - 1, B.shape[-1]). + Polynomial coefficients of the spline of |B| in local power basis. + Last axis enumerates the coefficients of power series. For a polynomial + given by ∑ᵢⁿ cᵢ xⁱ, coefficient cᵢ is stored at ``B[...,n-i]``. + Third axis enumerates the polynomials that compose a particular spline. + Second axis enumerates flux surfaces. + First axis enumerates field lines of a particular flux surface. + + """ + + required_names = ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] + get_pitch_inv = staticmethod(get_pitch_inv) + + def __init__( + self, + grid, + data, + quad=leggauss(32), + automorphism=(automorphism_sin, grad_automorphism_sin), + Bref=1.0, + Lref=1.0, + *, + is_reshaped=False, + check=False, + **kwargs, + ): + """Returns an object to compute bounce integrals. + + Parameters + ---------- + grid : Grid + Clebsch coordinate (ρ, α, ζ) tensor-product grid. + The ζ coordinates (the unique values prior to taking the tensor-product) + must be strictly increasing and preferably uniformly spaced. These are used + as knots to construct splines. A reference knot density is 100 knots per + toroidal transit. Note that below shape notation defines + L = ``grid.num_rho``, M = ``grid.num_alpha``, and N = ``grid.num_zeta``. + data : dict[str, jnp.ndarray] + Data evaluated on ``grid``. + Must include names in ``Bounce1D.required_names``. + quad : (jnp.ndarray, jnp.ndarray) + Quadrature points xₖ and weights wₖ for the approximate evaluation of an + integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. + automorphism : (Callable, Callable) or None + The first callable should be an automorphism of the real interval [-1, 1]. + The second callable should be the derivative of the first. This map defines + a change of variable for the bounce integral. The choice made for the + automorphism will affect the performance of the quadrature method. + Bref : float + Optional. Reference magnetic field strength for normalization. + Lref : float + Optional. Reference length scale for normalization. + is_reshaped : bool + Whether the arrays in ``data`` are already reshaped to the expected form of + shape (..., N) or (..., L, N) or (M, L, N). This option can be used to + iteratively compute bounce integrals one field line or one flux surface + at a time, respectively, potentially reducing memory usage. To do so, + set to true and provide only those axes of the reshaped data. + Default is false. + check : bool + Flag for debugging. Must be false for JAX transformations. + + """ + # Strictly increasing zeta knots enforces dζ > 0. + # To retain dℓ = (|B|/B^ζ) dζ > 0 after fixing dζ > 0, we require + # B^ζ = B⋅∇ζ > 0. This is equivalent to changing the sign of ∇ζ or [∂ℓ/∂ζ]|ρ,a. + # Recall dζ = ∇ζ⋅dR, implying 1 = ∇ζ⋅(e_ζ|ρ,a). Hence, a sign change in ∇ζ + # requires the same sign change in e_ζ|ρ,a to retain the metric identity. + warnif( + check and kwargs.pop("warn", True) and jnp.any(data["B^zeta"] <= 0), + msg="(∂ℓ/∂ζ)|ρ,a > 0 is required. Enforcing positive B^ζ.", + ) + data = { + "B^zeta": jnp.abs(data["B^zeta"]) * Lref / Bref, + "B^zeta_z|r,a": data["B^zeta_z|r,a"] + * jnp.sign(data["B^zeta"]) + * Lref + / Bref, + "|B|": data["|B|"] / Bref, + "|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign. + } + self._data = ( + data + if is_reshaped + else dict(zip(data.keys(), Bounce1D.reshape_data(grid, *data.values()))) + ) + self._x, self._w = get_quadrature(quad, automorphism) + + # Compute local splines. + self._zeta = grid.compress(grid.nodes[:, 2], surface_label="zeta") + self.B = jnp.moveaxis( + CubicHermiteSpline( + x=self._zeta, + y=self._data["|B|"], + dydx=self._data["|B|_z|r,a"], + axis=-1, + check=check, + ).c, + source=(0, 1), + destination=(-1, -2), + ) + self._dB_dz = polyder_vec(self.B) + + # Add axis here instead of in ``_bounce_quadrature``. + for name in self._data: + self._data[name] = self._data[name][..., jnp.newaxis, :] + + @staticmethod + def reshape_data(grid, *arys): + """Reshape arrays for acceptable input to ``integrate``. + + Parameters + ---------- + grid : Grid + Clebsch coordinate (ρ, α, ζ) tensor-product grid. + arys : jnp.ndarray + Data evaluated on grid. + + Returns + ------- + f : jnp.ndarray + Shape (M, L, N). + Reshaped data which may be given to ``integrate``. + + """ + f = [grid.meshgrid_reshape(d, "arz") for d in arys] + return f if len(f) > 1 else f[0] + + def points(self, pitch_inv, *, num_well=None): + """Compute bounce points. + + Parameters + ---------- + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + + Returns + ------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field + line and pitch, is padded with zero. + + """ + return bounce_points(pitch_inv, self._zeta, self.B, self._dB_dz, num_well) + + def check_points(self, z1, z2, pitch_inv, *, plot=True, **kwargs): + """Check that bounce points are computed correctly. + + Parameters + ---------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (M, L, P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce points at each field line. 1/λ(α,ρ) is + specified by ``pitch_inv[α,ρ]`` where in the latter the labels + are interpreted as the indices that correspond to that field line. + plot : bool + Whether to plot the field lines and bounce points of the given pitch angles. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + plots : list + Matplotlib (fig, ax) tuples for the 1D plot of each field line. + + """ + return _check_bounce_points( + z1=z1, + z2=z2, + pitch_inv=pitch_inv, + knots=self._zeta, + B=self.B, + plot=plot, + **kwargs, + ) + + def integrate( + self, + integrand, + pitch_inv, + f=None, + weight=None, + *, + num_well=None, + method="cubic", + batch=True, + check=False, + plot=False, + ): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Computes the bounce integral ∫ f(λ, ℓ) dℓ for every field line and pitch. + + Parameters + ---------- + integrand : callable + The composition operator on the set of functions in ``f`` that maps the + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (M, L, P). + 1/λ values to compute the bounce integrals. 1/λ(α,ρ) is specified by + ``pitch_inv[α,ρ]`` where in the latter the labels are interpreted + as the indices that correspond to that field line. + f : list[jnp.ndarray] or jnp.ndarray + Shape (M, L, N). + Real scalar-valued functions evaluated on the ``grid`` supplied to + construct this object. These functions should be arguments to the callable + ``integrand``. Use the method ``self.reshape_data`` to reshape the data + into the expected shape. + weight : jnp.ndarray + Shape (M, L, N). + If supplied, the bounce integral labeled by well j is weighted such that + the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight`` + interpolated to the deepest point in that magnetic well. Use the method + ``self.reshape_data`` to reshape the data into the expected shape. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``check_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + batch : bool + Whether to perform computation in a batched manner. Default is true. + check : bool + Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. + + Returns + ------- + result : jnp.ndarray + Shape (M, L, P, num_well). + Last axis enumerates the bounce integrals for a given field line, + flux surface, and pitch value. + + """ + z1, z2 = self.points(pitch_inv, num_well=num_well) + result = _bounce_quadrature( + x=self._x, + w=self._w, + z1=z1, + z2=z2, + integrand=integrand, + pitch_inv=pitch_inv, + f=setdefault(f, []), + data=self._data, + knots=self._zeta, + method=method, + batch=batch, + check=check, + plot=plot, + ) + if weight is not None: + result *= interp_to_argmin( + weight, + z1, + z2, + self._zeta, + self.B, + self._dB_dz, + method, + ) + assert result.shape == z1.shape + return result + + def plot(self, m, l, pitch_inv=None, /, **kwargs): + """Plot the field line and bounce points of the given pitch angles. + + Parameters + ---------- + m, l : int, int + Indices into the nodes of the grid supplied to make this object. + ``alpha,rho=grid.meshgrid_reshape(grid.nodes[:,:2],"arz")[m,l,0]``. + pitch_inv : jnp.ndarray + Shape (P, ). + Optional, 1/λ values whose corresponding bounce points on the field line + specified by Clebsch coordinate α(m), ρ(l) will be plotted. + kwargs + Keyword arguments into ``desc/integrals/bounce_utils.py::plot_ppoly``. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + B, dB_dz = self.B, self._dB_dz + if B.ndim == 4: + B = B[m] + dB_dz = dB_dz[m] + if B.ndim == 3: + B = B[l] + dB_dz = dB_dz[l] + if pitch_inv is not None: + errorif( + pitch_inv.ndim > 1, + msg=f"Got pitch_inv.ndim={pitch_inv.ndim}, but expected 1.", + ) + z1, z2 = bounce_points(pitch_inv, self._zeta, B, dB_dz) + kwargs["z1"] = z1 + kwargs["z2"] = z2 + kwargs["k"] = pitch_inv + fig, ax = plot_ppoly(PPoly(B.T, self._zeta), **_set_default_plot_kwargs(kwargs)) + return fig, ax diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py new file mode 100644 index 0000000000..c63477c0cc --- /dev/null +++ b/desc/integrals/bounce_utils.py @@ -0,0 +1,809 @@ +"""Utilities and functional programming interface for bounce integrals.""" + +import numpy as np +from interpax import PPoly +from matplotlib import pyplot as plt + +from desc.backend import imap, jnp, softargmax +from desc.integrals.basis import _add2legend, _in_epigraph_and, _plot_intersect +from desc.integrals.interp_utils import ( + interp1d_Hermite_vec, + interp1d_vec, + polyroot_vec, + polyval_vec, +) +from desc.integrals.quad_utils import ( + bijection_from_disc, + composite_linspace, + grad_bijection_from_disc, +) +from desc.utils import ( + atleast_nd, + errorif, + flatten_matrix, + is_broadcastable, + setdefault, + take_mask, +) + + +def get_pitch_inv(min_B, max_B, num, relative_shift=1e-6): + """Return 1/λ values for quadrature between ``min_B`` and ``max_B``. + + Parameters + ---------- + min_B : jnp.ndarray + Minimum |B| value. + max_B : jnp.ndarray + Maximum |B| value. + num : int + Number of values, not including endpoints. + relative_shift : float + Relative amount to shift maxima down and minima up to avoid floating point + errors in downstream routines. + + Returns + ------- + pitch_inv : jnp.ndarray + Shape (*min_B.shape, num + 2). + 1/λ values. + + """ + # Floating point error impedes consistent detection of bounce points riding + # extrema. Shift values slightly to resolve this issue. + min_B = (1 + relative_shift) * min_B + max_B = (1 - relative_shift) * max_B + # Samples should be uniformly spaced in |B| and not λ (GitHub issue #1228). + pitch_inv = jnp.moveaxis(composite_linspace(jnp.stack([min_B, max_B]), num), 0, -1) + assert pitch_inv.shape == (*min_B.shape, num + 2) + return pitch_inv + + +def _check_spline_shape(knots, g, dg_dz, pitch_inv=None): + """Ensure inputs have compatible shape. + + Parameters + ---------- + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂ζ in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values. 1/λ(α,ρ) is specified by ``pitch_inv[α,ρ]`` where in + the latter the labels are interpreted as the indices that correspond + to that field line. + + """ + errorif(knots.ndim != 1, msg=f"knots should be 1d; got shape {knots.shape}.") + errorif( + g.shape[-2] != (knots.size - 1), + msg=( + "Second to last axis does not enumerate polynomials of spline. " + f"Spline shape {g.shape}. Knots shape {knots.shape}." + ), + ) + errorif( + not (g.ndim == dg_dz.ndim < 5) + or g.shape != (*dg_dz.shape[:-1], dg_dz.shape[-1] + 1), + msg=f"Invalid shape {g.shape} for spline and derivative {dg_dz.shape}.", + ) + g, dg_dz = jnp.atleast_2d(g, dg_dz) + if pitch_inv is not None: + pitch_inv = jnp.atleast_1d(pitch_inv) + errorif( + pitch_inv.ndim > 3 + or not is_broadcastable(pitch_inv.shape[:-1], g.shape[:-2]), + msg=f"Invalid shape {pitch_inv.shape} for pitch angles.", + ) + return g, dg_dz, pitch_inv + + +def bounce_points( + pitch_inv, knots, B, dB_dz, num_well=None, check=False, plot=True, **kwargs +): + """Compute the bounce points given spline of |B| and pitch λ. + + Parameters + ---------- + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce points. + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + B : jnp.ndarray + Shape (..., N - 1, B.shape[-1]). + Polynomial coefficients of the spline of |B| in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dB_dz : jnp.ndarray + Shape (..., N - 1, B.shape[-1] - 1). + Polynomial coefficients of the spline of (∂|B|/∂ζ)|(ρ,α) in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + num_well : int or None + Specify to return the first ``num_well`` pairs of bounce points for each + pitch along each field line. This is useful if ``num_well`` tightly + bounds the actual number. As a reference, there are typically 20 wells + per toroidal transit for a given pitch. You can check this by plotting + the field lines with the ``_check_bounce_points`` method. + + If not specified, then all bounce points are returned. If there were fewer + wells detected along a field line than the size of the last axis of the + returned arrays, then that axis is padded with zero. + check : bool + Flag for debugging. Must be false for JAX transformations. + plot : bool + Whether to plot some things if check is true. Default is true. + kwargs + Keyword arguments into ``plot_ppoly``. + + Returns + ------- + z1, z2 : (jnp.ndarray, jnp.ndarray) + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + + If there were less than ``num_well`` wells detected along a field line, + then the last axis, which enumerates bounce points for a particular field + line and pitch, is padded with zero. + + """ + B, dB_dz, pitch_inv = _check_spline_shape(knots, B, dB_dz, pitch_inv) + intersect = polyroot_vec( + c=B[..., jnp.newaxis, :, :], # Add P axis + k=pitch_inv[..., jnp.newaxis], # Add N axis + a_min=jnp.array([0.0]), + a_max=jnp.diff(knots), + sort=True, + sentinel=-1.0, + distinct=True, + ) + assert intersect.shape[-3:] == ( + pitch_inv.shape[-1], + knots.size - 1, + B.shape[-1] - 1, + ) + + # Reshape so that last axis enumerates intersects of a pitch along a field line. + dB_sign = flatten_matrix( + jnp.sign(polyval_vec(x=intersect, c=dB_dz[..., jnp.newaxis, :, jnp.newaxis, :])) + ) + # Only consider intersect if it is within knots that bound that polynomial. + is_intersect = flatten_matrix(intersect) >= 0 + # Following discussion on page 3 and 5 of https://doi.org/10.1063/1.873749, + # we ignore the bounce points of particles only assigned to a class that are + # trapped outside this snapshot of the field line. + is_z1 = (dB_sign <= 0) & is_intersect + is_z2 = (dB_sign >= 0) & _in_epigraph_and(is_intersect, dB_sign) + + # Transform out of local power basis expansion. + intersect = flatten_matrix(intersect + knots[:-1, jnp.newaxis]) + # New versions of JAX only like static sentinels. + sentinel = -10000000.0 # instead of knots[0] - 1 + z1 = take_mask(intersect, is_z1, size=num_well, fill_value=sentinel) + z2 = take_mask(intersect, is_z2, size=num_well, fill_value=sentinel) + + mask = (z1 > sentinel) & (z2 > sentinel) + # Set outside mask to same value so integration is over set of measure zero. + z1 = jnp.where(mask, z1, 0.0) + z2 = jnp.where(mask, z2, 0.0) + + if check: + _check_bounce_points(z1, z2, pitch_inv, knots, B, plot, **kwargs) + + return z1, z2 + + +def _set_default_plot_kwargs(kwargs): + kwargs.setdefault( + "title", + r"Intersects $\zeta$ in epigraph($\vert B \vert$) s.t. " + r"$\vert B \vert(\zeta) = 1/\lambda$", + ) + kwargs.setdefault("klabel", r"$1/\lambda$") + kwargs.setdefault("hlabel", r"$\zeta$") + kwargs.setdefault("vlabel", r"$\vert B \vert$") + return kwargs + + +def _check_bounce_points(z1, z2, pitch_inv, knots, B, plot=True, **kwargs): + """Check that bounce points are computed correctly.""" + z1 = atleast_nd(4, z1) + z2 = atleast_nd(4, z2) + pitch_inv = atleast_nd(3, pitch_inv) + B = atleast_nd(4, B) + + kwargs = _set_default_plot_kwargs(kwargs) + plots = [] + + assert z1.shape == z2.shape + mask = (z1 - z2) != 0.0 + z1 = jnp.where(mask, z1, jnp.nan) + z2 = jnp.where(mask, z2, jnp.nan) + + err_1 = jnp.any(z1 > z2, axis=-1) + err_2 = jnp.any(z1[..., 1:] < z2[..., :-1], axis=-1) + + eps = kwargs.pop("eps", jnp.finfo(jnp.array(1.0).dtype).eps * 10) + for ml in np.ndindex(B.shape[:-2]): + ppoly = PPoly(B[ml].T, knots) + for p in range(pitch_inv.shape[-1]): + idx = (*ml, p) + B_midpoint = ppoly((z1[idx] + z2[idx]) / 2) + err_3 = jnp.any(B_midpoint > pitch_inv[idx] + eps) + if not (err_1[idx] or err_2[idx] or err_3): + continue + _z1 = z1[idx][mask[idx]] + _z2 = z2[idx][mask[idx]] + if plot: + plot_ppoly( + ppoly=ppoly, + z1=_z1, + z2=_z2, + k=pitch_inv[idx], + title=kwargs.pop("title") + f", (m,l,p)={idx}", + **kwargs, + ) + + print(" z1 | z2") + print(jnp.column_stack([_z1, _z2])) + assert not err_1[idx], "Intersects have an inversion.\n" + assert not err_2[idx], "Detected discontinuity.\n" + assert not err_3, ( + f"Detected |B| = {B_midpoint[mask[idx]]} > {pitch_inv[idx] + eps} " + "= 1/λ in well, implying the straight line path between " + "bounce points is in hypograph(|B|). Use more knots.\n" + ) + if plot: + plots.append( + plot_ppoly( + ppoly=ppoly, + z1=z1[ml], + z2=z2[ml], + k=pitch_inv[ml], + **kwargs, + ) + ) + return plots + + +def _bounce_quadrature( + x, + w, + z1, + z2, + integrand, + pitch_inv, + f, + data, + knots, + method="cubic", + batch=True, + check=False, + plot=False, +): + """Bounce integrate ∫ f(λ, ℓ) dℓ. + + Parameters + ---------- + x : jnp.ndarray + Shape (w.size, ). + Quadrature points in [-1, 1]. + w : jnp.ndarray + Shape (w.size, ). + Quadrature weights. + z1, z2 : jnp.ndarray + Shape (..., P, num_well). + ζ coordinates of bounce points. The points are ordered and grouped such + that the straight line path between ``z1`` and ``z2`` resides in the + epigraph of |B|. + integrand : callable + The composition operator on the set of functions in ``f`` that maps the + functions in ``f`` to the integrand f(λ, ℓ) in ∫ f(λ, ℓ) dℓ. It should + accept the arrays in ``f`` as arguments as well as the additional keyword + arguments: ``B`` and ``pitch``. A quadrature will be performed to + approximate the bounce integral of ``integrand(*f,B=B,pitch=pitch)``. + pitch_inv : jnp.ndarray + Shape (..., P). + 1/λ values to compute the bounce integrals. + f : list[jnp.ndarray] + Shape (..., N). + Real scalar-valued functions evaluated on the ``knots``. + These functions should be arguments to the callable ``integrand``. + data : dict[str, jnp.ndarray] + Shape (..., 1, N). + Required data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. + Must include names in ``Bounce1D.required_names``. + knots : jnp.ndarray + Shape (N, ). + Unique ζ coordinates where the arrays in ``data`` and ``f`` were evaluated. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + batch : bool + Whether to perform computation in a batched manner. Default is true. + check : bool + Flag for debugging. Must be false for JAX transformations. + Ignored if ``batch`` is false. + plot : bool + Whether to plot the quantities in the integrand interpolated to the + quadrature points of each integral. Ignored if ``check`` is false. + + Returns + ------- + result : jnp.ndarray + Shape (..., P, num_well). + Last axis enumerates the bounce integrals for a field line, + flux surface, and pitch. + + """ + errorif(x.ndim != 1 or x.shape != w.shape) + errorif(z1.ndim < 2 or z1.shape != z2.shape) + pitch_inv = jnp.atleast_1d(pitch_inv) + if not isinstance(f, (list, tuple)): + f = [f] if isinstance(f, (jnp.ndarray, np.ndarray)) else list(f) + + # Integrate and complete the change of variable. + if batch: + result = _interpolate_and_integrate( + w=w, + Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), + pitch_inv=pitch_inv, + integrand=integrand, + f=f, + data=data, + knots=knots, + method=method, + check=check, + plot=plot, + ) + else: + # TODO: Use batched vmap. + def loop(z): # over num well axis + z1, z2 = z + # Need to return tuple because input was tuple; artifact of JAX map. + return None, _interpolate_and_integrate( + w=w, + Q=bijection_from_disc(x, z1[..., jnp.newaxis], z2[..., jnp.newaxis]), + pitch_inv=pitch_inv, + integrand=integrand, + f=f, + data=data, + knots=knots, + method=method, + check=False, + plot=False, + batch=True, + ) + + result = jnp.moveaxis( + imap(loop, (jnp.moveaxis(z1, -1, 0), jnp.moveaxis(z2, -1, 0)))[1], + source=0, + destination=-1, + ) + + return result * grad_bijection_from_disc(z1, z2) + + +def _interpolate_and_integrate( + w, + Q, + pitch_inv, + integrand, + f, + data, + knots, + method, + check, + plot, + batch=False, +): + """Interpolate given functions to points ``Q`` and perform quadrature. + + Parameters + ---------- + w : jnp.ndarray + Shape (w.size, ). + Quadrature weights. + Q : jnp.ndarray + Shape (..., P, Q.shape[-2], w.size). + Quadrature points in ζ coordinates. + + Returns + ------- + result : jnp.ndarray + Shape Q.shape[:-1]. + Quadrature result. + + """ + assert w.ndim == 1 and Q.shape[-1] == w.size + assert Q.shape[-3 + batch] == pitch_inv.shape[-1] + assert data["|B|"].shape[-1] == knots.size + + shape = Q.shape + if not batch: + Q = flatten_matrix(Q) + b_sup_z = interp1d_Hermite_vec( + Q, + knots, + data["B^zeta"] / data["|B|"], + data["B^zeta_z|r,a"] / data["|B|"] + - data["B^zeta"] * data["|B|_z|r,a"] / data["|B|"] ** 2, + ) + B = interp1d_Hermite_vec(Q, knots, data["|B|"], data["|B|_z|r,a"]) + # Spline each function separately so that operations in the integrand + # that do not preserve smoothness can be captured. + f = [interp1d_vec(Q, knots, f_i[..., jnp.newaxis, :], method=method) for f_i in f] + result = ( + (integrand(*f, B=B, pitch=1 / pitch_inv[..., jnp.newaxis]) / b_sup_z) + .reshape(shape) + .dot(w) + ) + if check: + _check_interp(shape, Q, f, b_sup_z, B, result, plot) + + return result + + +def _check_interp(shape, Q, f, b_sup_z, B, result, plot): + """Check for interpolation failures and floating point issues. + + Parameters + ---------- + shape : tuple + (..., P, Q.shape[-2], w.size). + Q : jnp.ndarray + Quadrature points in ζ coordinates. + f : list[jnp.ndarray] + Arguments to the integrand, interpolated to Q. + b_sup_z : jnp.ndarray + Contravariant toroidal component of magnetic field, interpolated to Q. + B : jnp.ndarray + Norm of magnetic field, interpolated to Q. + result : jnp.ndarray + Output of ``_interpolate_and_integrate``. + plot : bool + Whether to plot stuff. + + """ + assert jnp.isfinite(Q).all(), "NaN interpolation point." + assert not ( + jnp.isclose(B, 0).any() or jnp.isclose(b_sup_z, 0).any() + ), "|B| has vanished, violating the hairy ball theorem." + + # Integrals that we should be computing. + marked = jnp.any(Q.reshape(shape) != 0.0, axis=-1) + goal = marked.sum() + + assert goal == (marked & jnp.isfinite(b_sup_z).reshape(shape).all(axis=-1)).sum() + assert goal == (marked & jnp.isfinite(B).reshape(shape).all(axis=-1)).sum() + for f_i in f: + assert goal == (marked & jnp.isfinite(f_i).reshape(shape).all(axis=-1)).sum() + + # Number of those integrals that were computed. + actual = (marked & jnp.isfinite(result)).sum() + assert goal == actual, ( + f"Lost {goal - actual} integrals from NaN generation in the integrand. This " + "is caused by floating point error, usually due to a poor quadrature choice." + ) + if plot: + Q = Q.reshape(shape) + _plot_check_interp(Q, B.reshape(shape), name=r"$\vert B \vert$") + _plot_check_interp( + Q, b_sup_z.reshape(shape), name=r"$ (B / \vert B \vert) \cdot e^{\zeta}$" + ) + + +def _plot_check_interp(Q, V, name=""): + """Plot V[..., λ, (ζ₁, ζ₂)](Q).""" + for idx in np.ndindex(Q.shape[:3]): + marked = jnp.nonzero(jnp.any(Q[idx] != 0.0, axis=-1))[0] + if marked.size == 0: + continue + fig, ax = plt.subplots() + ax.set_xlabel(r"$\zeta$") + ax.set_ylabel(name) + ax.set_title(f"Interpolation of {name} to quadrature points, (m,l,p)={idx}") + for i in marked: + ax.plot(Q[(*idx, i)], V[(*idx, i)], marker="o") + fig.text(0.01, 0.01, "Each color specifies a particular integral.") + plt.tight_layout() + plt.show() + + +def _get_extrema(knots, g, dg_dz, sentinel=jnp.nan): + """Return extrema (z*, g(z*)). + + Parameters + ---------- + knots : jnp.ndarray + Shape (N, ). + ζ coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + sentinel : float + Value with which to pad array to return fixed shape. + + Returns + ------- + ext, g_ext : jnp.ndarray + Shape (..., (N - 1) * (g.shape[-1] - 2)). + First array enumerates z*. Second array enumerates g(z*) + Sorting order of extrema is arbitrary. + + """ + g, dg_dz, _ = _check_spline_shape(knots, g, dg_dz) + ext = polyroot_vec( + c=dg_dz, a_min=jnp.array([0.0]), a_max=jnp.diff(knots), sentinel=sentinel + ) + g_ext = flatten_matrix(polyval_vec(x=ext, c=g[..., jnp.newaxis, :])) + # Transform out of local power basis expansion. + ext = flatten_matrix(ext + knots[:-1, jnp.newaxis]) + assert ext.shape == g_ext.shape and ext.shape[-1] == g.shape[-2] * (g.shape[-1] - 2) + return ext, g_ext + + +def _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel): + return jnp.where( + (z1[..., jnp.newaxis] < ext[..., jnp.newaxis, jnp.newaxis, :]) + & (ext[..., jnp.newaxis, jnp.newaxis, :] < z2[..., jnp.newaxis]), + g_ext[..., jnp.newaxis, jnp.newaxis, :], + upper_sentinel, + ) + + +def interp_to_argmin( + h, z1, z2, knots, g, dg_dz, method="cubic", beta=-100, upper_sentinel=1e2 +): + """Interpolate ``h`` to the deepest point of ``g`` between ``z1`` and ``z2``. + + Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A = argmin_E g(ζ). Returns mean_A h(ζ). + + Parameters + ---------- + h : jnp.ndarray + Shape (..., N). + Values evaluated on ``knots`` to interpolate. + z1, z2 : jnp.ndarray + Shape (..., P, W). + Boundaries to detect argmin between. + knots : jnp.ndarray + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + beta : float + More negative gives exponentially better approximation at the + expense of noisier gradients - noisier in the physics sense (unrelated + to the automatic differentiation). + upper_sentinel : float + Something larger than g. Choose value such that + exp(max(g)) << exp(``upper_sentinel``). Don't make too large or numerical + resolution is lost. + + Warnings + -------- + Recall that if g is small then the effect of β is reduced. + If the intention is to use this function as argmax, be sure to supply + a lower sentinel for ``upper_sentinel``. + + Returns + ------- + h : jnp.ndarray + Shape (..., P, W). + + """ + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) + # Our softargmax(x) does the proper shift to compute softargmax(x - max(x)), + # but it's still not a good idea to compute over a large length scale, so we + # warn in docstring to choose upper sentinel properly. + argmin = softargmax( + beta * _where_for_argmin(z1, z2, ext, g_ext, upper_sentinel), + axis=-1, + ) + h = jnp.linalg.vecdot( + argmin, + interp1d_vec(ext, knots, h, method=method)[..., jnp.newaxis, jnp.newaxis, :], + ) + assert h.shape == z1.shape + return h + + +def interp_to_argmin_hard(h, z1, z2, knots, g, dg_dz, method="cubic"): + """Interpolate ``h`` to the deepest point of ``g`` between ``z1`` and ``z2``. + + Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A ∈ argmin_E g(ζ). Returns h(A). + + See Also + -------- + interp_to_argmin + Accomplishes the same task, but handles the case of non-unique global minima + more correctly. It is also more efficient if P >> 1. + + Parameters + ---------- + h : jnp.ndarray + Shape (..., N). + Values evaluated on ``knots`` to interpolate. + z1, z2 : jnp.ndarray + Shape (..., P, W). + Boundaries to detect argmin between. + knots : jnp.ndarray + Shape (N, ). + z coordinates of spline knots. Must be strictly increasing. + g : jnp.ndarray + Shape (..., N - 1, g.shape[-1]). + Polynomial coefficients of the spline of g in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + dg_dz : jnp.ndarray + Shape (..., N - 1, g.shape[-1] - 1). + Polynomial coefficients of the spline of ∂g/∂z in local power basis. + Last axis enumerates the coefficients of power series. Second to + last axis enumerates the polynomials that compose a particular spline. + method : str + Method of interpolation. + See https://interpax.readthedocs.io/en/latest/_api/interpax.interp1d.html. + Default is cubic C1 local spline. + + Returns + ------- + h : jnp.ndarray + Shape (..., P, W). + + """ + assert z1.ndim == z2.ndim >= 2 and z1.shape == z2.shape + ext, g_ext = _get_extrema(knots, g, dg_dz, sentinel=0) + # We can use the non-differentiable max because we actually want the gradients + # to accumulate through only the minimum since we are differentiating how our + # physics objective changes wrt equilibrium perturbations not wrt which of the + # extrema get interpolated to. + argmin = jnp.argmin( + _where_for_argmin(z1, z2, ext, g_ext, jnp.max(g_ext) + 1), + axis=-1, + ) + h = interp1d_vec( + jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1), + knots, + h[..., jnp.newaxis, :], + method=method, + ) + assert h.shape == z1.shape, h.shape + return h + + +def plot_ppoly( + ppoly, + num=1000, + z1=None, + z2=None, + k=None, + k_transparency=0.5, + klabel=r"$k$", + title=r"Intersects $z$ in epigraph($f$) s.t. $f(z) = k$", + hlabel=r"$z$", + vlabel=r"$f$", + show=True, + start=None, + stop=None, + include_knots=False, + knot_transparency=0.2, + include_legend=True, +): + """Plot the piecewise polynomial ``ppoly``. + + Parameters + ---------- + ppoly : PPoly + Piecewise polynomial f. + num : int + Number of points to evaluate for plot. + z1 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂z <= 0. + z2 : jnp.ndarray + Shape (k.shape[0], W). + Optional, intersects with ∂f/∂z >= 0. + k : jnp.ndarray + Shape (k.shape[0], ). + Optional, k such that f(z) = k. + k_transparency : float + Transparency of intersect lines. + klabel : str + Label of intersect lines. + title : str + Plot title. + hlabel : str + Horizontal axis label. + vlabel : str + Vertical axis label. + show : bool + Whether to show the plot. Default is true. + start : float + Minimum z on plot. + stop : float + Maximum z on plot. + include_knots : bool + Whether to plot vertical lines at the knots. + knot_transparency : float + Transparency of knot lines. + include_legend : bool + Whether to include the legend in the plot. Default is true. + + Returns + ------- + fig, ax + Matplotlib (fig, ax) tuple. + + """ + fig, ax = plt.subplots() + legend = {} + if include_knots: + for knot in ppoly.x: + _add2legend( + legend, + ax.axvline( + x=knot, color="tab:blue", alpha=knot_transparency, label="knot" + ), + ) + + z = jnp.linspace( + start=setdefault(start, ppoly.x[0]), + stop=setdefault(stop, ppoly.x[-1]), + num=num, + ) + _add2legend(legend, ax.plot(z, ppoly(z), label=vlabel)) + _plot_intersect( + ax=ax, + legend=legend, + z1=z1, + z2=z2, + k=k, + k_transparency=k_transparency, + klabel=klabel, + ) + ax.set_xlabel(hlabel) + ax.set_ylabel(vlabel) + if include_legend: + ax.legend(legend.values(), legend.keys(), loc="lower right") + ax.set_title(title) + plt.tight_layout() + if show: + plt.show() + plt.close() + return fig, ax diff --git a/desc/integrals/interp_utils.py b/desc/integrals/interp_utils.py new file mode 100644 index 0000000000..42f34271e1 --- /dev/null +++ b/desc/integrals/interp_utils.py @@ -0,0 +1,292 @@ +"""Fast interpolation utilities. + +Notes +----- +These polynomial utilities are chosen for performance on gpu among +methods that have the best (asymptotic) algorithmic complexity. +For example, we prefer to not use Horner's method. +""" + +from functools import partial + +from interpax import interp1d + +from desc.backend import jnp +from desc.utils import safediv + +# Warning: method must be specified as keyword argument. +interp1d_vec = jnp.vectorize( + interp1d, signature="(m),(n),(n)->(m)", excluded={"method"} +) + + +@partial(jnp.vectorize, signature="(m),(n),(n),(n)->(m)") +def interp1d_Hermite_vec(xq, x, f, fx, /): + """Vectorized cubic Hermite spline.""" + return interp1d(xq, x, f, method="cubic", fx=fx) + + +def polyder_vec(c): + """Coefficients for the derivatives of the given set of polynomials. + + Parameters + ---------- + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + + Returns + ------- + poly : jnp.ndarray + Coefficients of polynomial derivative, ignoring the arbitrary constant. That is, + ``poly[...,i]`` stores the coefficient of the monomial xⁿ⁻ⁱ⁻¹, where n is + ``c.shape[-1]-1``. + + """ + return c[..., :-1] * jnp.arange(c.shape[-1] - 1, 0, -1) + + +def polyval_vec(*, x, c): + """Evaluate the set of polynomials ``c`` at the points ``x``. + + Parameters + ---------- + x : jnp.ndarray + Coordinates at which to evaluate the set of polynomials. + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + + Returns + ------- + val : jnp.ndarray + Polynomial with given coefficients evaluated at given points. + + Examples + -------- + .. code-block:: python + + np.testing.assert_allclose( + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), + ) + + """ + # Better than Horner's method as we expect to evaluate low order polynomials. + # No need to use fast multipoint evaluation techniques for the same reason. + return jnp.sum( + c * x[..., jnp.newaxis] ** jnp.arange(c.shape[-1] - 1, -1, -1), + axis=-1, + ) + + +# TODO: Eventually do a PR to move this stuff into interpax. + + +def _subtract_last(c, k): + """Subtract ``k`` from last index of last axis of ``c``. + + Semantically same as ``return c.copy().at[...,-1].add(-k)``, + but allows dimension to increase. + """ + c_1 = c[..., -1] - k + c = jnp.concatenate( + [ + jnp.broadcast_to(c[..., :-1], (*c_1.shape, c.shape[-1] - 1)), + c_1[..., jnp.newaxis], + ], + axis=-1, + ) + return c + + +def _filter_distinct(r, sentinel, eps): + """Set all but one of matching adjacent elements in ``r`` to ``sentinel``.""" + # eps needs to be low enough that close distinct roots do not get removed. + # Otherwise, algorithms relying on continuity will fail. + mask = jnp.isclose(jnp.diff(r, axis=-1, prepend=sentinel), 0, atol=eps) + r = jnp.where(mask, sentinel, r) + return r + + +_roots = jnp.vectorize(partial(jnp.roots, strip_zeros=False), signature="(m)->(n)") + + +def polyroot_vec( + c, + k=0, + a_min=None, + a_max=None, + sort=False, + sentinel=jnp.nan, + eps=max(jnp.finfo(jnp.array(1.0).dtype).eps, 2.5e-12), + distinct=False, +): + """Roots of polynomial with given coefficients. + + Parameters + ---------- + c : jnp.ndarray + Last axis should store coefficients of a polynomial. For a polynomial given by + ∑ᵢⁿ cᵢ xⁱ, where n is ``c.shape[-1]-1``, coefficient cᵢ should be stored at + ``c[...,n-i]``. + k : jnp.ndarray + Shape (..., *c.shape[:-1]). + Specify to find solutions to ∑ᵢⁿ cᵢ xⁱ = ``k``. + a_min : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + a_max : jnp.ndarray + Shape (..., *c.shape[:-1]). + Minimum ``a_min`` and maximum ``a_max`` value to return roots between. + If specified only real roots are returned, otherwise returns all complex roots. + sort : bool + Whether to sort the roots. + sentinel : float + Value with which to pad array in place of filtered elements. + Anything less than ``a_min`` or greater than ``a_max`` plus some floating point + error buffer will work just like nan while avoiding ``nan`` gradient. + eps : float + Absolute tolerance with which to consider value as zero. + distinct : bool + Whether to only return the distinct roots. If true, when the multiplicity is + greater than one, the repeated roots are set to ``sentinel``. + + Returns + ------- + r : jnp.ndarray + Shape (..., *c.shape[:-1], c.shape[-1] - 1). + The roots of the polynomial, iterated over the last axis. + + """ + get_only_real_roots = not (a_min is None and a_max is None) + num_coef = c.shape[-1] + c = _subtract_last(c, k) + func = {2: _root_linear, 3: _root_quadratic, 4: _root_cubic} + + if ( + num_coef in func + and get_only_real_roots + and not (jnp.iscomplexobj(c) or jnp.iscomplexobj(k)) + ): + # Compute from analytic formula to avoid the issue of complex roots with small + # imaginary parts and to avoid nan in gradient. + r = func[num_coef](C=c, sentinel=sentinel, eps=eps, distinct=distinct) + # We already filtered distinct roots for quadratics. + distinct = distinct and num_coef > 3 + else: + # Compute from eigenvalues of polynomial companion matrix. + r = _roots(c) + + if get_only_real_roots: + a_min = -jnp.inf if a_min is None else a_min[..., jnp.newaxis] + a_max = +jnp.inf if a_max is None else a_max[..., jnp.newaxis] + r = jnp.where( + (jnp.abs(r.imag) <= eps) & (a_min <= r.real) & (r.real <= a_max), + r.real, + sentinel, + ) + + if sort or distinct: + r = jnp.sort(r, axis=-1) + r = _filter_distinct(r, sentinel, eps) if distinct else r + assert r.shape[-1] == num_coef - 1 + return r + + +def _root_cubic(C, sentinel, eps, distinct): + """Return real cubic root assuming real coefficients.""" + # numerical.recipes/book.html, page 228 + + def irreducible(Q, R, b, mask): + # Three irrational real roots. + theta = jnp.arccos(R / jnp.sqrt(jnp.where(mask, Q**3, R**2 + 1))) + return jnp.moveaxis( + -2 + * jnp.sqrt(Q) + * jnp.stack( + [ + jnp.cos(theta / 3), + jnp.cos((theta + 2 * jnp.pi) / 3), + jnp.cos((theta - 2 * jnp.pi) / 3), + ] + ) + - b / 3, + source=0, + destination=-1, + ) + + def reducible(Q, R, b): + # One real and two complex roots. + A = -jnp.sign(R) * (jnp.abs(R) + jnp.sqrt(jnp.abs(R**2 - Q**3))) ** (1 / 3) + B = safediv(Q, A) + r1 = (A + B) - b / 3 + return _concat_sentinel(r1[..., jnp.newaxis], sentinel, num=2) + + def root(b, c, d): + b = safediv(b, a) + c = safediv(c, a) + d = safediv(d, a) + Q = (b**2 - 3 * c) / 9 + R = (2 * b**3 - 9 * b * c + 27 * d) / 54 + mask = R**2 < Q**3 + return jnp.where( + mask[..., jnp.newaxis], + irreducible(jnp.abs(Q), R, b, mask), + reducible(Q, R, b), + ) + + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + d = C[..., 3] + return jnp.where( + # Tests catch failure here if eps < 1e-12 for 64 bit precision. + jnp.expand_dims(jnp.abs(a) <= eps, axis=-1), + _concat_sentinel( + _root_quadratic( + C=C[..., 1:], sentinel=sentinel, eps=eps, distinct=distinct + ), + sentinel, + ), + root(b, c, d), + ) + + +def _root_quadratic(C, sentinel, eps, distinct): + """Return real quadratic root assuming real coefficients.""" + # numerical.recipes/book.html, page 227 + a = C[..., 0] + b = C[..., 1] + c = C[..., 2] + + discriminant = b**2 - 4 * a * c + q = -0.5 * (b + jnp.sign(b) * jnp.sqrt(jnp.abs(discriminant))) + r1 = jnp.where( + discriminant < 0, + sentinel, + safediv(q, a, _root_linear(C=C[..., 1:], sentinel=sentinel, eps=eps)), + ) + r2 = jnp.where( + # more robust to remove repeated roots with discriminant + (discriminant < 0) | (distinct & (discriminant <= eps)), + sentinel, + safediv(c, q, sentinel), + ) + return jnp.stack([r1, r2], axis=-1) + + +def _root_linear(C, sentinel, eps, distinct=False): + """Return real linear root assuming real coefficients.""" + a = C[..., 0] + b = C[..., 1] + return safediv(-b, a, jnp.where(jnp.abs(b) <= eps, 0, sentinel)) + + +def _concat_sentinel(r, sentinel, num=1): + """Append ``sentinel`` ``num`` times to ``r`` on last axis.""" + sent = jnp.broadcast_to(sentinel, (*r.shape[:-1], num)) + return jnp.append(r, sent, axis=-1) diff --git a/desc/integrals/quad_utils.py b/desc/integrals/quad_utils.py new file mode 100644 index 0000000000..692149e84e --- /dev/null +++ b/desc/integrals/quad_utils.py @@ -0,0 +1,246 @@ +"""Utilities for quadratures.""" + +from orthax.legendre import legder, legval + +from desc.backend import eigh_tridiagonal, jnp, put +from desc.utils import errorif + + +def bijection_to_disc(x, a, b): + """[a, b] ∋ x ↦ y ∈ [−1, 1].""" + y = 2.0 * (x - a) / (b - a) - 1.0 + return y + + +def bijection_from_disc(x, a, b): + """[−1, 1] ∋ x ↦ y ∈ [a, b].""" + y = 0.5 * (b - a) * (x + 1.0) + a + return y + + +def grad_bijection_from_disc(a, b): + """Gradient wrt ``x`` of ``bijection_from_disc``.""" + dy_dx = 0.5 * (b - a) + return dy_dx + + +def automorphism_arcsin(x): + """[-1, 1] ∋ x ↦ y ∈ [−1, 1]. + + The arcsin transformation introduces a singularity that augments the singularity + in the bounce integral, so the quadrature scheme used to evaluate the integral must + work well on functions with large derivative near the boundary. + + Parameters + ---------- + x : jnp.ndarray + Points to transform. + + Returns + ------- + y : jnp.ndarray + Transformed points. + + """ + y = 2.0 * jnp.arcsin(x) / jnp.pi + return y + + +def grad_automorphism_arcsin(x): + """Gradient of arcsin automorphism.""" + dy_dx = 2.0 / (jnp.sqrt(1.0 - x**2) * jnp.pi) + return dy_dx + + +grad_automorphism_arcsin.__doc__ += "\n" + automorphism_arcsin.__doc__ + + +def automorphism_sin(x, s=0, m=10): + """[-1, 1] ∋ x ↦ y ∈ [−1, 1]. + + When used as the change of variable map for the bounce integral, the Lipschitzness + of the sin transformation prevents generation of new singularities. Furthermore, + its derivative vanishes to zero slowly near the boundary, which will suppress the + large derivatives near the boundary of singular integrals. + + In effect, this map pulls the mass of the integral away from the singularities, + which should improve convergence if the quadrature performs better on less singular + integrands. Pairs well with Gauss-Legendre quadrature. + + Parameters + ---------- + x : jnp.ndarray + Points to transform. + s : float + Strength of derivative suppression, s ∈ [0, 1]. + m : float + Number of machine epsilons used for floating point error buffer. + + Returns + ------- + y : jnp.ndarray + Transformed points. + + """ + errorif(not (0 <= s <= 1)) + # s = 0 -> derivative vanishes like cosine. + # s = 1 -> derivative vanishes like cosine^k. + y0 = jnp.sin(0.5 * jnp.pi * x) + y1 = x + jnp.sin(jnp.pi * x) / jnp.pi # k = 2 + y = (1 - s) * y0 + s * y1 + # y is an expansion, so y(x) > x near x ∈ {−1, 1} and there is a tendency + # for floating point error to overshoot the true value. + eps = m * jnp.finfo(jnp.array(1.0).dtype).eps + return jnp.clip(y, -1 + eps, 1 - eps) + + +def grad_automorphism_sin(x, s=0): + """Gradient of sin automorphism.""" + dy0_dx = 0.5 * jnp.pi * jnp.cos(0.5 * jnp.pi * x) + dy1_dx = 1.0 + jnp.cos(jnp.pi * x) + dy_dx = (1 - s) * dy0_dx + s * dy1_dx + return dy_dx + + +grad_automorphism_sin.__doc__ += "\n" + automorphism_sin.__doc__ + + +def tanh_sinh(deg, m=10): + """Tanh-Sinh quadrature. + + Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the + integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + + Parameters + ---------- + deg : int + Number of quadrature points. + m : float + Number of machine epsilons used for floating point error buffer. Larger implies + less floating point error, but increases the minimum achievable error. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Shape (deg, ). + Quadrature points and weights. + + """ + # buffer to avoid numerical instability + x_max = jnp.array(1.0) + x_max = x_max - m * jnp.finfo(x_max.dtype).eps + t_max = jnp.arcsinh(2 * jnp.arctanh(x_max) / jnp.pi) + # maximal-spacing scheme, doi.org/10.48550/arXiv.2007.15057 + t = jnp.linspace(-t_max, t_max, deg) + dt = 2 * t_max / (deg - 1) + arg = 0.5 * jnp.pi * jnp.sinh(t) + x = jnp.tanh(arg) # x = g(t) + w = 0.5 * jnp.pi * jnp.cosh(t) / jnp.cosh(arg) ** 2 * dt # w = (dg/dt) dt + return x, w + + +def leggauss_lob(deg, interior_only=False): + """Lobatto-Gauss-Legendre quadrature. + + Returns quadrature points xₖ and weights wₖ for the approximate evaluation of the + integral ∫₋₁¹ f(x) dx ≈ ∑ₖ wₖ f(xₖ). + + Parameters + ---------- + deg : int + Number of quadrature points. + interior_only : bool + Whether to exclude the points and weights at -1 and 1; + useful if f(-1) = f(1) = 0. If true, then ``deg`` points are still + returned; these are the interior points for lobatto quadrature of ``deg+2``. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Shape (deg, ). + Quadrature points and weights. + + """ + N = deg + 2 * bool(interior_only) + errorif(N < 2) + + # Golub-Welsh algorithm + n = jnp.arange(2, N - 1) + x = eigh_tridiagonal( + jnp.zeros(N - 2), + jnp.sqrt((n**2 - 1) / (4 * n**2 - 1)), + eigvals_only=True, + ) + c0 = put(jnp.zeros(N), -1, 1) + + # improve (single multiplicity) roots by one application of Newton + c = legder(c0) + dy = legval(x=x, c=c) + df = legval(x=x, c=legder(c)) + x -= dy / df + + w = 2 / (N * (N - 1) * legval(x=x, c=c0) ** 2) + + if not interior_only: + x = jnp.hstack([-1.0, x, 1.0]) + w_end = 2 / (deg * (deg - 1)) + w = jnp.hstack([w_end, w, w_end]) + + assert x.size == w.size == deg + return x, w + + +def get_quadrature(quad, automorphism): + """Apply automorphism to given quadrature. + + Parameters + ---------- + quad : (jnp.ndarray, jnp.ndarray) + Quadrature points xₖ and weights wₖ for the approximate evaluation of an + integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). + automorphism : (Callable, Callable) or None + The first callable should be an automorphism of the real interval [-1, 1]. + The second callable should be the derivative of the first. This map defines + a change of variable for the bounce integral. The choice made for the + automorphism will affect the performance of the quadrature method. + + Returns + ------- + x, w : (jnp.ndarray, jnp.ndarray) + Quadrature points and weights. + + """ + x, w = quad + assert x.ndim == w.ndim == 1 + if automorphism is not None: + auto, grad_auto = automorphism + w = w * grad_auto(x) + # Recall bijection_from_disc(auto(x), ζ₁, ζ₂) = ζ. + x = auto(x) + return x, w + + +def composite_linspace(x, num): + """Returns linearly spaced values between every pair of values in ``x``. + + Parameters + ---------- + x : jnp.ndarray + First axis has values to return linearly spaced values between. The remaining + axes are batch axes. Assumes input is sorted along first axis. + num : int + Number of values between every pair of values in ``x``. + + Returns + ------- + vals : jnp.ndarray + Shape ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]). + Linearly spaced values between ``x``. + + """ + x = jnp.atleast_1d(x) + vals = jnp.linspace(x[:-1], x[1:], num + 1, endpoint=False) + vals = jnp.swapaxes(vals, 0, 1).reshape(-1, *x.shape[1:]) + vals = jnp.append(vals, x[jnp.newaxis, -1], axis=0) + assert vals.shape == ((x.shape[0] - 1) * num + x.shape[0], *x.shape[1:]) + return vals diff --git a/desc/integrals/singularities.py b/desc/integrals/singularities.py index 3730c172af..ab2371a839 100644 --- a/desc/integrals/singularities.py +++ b/desc/integrals/singularities.py @@ -9,10 +9,9 @@ from desc.backend import fori_loop, jnp, put, vmap from desc.basis import DoubleFourierSeries from desc.compute.geom_utils import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec -from desc.compute.utils import safediv, safenorm from desc.grid import LinearGrid from desc.io import IOAble -from desc.utils import isalmostequal, islinspaced +from desc.utils import isalmostequal, islinspaced, safediv, safenorm def _get_quadrature_nodes(q): diff --git a/desc/integrals/surface_integral.py b/desc/integrals/surface_integral.py index acc1e6c1b9..944a711904 100644 --- a/desc/integrals/surface_integral.py +++ b/desc/integrals/surface_integral.py @@ -100,7 +100,7 @@ def line_integrals( The coordinate curve to compute the integration over. To clarify, a theta (poloidal) curve is the intersection of a rho surface (flux surface) and zeta (toroidal) surface. - fix_surface : str, float + fix_surface : (str, float) A tuple of the form: label, value. ``fix_surface`` label should differ from ``line_label``. By default, ``fix_surface`` is chosen to be the flux surface at rho=1. diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 554cdac070..e15a21756e 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -169,16 +169,17 @@ class IOAble(ABC, metaclass=_CombinedMeta): """Abstract Base Class for savable and loadable objects. Objects inheriting from this class can be saved and loaded via hdf5 or pickle. - To save properly, each object should have an attribute `_io_attrs_` which + To save properly, each object should have an attribute ``_io_attrs_`` which is a list of strings of the object attributes or properties that should be saved and loaded. - For saved objects to be loaded correctly, the __init__ method of any custom - types being saved should only assign attributes that are listed in `_io_attrs_`. + For saved objects to be loaded correctly, the ``__init__`` method of any custom + types being saved should only assign attributes that are listed in ``_io_attrs_``. Other attributes or other initialization should be done in a separate - `set_up()` method that can be called during __init__. The loading process - will involve creating an empty object, bypassing init, then setting any `_io_attrs_` - of the object, then calling `_set_up()` without any arguments, if it exists. + ``set_up()`` method that can be called during ``__init__``. The loading process + will involve creating an empty object, bypassing init, then setting any + ``_io_attrs_`` of the object, then calling ``_set_up()`` without any arguments, + if it exists. """ diff --git a/desc/magnetic_fields/__init__.py b/desc/magnetic_fields/__init__.py index 0a8f18abd8..173b04e7ee 100644 --- a/desc/magnetic_fields/__init__.py +++ b/desc/magnetic_fields/__init__.py @@ -9,6 +9,7 @@ SplineMagneticField, SumMagneticField, ToroidalMagneticField, + VectorPotentialField, VerticalMagneticField, _MagneticField, field_line_integrate, diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 7b32a1217a..760f4e372b 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -15,7 +15,7 @@ DoubleFourierSeries, ) from desc.compute import compute as compute_fun -from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz +from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec from desc.compute.utils import get_params, get_transforms from desc.derivatives import Derivative from desc.equilibrium import EquilibriaFamily, Equilibrium @@ -62,6 +62,40 @@ def body(i, B): return 1e-7 * fori_loop(0, J.shape[0], body, B) +def biot_savart_general_vector_potential(re, rs, J, dV): + """Biot-Savart law for arbitrary sources for vector potential. + + Parameters + ---------- + re : ndarray, shape(n_eval_pts, 3) + evaluation points to evaluate B at, in cartesian. + rs : ndarray, shape(n_src_pts, 3) + source points for current density J, in cartesian. + J : ndarray, shape(n_src_pts, 3) + current density vector at source points, in cartesian. + dV : ndarray, shape(n_src_pts) + volume element at source points + + Returns + ------- + A : ndarray, shape(n,3) + magnetic vector potential in cartesian components at specified points + """ + re, rs, J, dV = map(jnp.asarray, (re, rs, J, dV)) + assert J.shape == rs.shape + JdV = J * dV[:, None] + A = jnp.zeros_like(re) + + def body(i, A): + r = re - rs[i, :] + num = JdV[i, :] + den = jnp.linalg.norm(r, axis=-1) + A = A + jnp.where(den[:, None] == 0, 0, num / den[:, None]) + return A + + return 1e-7 * fori_loop(0, J.shape[0], body, A) + + def read_BNORM_file(fname, surface, eval_grid=None, scale_by_curpol=True): """Read BNORM-style .txt file containing Bnormal Fourier coefficients. @@ -193,6 +227,8 @@ def compute_magnetic_field( source_grid : Grid, int or None or array-like, optional Grid used to discretize MagneticField object if calculating B from Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid Returns ------- @@ -205,6 +241,33 @@ def __call__(self, grid, params=None, basis="rpz"): """Compute magnetic field at a set of points.""" return self.compute_magnetic_field(grid, params, basis) + @abstractmethod + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + def compute_Bnormal( self, surface, @@ -410,6 +473,7 @@ def save_mgrid( nR=101, nZ=101, nphi=90, + save_vector_potential=True, ): """Save the magnetic field to an mgrid NetCDF file in "raw" format. @@ -431,6 +495,9 @@ def save_mgrid( Number of grid points in the Z coordinate (default = 101). nphi : int, optional Number of grid points in the toroidal angle (default = 90). + save_vector_potential : bool, optional + Whether or not to save the magnetic vector potential to the mgrid + file, in addition to the magnetic field. Defaults to True. Returns ------- @@ -451,6 +518,15 @@ def save_mgrid( B_phi = field[:, 1].reshape(nphi, nZ, nR) B_Z = field[:, 2].reshape(nphi, nZ, nR) + # evaluate magnetic vector potential on grid + if save_vector_potential: + field = self.compute_magnetic_vector_potential(grid, basis="rpz") + A_R = field[:, 0].reshape(nphi, nZ, nR) + A_phi = field[:, 1].reshape(nphi, nZ, nR) + A_Z = field[:, 2].reshape(nphi, nZ, nR) + else: + A_R = None + # write mgrid file file = Dataset(path, mode="w", format="NETCDF3_64BIT_OFFSET") @@ -537,6 +613,28 @@ def save_mgrid( ) bz_001[:] = B_Z + if save_vector_potential: + ar_001 = file.createVariable("ar_001", np.float64, ("phi", "zee", "rad")) + ar_001.long_name = ( + "A_R = radial component of magnetic vector potential " + "in lab frame (T/m)." + ) + ar_001[:] = A_R + + ap_001 = file.createVariable("ap_001", np.float64, ("phi", "zee", "rad")) + ap_001.long_name = ( + "A_phi = toroidal component of magnetic vector potential " + "in lab frame (T/m)." + ) + ap_001[:] = A_phi + + az_001 = file.createVariable("az_001", np.float64, ("phi", "zee", "rad")) + az_001.long_name = ( + "A_Z = vertical component of magnetic vector potential " + "in lab frame (T/m)." + ) + az_001[:] = A_Z + file.close() @@ -618,6 +716,33 @@ def compute_magnetic_field( B = rpz2xyz_vec(B, phi=coords[:, 1]) return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "MagneticFieldFromUser does not have vector potential calculation " + "implemented." + ) + class ScaledMagneticField(_MagneticField, Optimizable): """Magnetic field scaled by a scalar value. @@ -703,6 +828,35 @@ def compute_magnetic_field( coords, params, basis, source_grid ) + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + scaled magnetic vector potential at specified points + + """ + return self._scale * self._field.compute_magnetic_vector_potential( + coords, params, basis, source_grid + ) + class SumMagneticField(_MagneticField, MutableSequence, OptimizableCollection): """Sum of two or more magnetic field sources. @@ -724,10 +878,16 @@ def __init__(self, *fields): ) self._fields = fields - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -742,6 +902,9 @@ def compute_magnetic_field( Biot-Savart. Should NOT include endpoint at 2pi. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- @@ -749,6 +912,11 @@ def compute_magnetic_field( scaled magnetic field at specified points """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) if params is None: params = [None] * len(self._fields) if isinstance(params, dict): @@ -770,13 +938,74 @@ def compute_magnetic_field( # zip does not terminate early transforms = transforms * len(self._fields) - B = 0 + op = {"B": "compute_magnetic_field", "A": "compute_magnetic_vector_potential"}[ + compute_A_or_B + ] + + AB = 0 for i, (field, g, tr) in enumerate(zip(self._fields, source_grid, transforms)): - B += field.compute_magnetic_field( + AB += getattr(field, op)( coords, params[i % len(params)], basis, source_grid=g, transforms=tr ) + return AB - return B + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating B from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + sum magnetic field at specified points + + """ + return self._compute_A_or_B( + coords, params, basis, source_grid, transforms, compute_A_or_B="B" + ) + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Grid used to discretize MagneticField object if calculating A from + Biot-Savart. Should NOT include endpoint at 2pi. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + sum magnetic vector potential at specified points + + """ + return self._compute_A_or_B( + coords, params, basis, source_grid, transforms, compute_A_or_B="A" + ) # dunder methods required by MutableSequence def __getitem__(self, i): @@ -886,10 +1115,54 @@ def compute_magnetic_field( return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The vector potential is specified assuming the Coulomb Gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for R0 and B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + params = setdefault(params, {}) + B0 = params.get("B0", self.B0) + R0 = params.get("R0", self.R0) + + assert basis.lower() in ["rpz", "xyz"] + coords = jnp.atleast_2d(jnp.asarray(coords)) + if basis == "xyz": + coords = xyz2rpz(coords) + az = -B0 * R0 * jnp.log(coords[:, 0]) + arp = jnp.zeros_like(az) + A = jnp.array([arp, arp, az]).T + # b/c it only has a nonzero z component, no need + # to switch bases back if xyz is given + return A + class VerticalMagneticField(_MagneticField, Optimizable): """Uniform magnetic field purely in the vertical (Z) direction. + The vector potential is specified assuming the Coulomb Gauge. + Parameters ---------- B0 : float @@ -940,18 +1213,63 @@ def compute_magnetic_field( params = setdefault(params, {}) B0 = params.get("B0", self.B0) - assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) - if basis == "xyz": - coords = xyz2rpz(coords) bz = B0 * jnp.ones_like(coords[:, 2]) brp = jnp.zeros_like(bz) B = jnp.array([brp, brp, bz]).T - if basis == "xyz": - B = rpz2xyz_vec(B, phi=coords[:, 1]) + # b/c it only has a nonzero z component, no need + # to switch bases back if xyz is given return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + The vector potential is specified assuming the Coulomb Gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + params = setdefault(params, {}) + B0 = params.get("B0", self.B0) + + coords = jnp.atleast_2d(jnp.asarray(coords)) + + if basis == "xyz": + coords_xyz = coords + coords_rpz = xyz2rpz(coords) + else: + coords_rpz = coords + coords_xyz = rpz2xyz(coords) + ax = B0 / 2 * coords_xyz[:, 1] + ay = -B0 / 2 * coords_xyz[:, 0] + + az = jnp.zeros_like(ax) + A = jnp.array([ax, ay, az]).T + if basis == "rpz": + A = xyz2rpz_vec(A, phi=coords_rpz[:, 1]) + + return A + class PoloidalMagneticField(_MagneticField, Optimizable): """Pure poloidal magnetic field (ie in theta direction). @@ -1062,6 +1380,36 @@ def compute_magnetic_field( return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "PoloidalMagneticField has nonzero divergence, therefore it can't be " + "represented with a vector potential." + ) + class SplineMagneticField(_MagneticField, Optimizable): """Magnetic field from precomputed values on a grid. @@ -1080,6 +1428,12 @@ class SplineMagneticField(_MagneticField, Optimizable): toroidal magnetic field on grid BZ : array-like, shape(NR,Nphi,NZ,Ngroups) vertical magnetic field on grid + AR : array-like, shape(NR,Nphi,NZ,Ngroups) + radial magnetic vector potential on grid, optional + aphi : array-like, shape(NR,Nphi,NZ,Ngroups) + toroidal magnetic vector potential on grid, optional + AZ : array-like, shape(NR,Nphi,NZ,Ngroups) + vertical magnetic vector potential on grid, optional currents : array-like, shape(Ngroups) Currents or scaling factors for each field group. NFP : int, optional @@ -1098,6 +1452,9 @@ class SplineMagneticField(_MagneticField, Optimizable): "_BR", "_Bphi", "_BZ", + "_AR", + "_Aphi", + "_AZ", "_method", "_extrap", "_derivs", @@ -1110,7 +1467,20 @@ class SplineMagneticField(_MagneticField, Optimizable): _static_attrs = ["_extrap", "_period"] def __init__( - self, R, phi, Z, BR, Bphi, BZ, currents=1.0, NFP=1, method="cubic", extrap=False + self, + R, + phi, + Z, + BR, + Bphi, + BZ, + AR=None, + Aphi=None, + AZ=None, + currents=1.0, + NFP=1, + method="cubic", + extrap=False, ): R, phi, Z, currents = map( lambda x: jnp.atleast_1d(jnp.asarray(x)), (R, phi, Z, currents) @@ -1152,6 +1522,17 @@ def _atleast_4d(x): self._derivs["BR"] = self._approx_derivs(self._BR) self._derivs["Bphi"] = self._approx_derivs(self._Bphi) self._derivs["BZ"] = self._approx_derivs(self._BZ) + if AR is not None and Aphi is not None and AZ is not None: + AR, Aphi, AZ = map(_atleast_4d, (AR, Aphi, AZ)) + assert AR.shape == Aphi.shape == AZ.shape == shape + self._AR = AR + self._Aphi = Aphi + self._AZ = AZ + self._derivs["AR"] = self._approx_derivs(self._AR) + self._derivs["Aphi"] = self._approx_derivs(self._Aphi) + self._derivs["AZ"] = self._approx_derivs(self._AZ) + else: + self._AR = self._Aphi = self._AZ = None @property def NFP(self): @@ -1190,10 +1571,16 @@ def _approx_derivs(self, Bi): tempdict[key] = val[:, 0, :] return tempdict - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or magnetic vector potential at a set of points. Parameters ---------- @@ -1208,107 +1595,185 @@ def compute_magnetic_field( transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid Unused by this MagneticField class. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points, in cylindrical form [BR, Bphi,BZ] + magnetic field or vector potential at specified points, + in cylindrical form [BR, Bphi,BZ] """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + errorif( + compute_A_or_B == "A" and self._AR is None, + ValueError, + "Cannot calculate vector potential" + " as no vector potential spline values exist.", + ) assert basis.lower() in ["rpz", "xyz"] currents = self.currents if params is None else params["currents"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "xyz": coords = xyz2rpz(coords) Rq, phiq, Zq = coords.T + if compute_A_or_B == "B": + A_or_B_R = self._BR + A_or_B_phi = self._Bphi + A_or_B_Z = self._BZ + elif compute_A_or_B == "A": + A_or_B_R = self._AR + A_or_B_phi = self._Aphi + A_or_B_Z = self._AZ + if self._axisym: - BRq = interp2d( + ABRq = interp2d( Rq, Zq, self._R, self._Z, - self._BR[:, 0, :], + A_or_B_R[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["BR"], + **self._derivs[compute_A_or_B + "R"], ) - Bphiq = interp2d( + ABphiq = interp2d( Rq, Zq, self._R, self._Z, - self._Bphi[:, 0, :], + A_or_B_phi[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["Bphi"], + **self._derivs[compute_A_or_B + "phi"], ) - BZq = interp2d( + ABZq = interp2d( Rq, Zq, self._R, self._Z, - self._BZ[:, 0, :], + A_or_B_Z[:, 0, :], self._method, (0, 0), self._extrap, (None, None), - **self._derivs["BZ"], + **self._derivs[compute_A_or_B + "Z"], ) else: - BRq = interp3d( + ABRq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._BR, + A_or_B_R, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["BR"], + **self._derivs[compute_A_or_B + "R"], ) - Bphiq = interp3d( + ABphiq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._Bphi, + A_or_B_phi, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["Bphi"], + **self._derivs[compute_A_or_B + "phi"], ) - BZq = interp3d( + ABZq = interp3d( Rq, phiq, Zq, self._R, self._phi, self._Z, - self._BZ, + A_or_B_Z, self._method, (0, 0, 0), self._extrap, (None, 2 * np.pi / self.NFP, None), - **self._derivs["BZ"], + **self._derivs[compute_A_or_B + "Z"], ) - # BRq etc shape(nq, ngroups) - B = jnp.stack([BRq, Bphiq, BZq], axis=1) - # B shape(nq, 3, ngroups) - B = jnp.sum(B * currents, axis=-1) + # ABRq etc shape(nq, ngroups) + AB = jnp.stack([ABRq, ABphiq, ABZq], axis=1) + # AB shape(nq, 3, ngroups) + AB = jnp.sum(AB * currents, axis=-1) if basis == "xyz": - B = rpz2xyz_vec(B, phi=coords[:, 1]) - return B + AB = rpz2xyz_vec(AB, phi=coords[:, 1]) + return AB + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points, in cylindrical form [BR, Bphi,BZ] + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") @classmethod def from_mgrid(cls, mgrid_file, extcur=None, method="cubic", extrap=False): @@ -1366,8 +1831,40 @@ def from_mgrid(cls, mgrid_file, extcur=None, method="cubic", extrap=False): bp = np.moveaxis(bp, (0, 1, 2), (1, 2, 0)) bz = np.moveaxis(bz, (0, 1, 2), (1, 2, 0)) + # sum magnetic vector potentials from each coil + ar = np.zeros([kp, jz, ir, nextcur]) + ap = np.zeros([kp, jz, ir, nextcur]) + az = np.zeros([kp, jz, ir, nextcur]) + try: + for i in range(nextcur): + coil_id = "%03d" % (i + 1,) + ar[:, :, :, i] += mgrid["ar_" + coil_id][ + () + ] # A_R radial mag. vec. potential + ap[:, :, :, i] += mgrid["ap_" + coil_id][ + () + ] # A_phi toroidal mag. vec. potential + az[:, :, :, i] += mgrid["az_" + coil_id][ + () + ] # A_Z vertical mag. vec. potential + + # shift axes to correct order + ar = np.moveaxis(ar, (0, 1, 2), (1, 2, 0)) + ap = np.moveaxis(ap, (0, 1, 2), (1, 2, 0)) + az = np.moveaxis(az, (0, 1, 2), (1, 2, 0)) + except IndexError: + warnif( + True, + UserWarning, + "mgrid does not appear to contain vector potential information." + " Vector potential will not be computable.", + ) + ar = ap = az = None + mgrid.close() - return cls(Rgrid, pgrid, Zgrid, br, bp, bz, extcur, nfp, method, extrap) + return cls( + Rgrid, pgrid, Zgrid, br, bp, bz, ar, ap, az, extcur, nfp, method, extrap + ) @classmethod def from_field( @@ -1397,6 +1894,15 @@ def from_field( shp = rr.shape coords = np.array([rr.flatten(), pp.flatten(), zz.flatten()]).T BR, BP, BZ = field.compute_magnetic_field(coords, params, basis="rpz").T + try: + AR, AP, AZ = field.compute_magnetic_vector_potential( + coords, params, basis="rpz" + ).T + AR = AR.reshape(shp) + AP = AP.reshape(shp) + AZ = AZ.reshape(shp) + except NotImplementedError: + AR = AP = AZ = None return cls( R, phi, @@ -1404,6 +1910,9 @@ def from_field( BR.reshape(shp), BP.reshape(shp), BZ.reshape(shp), + AR=AR, + Aphi=AP, + AZ=AZ, currents=1.0, NFP=NFP, method=method, @@ -1474,6 +1983,187 @@ def compute_magnetic_field( B = rpz2xyz_vec(B, phi=coords[:, 1]) return B + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + raise NotImplementedError( + "ScalarPotentialField does not have vector potential calculation " + "implemented." + ) + + +class VectorPotentialField(_MagneticField): + """Magnetic field due to a vector magnetic potential in cylindrical coordinates. + + Parameters + ---------- + potential : callable + function to compute the vector potential. Should have a signature of + the form potential(R,phi,Z,*params) -> ndarray. + R,phi,Z are arrays of cylindrical coordinates. + params : dict, optional + default parameters to pass to potential function + + """ + + def __init__(self, potential, params=None): + self._potential = potential + self._params = params + + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", + ): + """Compute magnetic field or vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) + assert basis.lower() in ["rpz", "xyz"] + coords = jnp.atleast_2d(jnp.asarray(coords)) + coords = coords.astype(float) + if basis == "xyz": + coords = xyz2rpz(coords) + + if params is None: + params = self._params + r, p, z = coords.T + + if compute_A_or_B == "B": + funR = lambda x: self._potential(x, p, z, **params) + funP = lambda x: self._potential(r, x, z, **params) + funZ = lambda x: self._potential(r, p, x, **params) + + ap = self._potential(r, p, z, **params)[:, 1] + + # these are the gradients of each component of A + dAdr = Derivative.compute_jvp(funR, 0, (jnp.ones_like(r),), r) + dAdp = Derivative.compute_jvp(funP, 0, (jnp.ones_like(p),), p) + dAdz = Derivative.compute_jvp(funZ, 0, (jnp.ones_like(z),), z) + + # form the B components with the appropriate combinations + B = jnp.array( + [ + dAdp[:, 2] / r - dAdz[:, 1], + dAdz[:, 0] - dAdr[:, 2], + dAdr[:, 1] + (ap - dAdp[:, 0]) / r, + ] + ).T + if basis == "xyz": + B = rpz2xyz_vec(B, phi=coords[:, 1]) + return B + elif compute_A_or_B == "A": + A = self._potential(r, p, z, **params) + if basis == "xyz": + A = rpz2xyz_vec(A, phi=coords[:, 1]) + return A + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dict of values for B0. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Unused by this MagneticField class. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + Unused by this MagneticField class. + + Returns + ------- + A : ndarray, shape(N,3) + magnetic vector potential at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + def field_line_integrate( r0, diff --git a/desc/magnetic_fields/_current_potential.py b/desc/magnetic_fields/_current_potential.py index ab8a42d909..a8759155ee 100644 --- a/desc/magnetic_fields/_current_potential.py +++ b/desc/magnetic_fields/_current_potential.py @@ -11,7 +11,11 @@ from desc.optimizable import Optimizable, optimizable_parameter from desc.utils import copy_coeffs, errorif, setdefault, warnif -from ._core import _MagneticField, biot_savart_general +from ._core import ( + _MagneticField, + biot_savart_general, + biot_savart_general_vector_potential, +) class CurrentPotentialField(_MagneticField, FourierRZToroidalSurface): @@ -177,10 +181,16 @@ def save(self, file_name, file_format=None, file_mode="w"): " as the potential function cannot be serialized." ) - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -194,11 +204,14 @@ def compute_magnetic_field( Source grid upon which to evaluate the surface current density K. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ source_grid = source_grid or LinearGrid( @@ -206,15 +219,70 @@ def compute_magnetic_field( N=30 + 2 * self.N, NFP=self.NFP, ) - return _compute_magnetic_field_from_CurrentPotentialField( + return _compute_A_or_B_from_CurrentPotentialField( field=self, coords=coords, params=params, basis=basis, source_grid=source_grid, transforms=transforms, + compute_A_or_B=compute_A_or_B, ) + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + This assumes the Coulomb gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + Magnetic vector potential at specified points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + @classmethod def from_surface( cls, @@ -496,10 +564,16 @@ def change_Phi_resolution(self, M=None, N=None, NFP=None, sym_Phi=None): NFP=NFP ) # make sure surface and Phi basis NFP are the same - def compute_magnetic_field( - self, coords, params=None, basis="rpz", source_grid=None, transforms=None + def _compute_A_or_B( + self, + coords, + params=None, + basis="rpz", + source_grid=None, + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -513,11 +587,14 @@ def compute_magnetic_field( Source grid upon which to evaluate the surface current density K. transforms : dict of Transform Transforms for R, Z, lambda, etc. Default is to build from source_grid + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ source_grid = source_grid or LinearGrid( @@ -525,15 +602,70 @@ def compute_magnetic_field( N=30 + 2 * max(self.N, self.N_Phi), NFP=self.NFP, ) - return _compute_magnetic_field_from_CurrentPotentialField( + return _compute_A_or_B_from_CurrentPotentialField( field=self, coords=coords, params=params, basis=basis, source_grid=source_grid, transforms=transforms, + compute_A_or_B=compute_A_or_B, ) + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "B") + + def compute_magnetic_vector_potential( + self, coords, params=None, basis="rpz", source_grid=None, transforms=None + ): + """Compute magnetic vector potential at a set of points. + + This assumes the Coulomb gauge. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate vector potential at in [R,phi,Z] or [X,Y,Z] coordinates. + params : dict or array-like of dict, optional + Dictionary of optimizable parameters, eg field.params_dict. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic vector potential. + source_grid : Grid, int or None or array-like, optional + Source grid upon which to evaluate the surface current density K. + transforms : dict of Transform + Transforms for R, Z, lambda, etc. Default is to build from source_grid + + Returns + ------- + A : ndarray, shape(N,3) + Magnetic vector potential at specified points. + + """ + return self._compute_A_or_B(coords, params, basis, source_grid, transforms, "A") + @classmethod def from_surface( cls, @@ -613,10 +745,16 @@ def from_surface( ) -def _compute_magnetic_field_from_CurrentPotentialField( - field, coords, source_grid, params=None, basis="rpz", transforms=None +def _compute_A_or_B_from_CurrentPotentialField( + field, + coords, + source_grid, + params=None, + basis="rpz", + transforms=None, + compute_A_or_B="B", ): - """Compute magnetic field at a set of points. + """Compute magnetic field or vector potential at a set of points. Parameters ---------- @@ -631,25 +769,36 @@ def _compute_magnetic_field_from_CurrentPotentialField( should include the potential basis : {"rpz", "xyz"} basis for input coordinates and returned magnetic field + compute_A_or_B: {"A", "B"}, optional + whether to compute the magnetic vector potential "A" or the magnetic field + "B". Defaults to "B" Returns ------- field : ndarray, shape(N,3) - magnetic field at specified points + magnetic field or vector potential at specified points """ + errorif( + compute_A_or_B not in ["A", "B"], + ValueError, + f'Expected "A" or "B" for compute_A_or_B, instead got {compute_A_or_B}', + ) assert basis.lower() in ["rpz", "xyz"] coords = jnp.atleast_2d(jnp.asarray(coords)) if basis == "rpz": coords = rpz2xyz(coords) - + op = {"B": biot_savart_general, "A": biot_savart_general_vector_potential}[ + compute_A_or_B + ] # compute surface current, and store grid quantities # needed for integration in class if not params or not transforms: data = field.compute( ["K", "x"], grid=source_grid, + basis="rpz", params=params, transforms=transforms, jitable=True, @@ -680,7 +829,7 @@ def nfp_loop(j, f): rs = jnp.vstack((_rs[:, 0], phi, _rs[:, 2])).T rs = rpz2xyz(rs) K = rpz2xyz_vec(_K, phi=phi) - fj = biot_savart_general( + fj = op( coords, rs, K, diff --git a/desc/objectives/_coils.py b/desc/objectives/_coils.py index 7af1e61cce..1fc9247295 100644 --- a/desc/objectives/_coils.py +++ b/desc/objectives/_coils.py @@ -12,10 +12,9 @@ ) from desc.compute import get_profiles, get_transforms, rpz2xyz from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.grid import LinearGrid, _Grid from desc.integrals import compute_B_plasma -from desc.utils import Timer, errorif, warnif +from desc.utils import Timer, errorif, safenorm, warnif from .normalization import compute_scaling_factors from .objective_funs import _Objective @@ -124,6 +123,12 @@ def _prune_coilset_tree(coilset): # get individual coils from coilset coils, structure = tree_flatten(coil, is_leaf=_is_single_coil) + for c in coils: + errorif( + not isinstance(c, _Coil), + TypeError, + f"Expected object of type Coil, got {type(c)}", + ) self._num_coils = len(coils) # map grid to list of length coils @@ -1305,6 +1310,14 @@ class ToroidalFlux(_Objective): by making the coil currents zero. Instead, this objective ensures the coils create the necessary toroidal flux for the equilibrium field. + Will try to use the vector potential method to calculate the toroidal flux + (Φ = ∮ 𝐀 ⋅ 𝐝𝐥 over the perimeter of a constant zeta plane) + instead of the brute force method using the magnetic field + (Φ = ∯ 𝐁 ⋅ 𝐝𝐒 over a constant zeta XS). The vector potential method + is much more efficient, however not every ``MagneticField`` object + has a vector potential available to compute, so in those cases + the magnetic field method is used. + Parameters ---------- eq : Equilibrium @@ -1350,6 +1363,7 @@ class ToroidalFlux(_Objective): name : str, optional Name of the objective function. + """ _coordinates = "rtz" @@ -1377,6 +1391,7 @@ def __init__( self._field_grid = field_grid self._eval_grid = eval_grid self._eq = eq + # TODO: add eq_fixed option so this can be used in single stage super().__init__( things=[field], @@ -1402,9 +1417,17 @@ def build(self, use_jit=True, verbose=1): """ eq = self._eq + self._use_vector_potential = True + try: + self._field.compute_magnetic_vector_potential([0, 0, 0]) + except (NotImplementedError, ValueError): + self._use_vector_potential = False if self._eval_grid is None: eval_grid = LinearGrid( - L=eq.L_grid, M=eq.M_grid, zeta=jnp.array(0.0), NFP=eq.NFP + L=eq.L_grid if not self._use_vector_potential else 0, + M=eq.M_grid, + zeta=jnp.array(0.0), + NFP=eq.NFP, ) self._eval_grid = eval_grid eval_grid = self._eval_grid @@ -1439,10 +1462,12 @@ def build(self, use_jit=True, verbose=1): if verbose > 0: print("Precomputing transforms") timer.start("Precomputing transforms") - - data = eq.compute( - ["R", "phi", "Z", "|e_rho x e_theta|", "n_zeta"], grid=eval_grid - ) + data_keys = ["R", "phi", "Z"] + if self._use_vector_potential: + data_keys += ["e_theta"] + else: + data_keys += ["|e_rho x e_theta|", "n_zeta"] + data = eq.compute(data_keys, grid=eval_grid) plasma_coords = jnp.array([data["R"], data["phi"], data["Z"]]).T @@ -1484,22 +1509,32 @@ def compute(self, field_params=None, constants=None): data = constants["equil_data"] plasma_coords = constants["plasma_coords"] - - B = constants["field"].compute_magnetic_field( - plasma_coords, - basis="rpz", - source_grid=constants["field_grid"], - params=field_params, - ) grid = constants["eval_grid"] - B_dot_n_zeta = jnp.sum(B * data["n_zeta"], axis=1) + if self._use_vector_potential: + A = constants["field"].compute_magnetic_vector_potential( + plasma_coords, + basis="rpz", + source_grid=constants["field_grid"], + params=field_params, + ) - Psi = jnp.sum( - grid.spacing[:, 0] - * grid.spacing[:, 1] - * data["|e_rho x e_theta|"] - * B_dot_n_zeta - ) + A_dot_e_theta = jnp.sum(A * data["e_theta"], axis=1) + Psi = jnp.sum(grid.spacing[:, 1] * A_dot_e_theta) + else: + B = constants["field"].compute_magnetic_field( + plasma_coords, + basis="rpz", + source_grid=constants["field_grid"], + params=field_params, + ) + + B_dot_n_zeta = jnp.sum(B * data["n_zeta"], axis=1) + Psi = jnp.sum( + grid.spacing[:, 0] + * grid.spacing[:, 1] + * data["|e_rho x e_theta|"] + * B_dot_n_zeta + ) return Psi diff --git a/desc/objectives/_equilibrium.py b/desc/objectives/_equilibrium.py index 624fd99023..dc2f4bbb22 100644 --- a/desc/objectives/_equilibrium.py +++ b/desc/objectives/_equilibrium.py @@ -557,7 +557,7 @@ class HelicalForceBalance(_Objective): _equilibrium = True _coordinates = "rtz" _units = "(N)" - _print_value_fmt = "Helical force error: {:10.3e}, " + _print_value_fmt = "Helical force error: " def __init__( self, diff --git a/desc/objectives/_geometry.py b/desc/objectives/_geometry.py index 57b1eebe46..e405609c79 100644 --- a/desc/objectives/_geometry.py +++ b/desc/objectives/_geometry.py @@ -5,9 +5,8 @@ from desc.backend import jnp, vmap from desc.compute import get_profiles, get_transforms, rpz2xyz, xyz2rpz from desc.compute.utils import _compute as compute_fun -from desc.compute.utils import safenorm from desc.grid import LinearGrid, QuadratureGrid -from desc.utils import Timer, errorif, parse_argname_change, warnif +from desc.utils import Timer, errorif, parse_argname_change, safenorm, warnif from .normalization import compute_scaling_factors from .objective_funs import _Objective diff --git a/desc/objectives/_omnigenity.py b/desc/objectives/_omnigenity.py index 05f08356c0..1eb7a8da6e 100644 --- a/desc/objectives/_omnigenity.py +++ b/desc/objectives/_omnigenity.py @@ -47,7 +47,7 @@ class QuasisymmetryBoozer(_Objective): reverse mode and forward over reverse mode respectively. grid : Grid, optional Collocation grid containing the nodes to evaluate at. - Must be a LinearGrid with a single flux surface and sym=False. + Must be a LinearGrid with sym=False. Defaults to ``LinearGrid(M=M_booz, N=N_booz)``. helicity : tuple, optional Type of quasi-symmetry (M, N). Default = quasi-axisymmetry (1, 0). @@ -122,12 +122,6 @@ def build(self, use_jit=True, verbose=1): grid = self._grid errorif(grid.sym, ValueError, "QuasisymmetryBoozer grid must be non-symmetric") - errorif( - grid.num_rho != 1, - ValueError, - "QuasisymmetryBoozer grid must be on a single surface. " - "To target multiple surfaces, use multiple objectives.", - ) warnif( grid.num_theta < 2 * eq.M, RuntimeWarning, @@ -195,7 +189,7 @@ def compute(self, params, constants=None): Returns ------- f : ndarray - Quasi-symmetry flux function error at each node (T^3). + Symmetry breaking harmonics of B (T). """ if constants is None: @@ -207,8 +201,11 @@ def compute(self, params, constants=None): transforms=constants["transforms"], profiles=constants["profiles"], ) - B_mn = constants["matrix"] @ data["|B|_mn"] - return B_mn[constants["idx"]] + B_mn = data["|B|_mn"].reshape((constants["transforms"]["grid"].num_rho, -1)) + B_mn = constants["matrix"] @ B_mn.T + # output order = (rho, mn).flatten(), ie all the surfaces concatenated + # one after the other + return B_mn[constants["idx"]].T.flatten() @property def helicity(self): diff --git a/desc/objectives/_power_balance.py b/desc/objectives/_power_balance.py index 74b679ee2a..299b248358 100644 --- a/desc/objectives/_power_balance.py +++ b/desc/objectives/_power_balance.py @@ -61,7 +61,7 @@ class FusionPower(_Objective): _scalar = True _units = "(W)" - _print_value_fmt = "Fusion power: {:10.3e} " + _print_value_fmt = "Fusion power: " def __init__( self, @@ -246,7 +246,7 @@ class HeatingPowerISS04(_Objective): _scalar = True _units = "(W)" - _print_value_fmt = "Heating power: {:10.3e} " + _print_value_fmt = "Heating power: " def __init__( self, diff --git a/desc/plotting.py b/desc/plotting.py index 55fac468f3..b3c8fdeb17 100644 --- a/desc/plotting.py +++ b/desc/plotting.py @@ -971,9 +971,9 @@ def plot_3d( if grid.num_rho == 1: n1, n2 = grid.num_theta, grid.num_zeta if not grid.nodes[-1][2] == 2 * np.pi: - p1, p2 = True, False + p1, p2 = False, False else: - p1, p2 = True, True + p1, p2 = False, True elif grid.num_theta == 1: n1, n2 = grid.num_rho, grid.num_zeta p1, p2 = False, True @@ -2614,7 +2614,7 @@ def plot_boozer_modes( # noqa: C901 elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) - B_mn = np.array([[]]) + rho = np.sort(rho) M_booz = kwargs.pop("M_booz", 2 * eq.M) N_booz = kwargs.pop("N_booz", 2 * eq.N) linestyle = kwargs.pop("ls", "-") @@ -2632,16 +2632,15 @@ def plot_boozer_modes( # noqa: C901 else: matrix, modes = ptolemy_linear_transform(basis.modes) - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("|B|_mn", grid=grid, transforms=transforms) - b_mn = np.atleast_2d(matrix @ data["|B|_mn"]) - B_mn = np.vstack((B_mn, b_mn)) if B_mn.size else b_mn + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute("|B|_mn", grid=grid, transforms=transforms) + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = np.atleast_2d(matrix @ B_mn.T).T zidx = np.where((modes[:, 1:] == np.array([[0, 0]])).all(axis=1))[0] if norm: @@ -3010,6 +3009,7 @@ def plot_qs_error( # noqa: 16 fxn too complex rho = np.linspace(1, 0, num=20, endpoint=False) elif np.isscalar(rho) and rho > 1: rho = np.linspace(1, 0, num=rho, endpoint=False) + rho = np.sort(rho) fig, ax = _format_ax(ax, figsize=kwargs.pop("figsize", None)) @@ -3027,119 +3027,92 @@ def plot_qs_error( # noqa: 16 fxn too complex R0 = data["R0"] B0 = np.mean(data["|B|"] * data["sqrt(g)"]) / np.mean(data["sqrt(g)"]) - f_B = np.array([]) - f_C = np.array([]) - f_T = np.array([]) - plot_data = {} - for i, r in enumerate(rho): - grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=np.array(r)) - if fB: - transforms = get_transforms( - "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz - ) - if i == 0: # only need to do this once for the first rho surface - matrix, modes, idx = ptolemy_linear_transform( - transforms["B"].basis.modes, - helicity=helicity, - NFP=transforms["B"].basis.NFP, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute( - ["|B|_mn", "B modes"], grid=grid, transforms=transforms - ) - B_mn = matrix @ data["|B|_mn"] - f_b = np.sqrt(np.sum(B_mn[idx] ** 2)) / np.sqrt(np.sum(B_mn**2)) - f_B = np.append(f_B, f_b) - if fC: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_C", grid=grid, helicity=helicity) - f_c = ( - np.mean(np.abs(data["f_C"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - / B0**3 - ) - f_C = np.append(f_C, f_c) - if fT: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data = eq.compute("f_T", grid=grid) - f_t = ( - np.mean(np.abs(data["f_T"]) * data["sqrt(g)"]) - / np.mean(data["sqrt(g)"]) - * R0**2 - / B0**4 - ) - f_T = np.append(f_T, f_t) + plot_data = {"rho": rho} - plot_data["f_B"] = f_B - plot_data["f_C"] = f_C - plot_data["f_T"] = f_T - plot_data["rho"] = rho + grid = LinearGrid(M=2 * eq.M_grid, N=2 * eq.N_grid, NFP=eq.NFP, rho=rho) + names = [] + if fB: + names += ["|B|_mn"] + transforms = get_transforms( + "|B|_mn", obj=eq, grid=grid, M_booz=M_booz, N_booz=N_booz + ) + matrix, modes, idx = ptolemy_linear_transform( + transforms["B"].basis.modes, + helicity=helicity, + NFP=transforms["B"].basis.NFP, + ) + if fC or fT: + names += ["sqrt(g)"] + if fC: + names += ["f_C"] + if fT: + names += ["f_T"] - if log: - if fB: - ax.semilogy( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.semilogy( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.semilogy( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) - else: - if fB: - ax.plot( - rho, - f_B, - ls=ls[0 % len(ls)], - c=colors[0 % len(colors)], - marker=markers[0 % len(markers)], - label=labels[0 % len(labels)], - lw=lw[0 % len(lw)], - ) - if fC: - ax.plot( - rho, - f_C, - ls=ls[1 % len(ls)], - c=colors[1 % len(colors)], - marker=markers[1 % len(markers)], - label=labels[1 % len(labels)], - lw=lw[1 % len(lw)], - ) - if fT: - ax.plot( - rho, - f_T, - ls=ls[2 % len(ls)], - c=colors[2 % len(colors)], - marker=markers[2 % len(markers)], - label=labels[2 % len(labels)], - lw=lw[2 % len(lw)], - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = eq.compute( + names, grid=grid, M_booz=M_booz, N_booz=N_booz, helicity=helicity + ) + + if fB: + B_mn = data["|B|_mn"].reshape((len(rho), -1)) + B_mn = (matrix @ B_mn.T).T + f_B = np.sqrt(np.sum(B_mn[:, idx] ** 2, axis=-1)) / np.sqrt( + np.sum(B_mn**2, axis=-1) + ) + plot_data["f_B"] = f_B + if fC: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_C = grid.meshgrid_reshape(data["f_C"], "rtz") + f_C = ( + np.mean(np.abs(f_C) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + / B0**3 + ) + plot_data["f_C"] = f_C + if fT: + sqrtg = grid.meshgrid_reshape(data["sqrt(g)"], "rtz") + f_T = grid.meshgrid_reshape(data["f_T"], "rtz") + f_T = ( + np.mean(np.abs(f_T) * sqrtg, axis=(1, 2)) + / np.mean(sqrtg, axis=(1, 2)) + * R0**2 + / B0**4 + ) + plot_data["f_T"] = f_T + + plot_op = ax.semilogy if log else ax.plot + + if fB: + plot_op( + rho, + f_B, + ls=ls[0 % len(ls)], + c=colors[0 % len(colors)], + marker=markers[0 % len(markers)], + label=labels[0 % len(labels)], + lw=lw[0 % len(lw)], + ) + if fC: + plot_op( + rho, + f_C, + ls=ls[1 % len(ls)], + c=colors[1 % len(colors)], + marker=markers[1 % len(markers)], + label=labels[1 % len(labels)], + lw=lw[1 % len(lw)], + ) + if fT: + plot_op( + rho, + f_T, + ls=ls[2 % len(ls)], + c=colors[2 % len(colors)], + marker=markers[2 % len(markers)], + label=labels[2 % len(labels)], + lw=lw[2 % len(lw)], + ) ax.set_xlabel(_AXIS_LABELS_RTZ[0], fontsize=xlabel_fontsize) if ylabel: diff --git a/desc/utils.py b/desc/utils.py index 44f744dcb6..72dd10f975 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -2,13 +2,14 @@ import operator import warnings +from functools import partial from itertools import combinations_with_replacement, permutations import numpy as np from scipy.special import factorial from termcolor import colored -from desc.backend import fori_loop, jit, jnp +from desc.backend import flatnonzero, fori_loop, jit, jnp, take class Timer: @@ -184,6 +185,13 @@ class _Indexable: def __getitem__(self, index): return index + @staticmethod + def get(stuff, axis, ndim): + slices = [slice(None)] * ndim + slices[axis] = stuff + slices = tuple(slices) + return slices + """ Helper object for building indexes for indexed update functions. @@ -684,4 +692,238 @@ def broadcast_tree(tree_in, tree_out, dtype=int): raise ValueError("trees must be nested lists of dicts") +@partial(jnp.vectorize, signature="(m),(m)->(n)", excluded={"size", "fill_value"}) +def take_mask(a, mask, /, *, size=None, fill_value=None): + """JIT compilable method to return ``a[mask][:size]`` padded by ``fill_value``. + + Parameters + ---------- + a : jnp.ndarray + The source array. + mask : jnp.ndarray + Boolean mask to index into ``a``. Should have same shape as ``a``. + size : int + Elements of ``a`` at the first size True indices of ``mask`` will be returned. + If there are fewer elements than size indicates, the returned array will be + padded with ``fill_value``. The size default is ``mask.size``. + fill_value : Any + When there are fewer than the indicated number of elements, the remaining + elements will be filled with ``fill_value``. Defaults to NaN for inexact types, + the largest negative value for signed types, the largest positive value for + unsigned types, and True for booleans. + + Returns + ------- + result : jnp.ndarray + Shape (size, ). + + """ + assert a.shape == mask.shape + idx = flatnonzero(mask, size=setdefault(size, mask.size), fill_value=mask.size) + return take( + a, + idx, + mode="fill", + fill_value=fill_value, + unique_indices=True, + indices_are_sorted=True, + ) + + +def flatten_matrix(y): + """Flatten matrix to vector.""" + return y.reshape(*y.shape[:-2], -1) + + +# TODO: Eventually remove and use numpy's stuff. +# https://github.com/numpy/numpy/issues/25805 +def atleast_nd(ndmin, ary): + """Adds dimensions to front if necessary.""" + return jnp.array(ary, ndmin=ndmin) if jnp.ndim(ary) < ndmin else ary + + PRINT_WIDTH = 60 # current longest name is BootstrapRedlConsistency with pre-text + + +def dot(a, b, axis=-1): + """Batched vector dot product. + + Parameters + ---------- + a : array-like + First array of vectors. + b : array-like + Second array of vectors. + axis : int + Axis along which vectors are stored. + + Returns + ------- + y : array-like + y = sum(a*b, axis=axis) + + """ + return jnp.sum(a * b, axis=axis, keepdims=False) + + +def cross(a, b, axis=-1): + """Batched vector cross product. + + Parameters + ---------- + a : array-like + First array of vectors. + b : array-like + Second array of vectors. + axis : int + Axis along which vectors are stored. + + Returns + ------- + y : array-like + y = a x b + + """ + return jnp.cross(a, b, axis=axis) + + +def safenorm(x, ord=None, axis=None, fill=0, threshold=0): + """Like jnp.linalg.norm, but without nan gradient at x=0. + + Parameters + ---------- + x : ndarray + Vector or array to norm. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of norm. + axis : {None, int, 2-tuple of ints}, optional + Axis to take norm along. + fill : float, ndarray, optional + Value to return where x is zero. + threshold : float >= 0 + How small is x allowed to be. + + """ + is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = jnp.linalg.norm(y, ord=ord, axis=axis) + n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero + return n + + +def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): + """Normalize a vector to unit length, but without nan gradient at x=0. + + Parameters + ---------- + x : ndarray + Vector or array to norm. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of norm. + axis : {None, int, 2-tuple of ints}, optional + Axis to take norm along. + fill : float, ndarray, optional + Value to return where x is zero. + threshold : float >= 0 + How small is x allowed to be. + + """ + is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x) + # return unit vector with equal components if norm <= threshold + return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) + + +def safediv(a, b, fill=0, threshold=0): + """Divide a/b with guards for division by zero. + + Parameters + ---------- + a, b : ndarray + Numerator and denominator. + fill : float, ndarray, optional + Value to return where b is zero. + threshold : float >= 0 + How small is b allowed to be. + """ + mask = jnp.abs(b) <= threshold + num = jnp.where(mask, fill, a) + den = jnp.where(mask, 1, b) + return num / den + + +def cumtrapz(y, x=None, dx=1.0, axis=-1, initial=None): + """Cumulatively integrate y(x) using the composite trapezoidal rule. + + Taken from SciPy, but changed NumPy references to JAX.NumPy: + https://github.com/scipy/scipy/blob/v1.10.1/scipy/integrate/_quadrature.py + + Parameters + ---------- + y : array_like + Values to integrate. + x : array_like, optional + The coordinate to integrate along. If None (default), use spacing `dx` + between consecutive elements in `y`. + dx : float, optional + Spacing between elements of `y`. Only used if `x` is None. + axis : int, optional + Specifies the axis to cumulate. Default is -1 (last axis). + initial : scalar, optional + If given, insert this value at the beginning of the returned result. + Typically, this value should be 0. Default is None, which means no + value at ``x[0]`` is returned and `res` has one element less than `y` + along the axis of integration. + + Returns + ------- + res : ndarray + The result of cumulative integration of `y` along `axis`. + If `initial` is None, the shape is such that the axis of integration + has one less value than `y`. If `initial` is given, the shape is equal + to that of `y`. + + """ + y = jnp.asarray(y) + if x is None: + d = dx + else: + x = jnp.asarray(x) + if x.ndim == 1: + d = jnp.diff(x) + # reshape to correct shape + shape = [1] * y.ndim + shape[axis] = -1 + d = d.reshape(shape) + elif len(x.shape) != len(y.shape): + raise ValueError("If given, shape of x must be 1-D or the " "same as y.") + else: + d = jnp.diff(x, axis=axis) + + if d.shape[axis] != y.shape[axis] - 1: + raise ValueError( + "If given, length of x along axis must be the " "same as y." + ) + + def tupleset(t, i, value): + l = list(t) + l[i] = value + return tuple(l) + + nd = len(y.shape) + slice1 = tupleset((slice(None),) * nd, axis, slice(1, None)) + slice2 = tupleset((slice(None),) * nd, axis, slice(None, -1)) + res = jnp.cumsum(d * (y[slice1] + y[slice2]) / 2.0, axis=axis) + + if initial is not None: + if not jnp.isscalar(initial): + raise ValueError("`initial` parameter should be a scalar.") + + shape = list(res.shape) + shape[axis] = 1 + res = jnp.concatenate( + [jnp.full(shape, initial, dtype=res.dtype), res], axis=axis + ) + + return res diff --git a/desc/vmec.py b/desc/vmec.py index fc6fc5498f..17e7bf3b30 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -25,7 +25,7 @@ from desc.objectives.utils import factorize_linear_constraints from desc.profiles import PowerSeriesProfile, SplineProfile from desc.transform import Transform -from desc.utils import Timer +from desc.utils import Timer, warnif from desc.vmec_utils import ( fourier_to_zernike, ptolemy_identity_fwd, @@ -158,7 +158,7 @@ def load( zax_cs = file.variables["zaxis_cs"][:].filled() try: rax_cs = file.variables["raxis_cs"][:].filled() - rax_cc = file.variables["zaxis_cc"][:].filled() + zax_cc = file.variables["zaxis_cc"][:].filled() except KeyError: rax_cs = np.zeros_like(rax_cc) zax_cc = np.zeros_like(zax_cs) @@ -208,7 +208,9 @@ def load( return eq @classmethod - def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify + def save( # noqa: C901 - FIXME - simplify + cls, eq, path, surfs=128, verbose=1, M_nyq=None, N_nyq=None + ): """Save an Equilibrium as a netCDF file in the VMEC format. Parameters @@ -224,6 +226,10 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify * 0: no output * 1: status of quantities computed * 2: as above plus timing information + M_nyq, N_nyq: int + The max poloidal and toroidal modenumber to use in the + Nyquist spectrum that the derived quantities are Fourier + fit with. Defaults to M+4 and N+2. Returns ------- @@ -242,8 +248,14 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify NFP = eq.NFP M = eq.M N = eq.N - M_nyq = M + 4 - N_nyq = N + 2 if N > 0 else 0 + M_nyq = M + 4 if M_nyq is None else M_nyq + warnif( + N_nyq is not None and int(N) == 0, + UserWarning, + "Passed in N_nyq but equilibrium is axisymmetric, setting N_nyq to zero", + ) + N_nyq = N + 2 if N_nyq is None else N_nyq + N_nyq = 0 if int(N) == 0 else N_nyq # VMEC radial coordinate: s = rho^2 = Psi / Psi(LCFS) s_full = np.linspace(0, 1, surfs) @@ -807,6 +819,14 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify lmnc.long_name = "cos(m*t-n*p) component of lambda, on half mesh" lmnc.units = "rad" l1 = np.ones_like(eq.L_lmn) + # should negate lambda coefs bc theta_DESC + lambda = theta_PEST, + # since we are reversing the theta direction (and the theta_PEST direction), + # so -theta_PEST = -theta_DESC - lambda, so the negative of lambda is what + # should be saved, so that would be negating all of eq.L_lmn + # BUT since we are also reversing the poloidal angle direction, which + # would negate only the coeffs of L_lmn corresponding to m<0 + # (sin theta modes in DESC), the effective result is to only + # negate the cos(theta) (m>0) lambda modes l1[eq.L_basis.modes[:, 1] >= 0] *= -1 m, n, x_mn = zernike_to_fourier(l1 * eq.L_lmn, basis=eq.L_basis, rho=r_half) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) @@ -823,7 +843,7 @@ def save(cls, eq, path, surfs=128, verbose=1): # noqa: C901 - FIXME - simplify sin_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym="sin") cos_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym="cos") - full_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym=None) + full_basis = DoubleFourierSeries(M=M_nyq, N=N_nyq, NFP=NFP, sym=False) if eq.sym: sin_transform = Transform( grid=grid_lcfs, basis=sin_basis, build=False, build_pinv=True @@ -932,7 +952,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bmnc[0, :] = 0 bmnc[1:, :] = c @@ -975,7 +995,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bsupumnc[0, :] = 0 bsupumnc[1:, :] = -c # negative sign for negative Jacobian @@ -1018,7 +1038,7 @@ def fullfit(x): if eq.sym: x_mn[i, :] = cosfit(data[i, :]) else: - x_mn[i, :] = full_transform.fit(data[i, :]) + x_mn[i, :] = fullfit(data[i, :]) xm, xn, s, c = ptolemy_identity_rev(m, n, x_mn) bsupvmnc[0, :] = 0 bsupvmnc[1:, :] = c @@ -1641,13 +1661,15 @@ def vmec_interpolate(Cmn, Smn, xm, xn, theta, phi, s=None, si=None, sym=True): return C + S @classmethod - def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): + def compute_theta_coords( + cls, lmns, xm, xn, s, theta_star, zeta, si=None, lmnc=None + ): """Find theta such that theta + lambda(theta) == theta_star. Parameters ---------- lmns : array-like - fourier coefficients for lambda + sin(mt-nz) Fourier coefficients for lambda xm : array-like poloidal mode numbers xn : array-like @@ -1662,6 +1684,8 @@ def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): si : ndarray values of radial coordinates where lmns are defined. Defaults to linearly spaced on half grid between (0,1) + lmnc : array-like, optional + cos(mt-nz) Fourier coefficients for lambda Returns ------- @@ -1672,19 +1696,30 @@ def compute_theta_coords(cls, lmns, xm, xn, s, theta_star, zeta, si=None): if si is None: si = np.linspace(0, 1, lmns.shape[0]) si[1:] = si[0:-1] + 0.5 / (lmns.shape[0] - 1) - lmbda_mn = interpolate.CubicSpline(si, lmns) + lmbda_mns = interpolate.CubicSpline(si, lmns) + if lmnc is None: + lmbda_mnc = lambda s: 0 + else: + lmbda_mnc = interpolate.CubicSpline(si, lmnc) # Note: theta* (also known as vartheta) is the poloidal straight field line # angle in PEST-like flux coordinates def root_fun(theta): lmbda = np.sum( - lmbda_mn(s) + lmbda_mns(s) * np.sin( xm[np.newaxis] * theta[:, np.newaxis] - xn[np.newaxis] * zeta[:, np.newaxis] ), axis=-1, + ) + np.sum( + lmbda_mnc(s) + * np.cos( + xm[np.newaxis] * theta[:, np.newaxis] + - xn[np.newaxis] * zeta[:, np.newaxis] + ), + axis=-1, ) theta_star_k = theta + lmbda # theta* = theta + lambda err = theta_star - theta_star_k # FIXME: mod by 2pi @@ -1782,6 +1817,8 @@ def compute_coord_surfaces(cls, equil, vmec_data, Nr=10, Nt=8, Nz=None, **kwargs t_nodes = t_grid.nodes t_nodes[:, 0] = t_nodes[:, 0] ** 2 + sym = "lmnc" not in vmec_data.keys() + v_nodes = cls.compute_theta_coords( vmec_data["lmns"], vmec_data["xm"], @@ -1789,29 +1826,71 @@ def compute_coord_surfaces(cls, equil, vmec_data, Nr=10, Nt=8, Nz=None, **kwargs t_nodes[:, 0], t_nodes[:, 1], t_nodes[:, 2], + lmnc=vmec_data["lmnc"] if not sym else None, ) t_nodes[:, 1] = v_nodes + if sym: + Rr_vmec, Zr_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + ) - Rr_vmec, Zr_vmec = cls.vmec_interpolate( - vmec_data["rmnc"], - vmec_data["zmns"], - vmec_data["xm"], - vmec_data["xn"], - theta=r_nodes[:, 1], - phi=r_nodes[:, 2], - s=r_nodes[:, 0], - ) - - Rv_vmec, Zv_vmec = cls.vmec_interpolate( - vmec_data["rmnc"], - vmec_data["zmns"], - vmec_data["xm"], - vmec_data["xn"], - theta=t_nodes[:, 1], - phi=t_nodes[:, 2], - s=t_nodes[:, 0], - ) + Rv_vmec, Zv_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + ) + else: + Rr_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["rmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + sym=False, + ) + Zr_vmec = cls.vmec_interpolate( + vmec_data["zmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=r_nodes[:, 1], + phi=r_nodes[:, 2], + s=r_nodes[:, 0], + sym=False, + ) + Rv_vmec = cls.vmec_interpolate( + vmec_data["rmnc"], + vmec_data["rmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + sym=False, + ) + Zv_vmec = cls.vmec_interpolate( + vmec_data["zmnc"], + vmec_data["zmns"], + vmec_data["xm"], + vmec_data["xn"], + theta=t_nodes[:, 1], + phi=t_nodes[:, 2], + s=t_nodes[:, 0], + sym=False, + ) coords = { "Rr_desc": Rr_desc, diff --git a/devtools/dev-requirements_conda.yml b/devtools/dev-requirements_conda.yml index 5f5076a57e..5aa77689dd 100644 --- a/devtools/dev-requirements_conda.yml +++ b/devtools/dev-requirements_conda.yml @@ -15,9 +15,10 @@ dependencies: - pip: # Conda only parses a single list of pip requirements. # If two pip lists are given, all but the last list is skipped. - - interpax + - interpax >= 0.3.3 - jax[cpu] >= 0.3.2, < 0.5.0 - nvgpu + - orthax - plotly >= 5.16, < 6.0 - pylatexenc >= 2.0, < 3.0 # building the docs diff --git a/docs/api.rst b/docs/api.rst index c5147d4077..02c7cc8c73 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -41,6 +41,7 @@ Compatibility desc.compat.ensure_positive_jacobian desc.compat.flip_helicity + desc.compat.flip_theta desc.compat.rescale Continuation diff --git a/docs/api_equilibrium.rst b/docs/api_equilibrium.rst index 2adc6296c8..5b349d79e5 100644 --- a/docs/api_equilibrium.rst +++ b/docs/api_equilibrium.rst @@ -66,4 +66,5 @@ equilibria to a given size and/or field strength. desc.compat.ensure_positive_jacobian desc.compat.flip_helicity + desc.compat.flip_theta desc.compat.rescale diff --git a/requirements.txt b/requirements.txt index a667a2a2db..fa5b86bba9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ colorama h5py >= 3.0.0, < 4.0 -interpax +interpax >= 0.3.3 jax[cpu] >= 0.3.2, < 0.5.0 matplotlib >= 3.5.0, < 4.0.0 mpmath >= 1.0.0, < 2.0 netcdf4 >= 1.5.4, < 2.0 numpy >= 1.20.0, < 2.0.0 nvgpu +orthax plotly >= 5.16, < 6.0 psutil pylatexenc >= 2.0, < 3.0 diff --git a/requirements_conda.yml b/requirements_conda.yml index a151388648..da2996429a 100644 --- a/requirements_conda.yml +++ b/requirements_conda.yml @@ -14,8 +14,9 @@ dependencies: - pip: # Conda only parses a single list of pip requirements. # If two pip lists are given, all but the last list is skipped. - - interpax + - interpax >= 0.3.3 - jax[cpu] >= 0.3.2, < 0.5.0 - nvgpu + - orthax - plotly >= 5.16, < 6.0 - pylatexenc >= 2.0, < 3.0 diff --git a/tests/baseline/test_binormal_drift_bounce1d.png b/tests/baseline/test_binormal_drift_bounce1d.png new file mode 100644 index 0000000000..95339623df Binary files /dev/null and b/tests/baseline/test_binormal_drift_bounce1d.png differ diff --git a/tests/baseline/test_bounce1d_checks.png b/tests/baseline/test_bounce1d_checks.png new file mode 100644 index 0000000000..51e5a4d94f Binary files /dev/null and b/tests/baseline/test_bounce1d_checks.png differ diff --git a/tests/benchmarks/compare_bench_results.py b/tests/benchmarks/compare_bench_results.py index 09fc580e22..ab56816153 100644 --- a/tests/benchmarks/compare_bench_results.py +++ b/tests/benchmarks/compare_bench_results.py @@ -8,60 +8,87 @@ cwd = os.getcwd() data = {} -master_idx = 0 -latest_idx = 0 +master_idx = [] +latest_idx = [] commit_ind = 0 -for diret in os.walk(cwd + "/compare_results"): - files = diret[2] - timing_file_exists = False - - for filename in files: - if filename.find("json") != -1: # check if json output file is present - try: - filepath = os.path.join(diret[0], filename) - with open(filepath) as f: - print(filepath) - curr_data = json.load(f) - commit_id = curr_data["commit_info"]["id"][0:7] - data[commit_id] = curr_data - if filepath.find("master") != -1: - master_idx = commit_ind - elif filepath.find("Latest_Commit") != -1: - latest_idx = commit_ind - commit_ind += 1 - except Exception as e: - print(e) - continue - +folder_names = [] + +for root1, dirs1, files1 in os.walk(cwd): + for dir_name in dirs1: + if dir_name == "compare_results" or dir_name.startswith("benchmark_artifact"): + print("Including folder: " + dir_name) + # "compare_results" is the folder containing the benchmark results from this + # job "benchmark_artifact" is the folder containing the benchmark results + # from other jobs if in future we change the Python version of the + # benchmarks, we will need to update this + # "/Linux-CPython--64bit" + files2walk = ( + os.walk(cwd + "/" + dir_name) + if dir_name == "compare_results" + else os.walk(cwd + "/" + dir_name + "/Linux-CPython-3.9-64bit") + ) + for root, dirs, files in files2walk: + for filename in files: + if ( + filename.find("json") != -1 + ): # check if json output file is present + try: + filepath = os.path.join(root, filename) + with open(filepath) as f: + curr_data = json.load(f) + commit_id = curr_data["commit_info"]["id"][0:7] + data[commit_ind] = curr_data["benchmarks"] + if filepath.find("master") != -1: + master_idx.append(commit_ind) + elif filepath.find("Latest_Commit") != -1: + latest_idx.append(commit_ind) + commit_ind += 1 + except Exception as e: + print(e) + continue # need arrays of size [ num benchmarks x num commits ] # one for mean one for stddev # number of benchmark cases -num_benchmarks = len(data[list(data.keys())[0]]["benchmarks"]) -num_commits = len(list(data.keys())) +num_benchmarks = 0 +# sum number of benchmarks splitted into different jobs +for split in master_idx: + num_benchmarks += len(data[split]) +num_commits = 2 + times = np.zeros([num_benchmarks, num_commits]) stddevs = np.zeros([num_benchmarks, num_commits]) commit_ids = [] test_names = [None] * num_benchmarks -for id_num, commit_id in enumerate(data.keys()): - commit_ids.append(commit_id) - for i, test in enumerate(data[commit_id]["benchmarks"]): +id_num = 0 +for i in master_idx: + for test in data[i]: t_mean = test["stats"]["median"] t_stddev = test["stats"]["iqr"] - times[i, id_num] = t_mean - stddevs[i, id_num] = t_stddev - test_names[i] = test["name"] - + times[id_num, 0] = t_mean + stddevs[id_num, 0] = t_stddev + test_names[id_num] = test["name"] + id_num = id_num + 1 + +id_num = 0 +for i in latest_idx: + for test in data[i]: + t_mean = test["stats"]["median"] + t_stddev = test["stats"]["iqr"] + times[id_num, 1] = t_mean + stddevs[id_num, 1] = t_stddev + test_names[id_num] = test["name"] + id_num = id_num + 1 # we say a slowdown/speedup has occurred if the mean time difference is greater than # n_sigma * (stdev of time delta) significance = 3 # n_sigmas of normal distribution, ie z score of 3 colors = [" "] * num_benchmarks # g if faster, w if similar, r if slower -delta_times_ms = times[:, latest_idx] - times[:, master_idx] -delta_stds_ms = np.sqrt(stddevs[:, latest_idx] ** 2 + stddevs[:, master_idx] ** 2) -delta_times_pct = delta_times_ms / times[:, master_idx] * 100 -delta_stds_pct = delta_stds_ms / times[:, master_idx] * 100 +delta_times_ms = times[:, 1] - times[:, 0] +delta_stds_ms = np.sqrt(stddevs[:, 1] ** 2 + stddevs[:, 0] ** 2) +delta_times_pct = delta_times_ms / times[:, 0] * 100 +delta_stds_pct = delta_stds_ms / times[:, 0] * 100 for i, (pct, spct) in enumerate(zip(delta_times_pct, delta_stds_pct)): if pct > 0 and pct > significance * spct: colors[i] = "-" # this will make the line red @@ -72,8 +99,6 @@ # now make the commit message, save as a txt file # benchmark_name dt(%) dt(s) t_new(s) t_old(s) -print(latest_idx) -print(master_idx) commit_msg_lines = [ "```diff", f"| {'benchmark_name':^38} | {'dt(%)':^22} | {'dt(s)':^22} |" @@ -88,8 +113,8 @@ line = f"{colors[i]:>1}{test_names[i]:<39} |" line += f" {f'{dpct:+6.2f} +/- {sdpct:4.2f}':^22} |" line += f" {f'{dt:+.2e} +/- {sdt:.2e}':^22} |" - line += f" {f'{times[i, latest_idx]:.2e} +/- {stddevs[i, latest_idx]:.1e}':^22} |" - line += f" {f'{times[i, master_idx]:.2e} +/- {stddevs[i, master_idx]:.1e}':^22} |" + line += f" {f'{times[i, 1]:.2e} +/- {stddevs[i, 1]:.1e}':^22} |" + line += f" {f'{times[i, 0]:.2e} +/- {stddevs[i, 0]:.1e}':^22} |" commit_msg_lines.append(line) diff --git a/tests/conftest.py b/tests/conftest.py index 873d2c3f0a..ccab0e07a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -335,3 +335,22 @@ def VMEC_save(SOLOVEV, tmpdir_factory): ) desc = Dataset(str(SOLOVEV["desc_nc_path"]), mode="r") return vmec, desc + + +@pytest.fixture(scope="session") +def VMEC_save_asym(tmpdir_factory): + """Save an asymmetric equilibrium in VMEC netcdf format for comparison.""" + tmpdir = tmpdir_factory.mktemp("asym_wout") + filename = tmpdir.join("wout_HELIO_asym_desc.nc") + vmec = Dataset("./tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc", mode="r") + eq = Equilibrium.load("./tests/inputs/HELIO_asym.h5") + VMECIO.save( + eq, + filename, + surfs=vmec.variables["ns"][:], + verbose=0, + M_nyq=round(np.max(vmec.variables["xm_nyq"][:])), + N_nyq=round(np.max(vmec.variables["xn_nyq"][:]) / eq.NFP), + ) + desc = Dataset(filename, mode="r") + return vmec, desc, eq diff --git a/tests/inputs/HELIO_asym.h5 b/tests/inputs/HELIO_asym.h5 new file mode 100644 index 0000000000..c66a6cb100 Binary files /dev/null and b/tests/inputs/HELIO_asym.h5 differ diff --git a/tests/inputs/low-beta-shifted-circle.h5 b/tests/inputs/low-beta-shifted-circle.h5 index 31f4fab80b..dd75392a09 100644 Binary files a/tests/inputs/low-beta-shifted-circle.h5 and b/tests/inputs/low-beta-shifted-circle.h5 differ diff --git a/tests/inputs/master_compute_data_rpz.pkl b/tests/inputs/master_compute_data_rpz.pkl index eef5bbf2f6..d72778328e 100644 Binary files a/tests/inputs/master_compute_data_rpz.pkl and b/tests/inputs/master_compute_data_rpz.pkl differ diff --git a/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc b/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc new file mode 100644 index 0000000000..cc51c535a3 Binary files /dev/null and b/tests/inputs/wout_HELIOTRON_asym_NTHETA50_NZETA100.nc differ diff --git a/tests/test_axis_limits.py b/tests/test_axis_limits.py index 8c847ef3a0..e204dc423d 100644 --- a/tests/test_axis_limits.py +++ b/tests/test_axis_limits.py @@ -12,12 +12,13 @@ import pytest from desc.compute import data_index -from desc.compute.utils import _grow_seeds, dot +from desc.compute.utils import _grow_seeds from desc.equilibrium import Equilibrium from desc.examples import get from desc.grid import LinearGrid from desc.integrals import surface_integrals_map from desc.objectives import GenericObjective, ObjectiveFunction +from desc.utils import dot # Unless mentioned in the source code of the compute function, the assumptions # made to compute the magnetic axis limit can be reduced to assuming that these @@ -63,7 +64,6 @@ "gbdrift", "cvdrift", "grad(alpha)", - "cvdrift0", "|e^helical|", "|grad(theta)|", " Redl", # may not exist for all configurations @@ -94,7 +94,6 @@ "K_vc", # only defined on surface "iota_num_rrr", "iota_den_rrr", - "cvdrift0", } @@ -135,6 +134,14 @@ def _skip_this(eq, name): or (eq.anisotropy is None and "beta_a" in name) or (eq.pressure is not None and " Redl" in name) or (eq.current is None and "iota_num" in name) + # These quantities require a coordinate mapping to compute and special grids, so + # it's not economical to test their axis limits here. Instead, a grid that + # includes the axis should be used in existing unit tests for these quantities. + or bool( + data_index["desc.equilibrium.equilibrium.Equilibrium"][name][ + "source_grid_requirement" + ] + ) ) @@ -388,3 +395,4 @@ def test_reverse_mode_ad_axis(name): obj.build(verbose=0) g = obj.grad(obj.x()) assert not np.any(np.isnan(g)) + print(np.count_nonzero(g), name) diff --git a/tests/test_coils.py b/tests/test_coils.py index 704ad5f761..71127660da 100644 --- a/tests/test_coils.py +++ b/tests/test_coils.py @@ -4,7 +4,9 @@ import numpy as np import pytest +import scipy +from desc.backend import jnp from desc.coils import ( CoilSet, FourierPlanarCoil, @@ -13,12 +15,13 @@ MixedCoilSet, SplineXYZCoil, ) -from desc.compute import get_params, get_transforms, xyz2rpz, xyz2rpz_vec +from desc.compute import get_params, get_transforms, rpz2xyz, xyz2rpz, xyz2rpz_vec from desc.examples import get -from desc.geometry import FourierRZCurve, FourierRZToroidalSurface +from desc.geometry import FourierRZCurve, FourierRZToroidalSurface, FourierXYZCurve from desc.grid import Grid, LinearGrid from desc.io import load from desc.magnetic_fields import SumMagneticField, VerticalMagneticField +from desc.utils import dot class TestCoil: @@ -149,6 +152,198 @@ def test_biot_savart_all_coils(self): B_true_rpz_phi, B_rpz, rtol=1e-3, atol=1e-10, err_msg="Using FourierRZCoil" ) + @pytest.mark.unit + def test_biot_savart_vector_potential_all_coils(self): + """Test biot-savart vec potential implementation against analytic formula.""" + coil_grid = LinearGrid(zeta=100, endpoint=False) + + R = 2 + y = 1 + I = 1e7 + + A_true = np.atleast_2d([0, 0, 0]) + grid_xyz = np.atleast_2d([10, y, 0]) + grid_rpz = xyz2rpz(grid_xyz) + + def test(coil, grid_xyz, grid_rpz): + A_xyz = coil.compute_magnetic_vector_potential( + grid_xyz, basis="xyz", source_grid=coil_grid + ) + A_rpz = coil.compute_magnetic_vector_potential( + grid_rpz, basis="rpz", source_grid=coil_grid + ) + np.testing.assert_allclose( + A_true, A_xyz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true, A_rpz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true, A_rpz, rtol=1e-3, atol=1e-10, err_msg=f"Using {coil}" + ) + + # FourierXYZCoil + coil = FourierXYZCoil(I) + test(coil, grid_xyz, grid_rpz) + + # SplineXYZCoil + x = coil.compute("x", grid=coil_grid, basis="xyz")["x"] + coil = SplineXYZCoil(I, X=x[:, 0], Y=x[:, 1], Z=x[:, 2]) + test(coil, grid_xyz, grid_rpz) + + # FourierPlanarCoil + coil = FourierPlanarCoil(I) + test(coil, grid_xyz, grid_rpz) + + grid_xyz = np.atleast_2d([0, 0, y]) + grid_rpz = xyz2rpz(grid_xyz) + + # FourierRZCoil + coil = FourierRZCoil(I, R_n=np.array([R]), modes_R=np.array([0])) + test(coil, grid_xyz, grid_rpz) + # test in a CoilSet + coil2 = CoilSet(coil) + test(coil2, grid_xyz, grid_rpz) + # test in a MixedCoilSet + coil3 = MixedCoilSet(coil2, coil, check_intersection=False) + coil3[1].current = 0 + test(coil3, grid_xyz, grid_rpz) + + @pytest.mark.unit + def test_biot_savart_vector_potential_integral_all_coils(self): + """Test analytic expression of flux integral for all coils.""" + # taken from analytic benchmark in + # "A Magnetic Diagnostic Code for 3D Fusion Equilibria", Lazerson 2013 + # find flux for concentric loops of varying radii to a circular coil + + coil_grid = LinearGrid(zeta=1000, endpoint=False) + + R = 1 + I = 1e7 + + # analytic eqn for "A_phi" (phi is in dl direction for loop) + def _A_analytic(r): + # elliptic integral arguments must be k^2, not k, + # error in original paper and apparently in Jackson EM book too. + theta = np.pi / 2 + arg = R**2 + r**2 + 2 * r * R * np.sin(theta) + term_1_num = 4.0e-7 * I * R + term_1_den = np.sqrt(arg) + k_sqd = 4 * r * R * np.sin(theta) / arg + term_2_num = (2 - k_sqd) * scipy.special.ellipk( + k_sqd + ) - 2 * scipy.special.ellipe(k_sqd) + term_2_den = k_sqd + return term_1_num * term_2_num / term_1_den / term_2_den + + # we only evaluate it at theta=np.pi/2 (b/c it is in spherical coords) + rs = np.linspace(0.1, 3, 10, endpoint=True) + N = 200 + curve_grid = LinearGrid(zeta=N) + + def test( + coil, grid_xyz, grid_rpz, A_true_rpz, correct_flux, rtol=1e-10, atol=1e-12 + ): + """Test that we compute the correct flux for the given coil.""" + A_xyz = coil.compute_magnetic_vector_potential( + grid_xyz, basis="xyz", source_grid=coil_grid + ) + A_rpz = coil.compute_magnetic_vector_potential( + grid_rpz, basis="rpz", source_grid=coil_grid + ) + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose( + correct_flux, flux_xyz, rtol=rtol, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + correct_flux, flux_rpz, rtol=rtol, err_msg=f"Using {coil}" + ) + np.testing.assert_allclose( + A_true_rpz, + A_rpz, + rtol=rtol, + atol=atol, + err_msg=f"Using {coil}", + ) + + for r in rs: + # A_phi is constant around the loop (no phi dependence) + A_true_phi = _A_analytic(r) * np.ones(N) + A_true_rpz = np.vstack( + (np.zeros_like(A_true_phi), A_true_phi, np.zeros_like(A_true_phi)) + ).T + correct_flux = np.sum(r * A_true_phi * 2 * np.pi / N) + + curve = FourierXYZCurve( + X_n=[-r, 0, 0], Y_n=[0, 0, r], Z_n=[0, 0, 0] + ) # flux loop to integrate A over + + curve_data = curve.compute(["x", "x_s"], grid=curve_grid, basis="xyz") + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + grid_rpz = np.vstack( + [ + curve_data_rpz["x"][:, 0], + curve_data_rpz["x"][:, 1], + curve_data_rpz["x"][:, 2], + ] + ).T + grid_xyz = rpz2xyz(grid_rpz) + # FourierXYZCoil + coil = FourierXYZCoil(I, X_n=[-R, 0, 0], Y_n=[0, 0, R], Z_n=[0, 0, 0]) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + + # SplineXYZCoil + x = coil.compute("x", grid=coil_grid, basis="xyz")["x"] + coil = SplineXYZCoil(I, X=x[:, 0], Y=x[:, 1], Z=x[:, 2]) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-4, + atol=1e-12, + ) + + # FourierPlanarCoil + coil = FourierPlanarCoil(I, center=[0, 0, 0], normal=[0, 0, -1], r_n=R) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + + # FourierRZCoil + coil = FourierRZCoil(I, R_n=np.array([R]), modes_R=np.array([0])) + test( + coil, + grid_xyz, + grid_rpz, + A_true_rpz, + correct_flux, + rtol=1e-8, + atol=1e-12, + ) + @pytest.mark.unit def test_properties(self): """Test getting/setting attributes for Coil class.""" diff --git a/tests/test_compat.py b/tests/test_compat.py index 7a13387fce..ca1a0c55da 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from desc.compat import flip_helicity, rescale +from desc.compat import flip_helicity, flip_theta, rescale from desc.examples import get from desc.grid import Grid, LinearGrid, QuadratureGrid @@ -47,7 +47,6 @@ def test_flip_helicity_axisym(): @pytest.mark.unit -@pytest.mark.solve def test_flip_helicity_iota(): """Test flip_helicity on an Equilibrium with an iota profile.""" eq = get("HELIOTRON") @@ -90,7 +89,6 @@ def test_flip_helicity_iota(): @pytest.mark.unit -@pytest.mark.solve def test_flip_helicity_current(): """Test flip_helicity on an Equilibrium with a current profile.""" eq = get("HSX") @@ -135,6 +133,87 @@ def test_flip_helicity_current(): np.testing.assert_allclose(data_old["f_C"], data_flip["f_C"], atol=1e-8) +@pytest.mark.unit +def test_flip_theta_axisym(): + """Test flip_theta on an axisymmetric Equilibrium.""" + eq = get("DSHAPE") + + grid = LinearGrid( + L=eq.L_grid, + theta=2 * eq.M_grid, + N=eq.N_grid, + NFP=eq.NFP, + sym=eq.sym, + axis=False, + ) + data_keys = ["current", "|F|", "D_Mercier"] + + data_old = eq.compute(data_keys, grid=grid) + eq = flip_theta(eq) + data_new = eq.compute(data_keys, grid=grid) + + # check that Jacobian and force balance did not change + np.testing.assert_allclose( + data_old["sqrt(g)"].reshape((grid.num_rho, grid.num_theta)), + np.fliplr(data_new["sqrt(g)"].reshape((grid.num_rho, grid.num_theta))), + ) + np.testing.assert_allclose( + data_old["|F|"].reshape((grid.num_rho, grid.num_theta)), + np.fliplr(data_new["|F|"].reshape((grid.num_rho, grid.num_theta))), + rtol=2e-5, + ) + + # check that profiles did not change + np.testing.assert_allclose( + grid.compress(data_old["iota"]), grid.compress(data_new["iota"]) + ) + np.testing.assert_allclose( + grid.compress(data_old["current"]), grid.compress(data_new["current"]) + ) + np.testing.assert_allclose( + grid.compress(data_old["D_Mercier"]), grid.compress(data_new["D_Mercier"]) + ) + + +@pytest.mark.unit +def test_flip_theta_nonaxisym(): + """Test flip_theta on a non-axisymmetric Equilibrium.""" + eq = get("HSX") + + grid = QuadratureGrid(L=eq.L_grid, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + nodes = grid.nodes.copy() + nodes[:, 1] = np.mod(nodes[:, 1] + np.pi, 2 * np.pi) + grid_flip = Grid(nodes) # grid with flipped theta values + data_keys = ["current", "|F|", "D_Mercier", "f_C"] + + data_old = eq.compute(data_keys, grid=grid, helicity=(1, eq.NFP)) + eq = flip_theta(eq) + data_new = eq.compute(data_keys, grid=grid_flip, helicity=(1, eq.NFP)) + + # check that basis vectors did not change + np.testing.assert_allclose(data_old["e_rho"], data_new["e_rho"], atol=1e-15) + np.testing.assert_allclose(data_old["e_theta"], data_new["e_theta"], atol=1e-15) + np.testing.assert_allclose(data_old["e^zeta"], data_new["e^zeta"], atol=1e-15) + + # check that Jacobian is still positive + np.testing.assert_array_less(0, grid.compress(data_new["sqrt(g)"])) + + # check that stability did not change + np.testing.assert_allclose( + grid.compress(data_old["D_Mercier"]), + grid.compress(data_new["D_Mercier"]), + rtol=2e-2, + ) + + # check that the total force balance error on each surface did not change + # (equivalent collocation points now corresond to theta + pi) + np.testing.assert_allclose(data_old["|F|"], data_new["|F|"], rtol=1e-3) + + # check that the QH helicity did not change + # (equivalent collocation points now corresond to theta + pi) + np.testing.assert_allclose(data_old["f_C"], data_new["f_C"], atol=1e-8) + + @pytest.mark.unit def test_rescale(): """Test rescale function.""" diff --git a/tests/test_compute_funs.py b/tests/test_compute_funs.py index 43c3d81449..9a9216cc8e 100644 --- a/tests/test_compute_funs.py +++ b/tests/test_compute_funs.py @@ -5,12 +5,12 @@ from scipy.signal import convolve2d from desc.compute import rpz2xyz_vec -from desc.compute.utils import dot from desc.equilibrium import Equilibrium from desc.examples import get from desc.geometry import FourierRZToroidalSurface from desc.grid import LinearGrid from desc.io import load +from desc.utils import dot # convolve kernel is reverse of FD coeffs FD_COEF_1_2 = np.array([-1 / 2, 0, 1 / 2])[::-1] @@ -1134,6 +1134,24 @@ def test_boozer_transform(): ) +@pytest.mark.unit +def test_boozer_transform_multiple_surfaces(): + """Test that computing over multiple surfaces is the same as over 1 at a time.""" + eq = get("HELIOTRON") + grid1 = LinearGrid(rho=0.6, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + grid2 = LinearGrid(rho=0.8, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + grid3 = LinearGrid(rho=np.array([0.6, 0.8]), M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP) + data1 = eq.compute("|B|_mn", grid=grid1, M_booz=eq.M, N_booz=eq.N) + data2 = eq.compute("|B|_mn", grid=grid2, M_booz=eq.M, N_booz=eq.N) + data3 = eq.compute("|B|_mn", grid=grid3, M_booz=eq.M, N_booz=eq.N) + np.testing.assert_allclose( + data1["|B|_mn"], data3["|B|_mn"].reshape((grid3.num_rho, -1))[0] + ) + np.testing.assert_allclose( + data2["|B|_mn"], data3["|B|_mn"].reshape((grid3.num_rho, -1))[1] + ) + + @pytest.mark.unit def test_compute_averages(): """Test that computing averages uses the correct grid.""" diff --git a/tests/test_examples.py b/tests/test_examples.py index 6bae9dd16b..a03e364c35 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1077,9 +1077,12 @@ def test_freeb_axisym(): -6.588300858364606e04, -3.560589388468855e05, ] - ext_field = SplineMagneticField.from_mgrid( - r"tests/inputs/mgrid_solovev.nc", extcur=extcur - ) + with pytest.warns(UserWarning, match="Vector potential"): + # the mgrid file does not have the vector potential + # saved so we will ignore the thrown warning + ext_field = SplineMagneticField.from_mgrid( + r"tests/inputs/mgrid_solovev.nc", extcur=extcur + ) pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) diff --git a/tests/test_grid.py b/tests/test_grid.py index 051ba1b89f..929a1bbe57 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -793,26 +793,23 @@ def test_meshgrid_reshape(self): zeta = np.linspace(0, 6 * np.pi, 5) grid = Grid.create_meshgrid([rho, alpha, zeta], coordinates="raz") r, a, z = grid.nodes.T - r = grid.meshgrid_reshape(r, "raz") - a = grid.meshgrid_reshape(a, "raz") - z = grid.meshgrid_reshape(z, "raz") # functions of zeta should separate along first two axes # since those are contiguous, this should work - f = z.reshape(-1, zeta.size) + f = grid.meshgrid_reshape(z, "raz").reshape(-1, zeta.size) for i in range(1, f.shape[0]): np.testing.assert_allclose(f[i - 1], f[i]) # likewise for rho - f = r.reshape(rho.size, -1) + f = grid.meshgrid_reshape(r, "raz").reshape(rho.size, -1) for i in range(1, f.shape[-1]): np.testing.assert_allclose(f[:, i - 1], f[:, i]) # test reshaping result won't mix data - f = (a**2 + z).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(a**2 + z, "raz") for i in range(1, f.shape[0]): np.testing.assert_allclose(f[i - 1], f[i]) - f = (r**2 + z).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(r**2 + z, "raz") for i in range(1, f.shape[1]): np.testing.assert_allclose(f[:, i - 1], f[:, i]) - f = (r**2 + a).reshape(rho.size, alpha.size, zeta.size) + f = grid.meshgrid_reshape(r**2 + a, "raz") for i in range(1, f.shape[-1]): np.testing.assert_allclose(f[..., i - 1], f[..., i]) diff --git a/tests/test_integrals.py b/tests/test_integrals.py index b15b019283..c909516e00 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1,13 +1,26 @@ """Test integration algorithms.""" +from functools import partial + import numpy as np import pytest - +from jax import grad +from matplotlib import pyplot as plt +from numpy.polynomial.chebyshev import chebgauss, chebweight +from numpy.polynomial.legendre import leggauss +from scipy import integrate +from scipy.interpolate import CubicHermiteSpline +from scipy.special import ellipe, ellipkm1, roots_chebyu +from tests.test_plotting import tol_1d + +from desc.backend import jnp from desc.basis import FourierZernikeBasis from desc.equilibrium import Equilibrium +from desc.equilibrium.coords import get_rtz_grid from desc.examples import get -from desc.grid import ConcentricGrid, LinearGrid, QuadratureGrid +from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid from desc.integrals import ( + Bounce1D, DFTInterpolator, FFTInterpolator, line_integrals, @@ -20,9 +33,26 @@ surface_variance, virtual_casing_biot_savart, ) +from desc.integrals.bounce_utils import ( + _get_extrema, + bounce_points, + get_pitch_inv, + interp_to_argmin, + interp_to_argmin_hard, +) +from desc.integrals.quad_utils import ( + automorphism_sin, + bijection_from_disc, + get_quadrature, + grad_automorphism_sin, + grad_bijection_from_disc, + leggauss_lob, + tanh_sinh, +) from desc.integrals.singularities import _get_quadrature_nodes from desc.integrals.surface_integral import _get_grid_surface from desc.transform import Transform +from desc.utils import dot, safediv class TestSurfaceIntegral: @@ -688,3 +718,746 @@ def test_biest_interpolators(self): g2 = interp2(f(source_theta, source_zeta), i) np.testing.assert_allclose(g1, g2) np.testing.assert_allclose(g1, ff) + + +class TestBounce1DPoints: + """Test that bounce points are computed correctly.""" + + @staticmethod + def filter(z1, z2): + """Remove bounce points whose integrals have zero measure.""" + mask = (z1 - z2) != 0.0 + return z1[mask], z2[mask] + + @pytest.mark.unit + def test_z1_first(self): + """Case where straight line through first two intersects is in epigraph.""" + start = np.pi / 3 + end = 6 * np.pi + knots = np.linspace(start, end, 5) + B = CubicHermiteSpline(knots, np.cos(knots), -np.sin(knots)) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, knots, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + np.testing.assert_allclose(z1, intersect[0::2]) + np.testing.assert_allclose(z2, intersect[1::2]) + + @pytest.mark.unit + def test_z2_first(self): + """Case where straight line through first two intersects is in hypograph.""" + start = -3 * np.pi + end = -start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline(k, np.cos(k), -np.sin(k)) + pitch_inv = 0.5 + intersect = B.solve(pitch_inv, extrapolate=False) + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, B.derivative().c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + np.testing.assert_allclose(z1, intersect[1:-1:2]) + np.testing.assert_allclose(z2, intersect[0::2][1:]) + + @pytest.mark.unit + def test_z1_before_extrema(self): + """Case where local maximum is the shared intersect between two wells.""" + # To make sure both regions in epigraph left and right of extrema are + # integrated over. + start = -np.pi + end = -2 * start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline( + k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[3] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[1], 1.982767, rtol=1e-6) + np.testing.assert_allclose(z1, intersect[[1, 2]], rtol=1e-6) + # intersect array could not resolve double root as single at index 2,3 + np.testing.assert_allclose(intersect[2], intersect[3], rtol=1e-6) + np.testing.assert_allclose(z2, intersect[[3, 4]], rtol=1e-6) + + @pytest.mark.unit + def test_z2_before_extrema(self): + """Case where local minimum is the shared intersect between two wells.""" + # To make sure both regions in hypograph left and right of extrema are not + # integrated over. + start = -1.2 * np.pi + end = -2 * start + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 4, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 4, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1, intersect[[0, -2]]) + np.testing.assert_allclose(z2, intersect[[1, -1]]) + + @pytest.mark.unit + def test_extrema_first_and_before_z1(self): + """Case where first intersect is extrema and second enters epigraph.""" + # To make sure we don't perform integral between first pair of intersects. + start = -1.2 * np.pi + end = -2 * start + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 20, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 20, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[2] + 1e-13 + z1, z2 = bounce_points( + pitch_inv, + k[2:], + B.c[:, 2:].T, + dB_dz.c[:, 2:].T, + check=True, + start=k[2], + include_knots=True, + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[0], 0.835319, rtol=1e-6) + intersect = intersect[intersect >= k[2]] + np.testing.assert_allclose(z1, intersect[[0, 2, 4]], rtol=1e-6) + np.testing.assert_allclose(z2, intersect[[0, 3, 5]], rtol=1e-6) + + @pytest.mark.unit + def test_extrema_first_and_before_z2(self): + """Case where first intersect is extrema and second exits epigraph.""" + # To make sure we do perform integral between first pair of intersects. + start = -1.2 * np.pi + end = -2 * start + 1 + k = np.linspace(start, end, 7) + B = CubicHermiteSpline( + k, + np.cos(k) + 2 * np.sin(-2 * k) + k / 10, + -np.sin(k) - 4 * np.cos(-2 * k) + 1 / 10, + ) + dB_dz = B.derivative() + pitch_inv = B(dB_dz.roots(extrapolate=False))[1] - 1e-13 + z1, z2 = bounce_points( + pitch_inv, k, B.c.T, dB_dz.c.T, check=True, include_knots=True + ) + z1, z2 = TestBounce1DPoints.filter(z1, z2) + assert z1.size and z2.size + # Our routine correctly detects intersection, while scipy, jnp.root fails. + intersect = B.solve(pitch_inv, extrapolate=False) + np.testing.assert_allclose(z1[0], -0.671904, rtol=1e-6) + np.testing.assert_allclose(z1, intersect[[0, 3, 5]], rtol=1e-5) + # intersect array could not resolve double root as single at index 0,1 + np.testing.assert_allclose(intersect[0], intersect[1], rtol=1e-5) + np.testing.assert_allclose(z2, intersect[[2, 4, 6]], rtol=1e-5) + + @pytest.mark.unit + def test_get_extrema(self): + """Test computation of extrema of |B|.""" + start = -np.pi + end = -2 * start + k = np.linspace(start, end, 5) + B = CubicHermiteSpline( + k, np.cos(k) + 2 * np.sin(-2 * k), -np.sin(k) - 4 * np.cos(-2 * k) + ) + dB_dz = B.derivative() + ext, B_ext = _get_extrema(k, B.c.T, dB_dz.c.T) + mask = ~np.isnan(ext) + ext, B_ext = ext[mask], B_ext[mask] + idx = np.argsort(ext) + + ext_scipy = np.sort(dB_dz.roots(extrapolate=False)) + B_ext_scipy = B(ext_scipy) + assert ext.size == ext_scipy.size + np.testing.assert_allclose(ext[idx], ext_scipy) + np.testing.assert_allclose(B_ext[idx], B_ext_scipy) + + +def _mod_cheb_gauss(deg): + x, w = chebgauss(deg) + w /= chebweight(x) + return x, w + + +def _mod_chebu_gauss(deg): + x, w = roots_chebyu(deg) + w *= chebweight(x) + return x, w + + +class TestBounce1DQuadrature: + """Test bounce quadrature.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "is_strong, quad, automorphism", + [ + (True, tanh_sinh(40), None), + (True, leggauss(25), "default"), + (False, tanh_sinh(20), None), + (False, leggauss_lob(10), "default"), + # sin automorphism still helps out chebyshev quadrature + (True, _mod_cheb_gauss(30), "default"), + (False, _mod_chebu_gauss(10), "default"), + ], + ) + def test_bounce_quadrature(self, is_strong, quad, automorphism): + """Test quadrature matches singular (strong and weak) elliptic integrals.""" + p = 1e-4 + m = 1 - p + # Some prime number that doesn't appear anywhere in calculation. + # Ensures no lucky cancellation occurs from ζ₂ − ζ₁ / π = π / (ζ₂ − ζ₁) + # which could mask errors since π appears often in transformations. + v = 7 + z1 = -np.pi / 2 * v + z2 = -z1 + knots = np.linspace(z1, z2, 50) + pitch_inv = 1 - 50 * jnp.finfo(jnp.array(1.0).dtype).eps + b = np.clip(np.sin(knots / v) ** 2, 1e-7, 1) + db = np.sin(2 * knots / v) / v + data = {"B^zeta": b, "B^zeta_z|r,a": db, "|B|": b, "|B|_z|r,a": db} + + if is_strong: + integrand = lambda B, pitch: 1 / jnp.sqrt(1 - m * pitch * B) + truth = v * 2 * ellipkm1(p) + else: + integrand = lambda B, pitch: jnp.sqrt(1 - m * pitch * B) + truth = v * 2 * ellipe(m) + kwargs = {} + if automorphism != "default": + kwargs["automorphism"] = automorphism + bounce = Bounce1D( + Grid.create_meshgrid([1, 0, knots], coordinates="raz"), + data, + quad, + check=True, + **kwargs, + ) + result = bounce.integrate(integrand, pitch_inv, check=True, plot=True) + assert np.count_nonzero(result) == 1 + np.testing.assert_allclose(result.sum(), truth, rtol=1e-4) + + @staticmethod + @partial(np.vectorize, excluded={0}) + def _adaptive_elliptic(integrand, k): + a = 0 + b = 2 * np.arcsin(k) + return integrate.quad(integrand, a, b, args=(k,), points=b)[0] + + @staticmethod + def _fixed_elliptic(integrand, k, deg): + k = np.atleast_1d(k) + a = np.zeros_like(k) + b = 2 * np.arcsin(k) + x, w = get_quadrature(leggauss(deg), (automorphism_sin, grad_automorphism_sin)) + Z = bijection_from_disc(x, a[..., np.newaxis], b[..., np.newaxis]) + k = k[..., np.newaxis] + quad = integrand(Z, k).dot(w) * grad_bijection_from_disc(a, b) + return quad + + # TODO: add the analytical test that converts incomplete elliptic integrals to + # complete ones using the Reciprocal Modulus transformation + # https://dlmf.nist.gov/19.7#E4. + @staticmethod + def elliptic_incomplete(k2): + """Calculate elliptic integrals for bounce averaged binormal drift. + + The test is nice because it is independent of all the bounce integrals + and splines. One can test performance of different quadrature methods + by using that method in the ``_fixed_elliptic`` method above. + + """ + K_integrand = lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * (k / 4) + E_integrand = lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) / (k * 4) + # Scipy's elliptic integrals are broken. + # https://github.com/scipy/scipy/issues/20525. + k = np.sqrt(k2) + K = TestBounce1DQuadrature._adaptive_elliptic(K_integrand, k) + E = TestBounce1DQuadrature._adaptive_elliptic(E_integrand, k) + # Make sure scipy's adaptive quadrature is not broken. + np.testing.assert_allclose( + K, TestBounce1DQuadrature._fixed_elliptic(K_integrand, k, 10) + ) + np.testing.assert_allclose( + E, TestBounce1DQuadrature._fixed_elliptic(E_integrand, k, 10) + ) + + I_0 = 4 / k * K + I_1 = 4 * k * E + I_2 = 16 * k * E + I_3 = 16 * k / 9 * (2 * (-1 + 2 * k2) * E - (-1 + k2) * K) + I_4 = 16 * k / 3 * ((-1 + 2 * k2) * E - 2 * (-1 + k2) * K) + I_5 = 32 * k / 30 * (2 * (1 - k2 + k2**2) * E - (1 - 3 * k2 + 2 * k2**2) * K) + I_6 = 4 / k * (2 * k2 * E + (1 - 2 * k2) * K) + I_7 = 2 * k / 3 * ((-2 + 4 * k2) * E - 4 * (-1 + k2) * K) + # Check for math mistakes. + np.testing.assert_allclose( + I_2, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k + ), + ) + np.testing.assert_allclose( + I_3, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * Z * np.sin(Z), k + ), + ) + np.testing.assert_allclose( + I_4, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k + ), + ) + np.testing.assert_allclose( + I_5, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.sin(Z) ** 2, k + ), + ) + # scipy fails + np.testing.assert_allclose( + I_6, + TestBounce1DQuadrature._fixed_elliptic( + lambda Z, k: 2 / np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), + k, + deg=11, + ), + ) + np.testing.assert_allclose( + I_7, + TestBounce1DQuadrature._adaptive_elliptic( + lambda Z, k: 2 * np.sqrt(k**2 - np.sin(Z / 2) ** 2) * np.cos(Z), k + ), + ) + return I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 + + +class TestBounce1D: + """Test bounce integration with one-dimensional local spline methods.""" + + @staticmethod + def _example_numerator(g_zz, B, pitch): + f = (1 - 0.5 * pitch * B) * g_zz + return safediv(f, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @staticmethod + def _example_denominator(B, pitch): + return safediv(1, jnp.sqrt(jnp.abs(1 - pitch * B))) + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d * 4) + def test_bounce1d_checks(self): + """Test that all the internal correctness checks pass for real example.""" + # noqa: D202 + # Suppose we want to compute a bounce average of the function + # f(ℓ) = (1 − λ|B|/2) * g_zz, where g_zz is the squared norm of the + # toroidal basis vector on some set of field lines specified by (ρ, α) + # coordinates. This is defined as + # [∫ f(ℓ) / √(1 − λ|B|) dℓ] / [∫ 1 / √(1 − λ|B|) dℓ] + + # 1. Define python functions for the integrands. We do that above. + # 2. Pick flux surfaces, field lines, and how far to follow the field + # line in Clebsch coordinates ρ, α, ζ. + rho = np.linspace(0.1, 1, 6) + alpha = np.array([0, 0.5]) + zeta = np.linspace(-2 * np.pi, 2 * np.pi, 200) + + eq = get("HELIOTRON") + # 3. Convert above coordinates to DESC computational coordinates. + grid = get_rtz_grid( + eq, rho, alpha, zeta, coordinates="raz", period=(np.inf, 2 * np.pi, np.inf) + ) + # 4. Compute input data. + data = eq.compute( + Bounce1D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + ) + # 5. Make the bounce integration operator. + bounce = Bounce1D( + grid.source_grid, + data, + quad=leggauss(3), # not checking quadrature accuracy in this test + check=True, + ) + pitch_inv = bounce.get_pitch_inv( + grid.compress(data["min_tz |B|"]), grid.compress(data["max_tz |B|"]), 10 + ) + num = bounce.integrate( + integrand=TestBounce1D._example_numerator, + pitch_inv=pitch_inv, + f=Bounce1D.reshape_data(grid.source_grid, data["g_zz"]), + check=True, + ) + den = bounce.integrate( + integrand=TestBounce1D._example_denominator, + pitch_inv=pitch_inv, + check=True, + batch=False, + ) + avg = safediv(num, den) + assert np.isfinite(avg).all() and np.count_nonzero(avg) + + # 6. Basic manipulation of the output. + # Sum all bounce averages across a particular field line, for every field line. + result = avg.sum(axis=-1) + # Group the result by pitch and flux surface. + result = result.reshape(alpha.size, rho.size, pitch_inv.shape[-1]) + # The result stored at + m, l, p = 0, 1, 3 + print("Result(α, ρ, λ):", result[m, l, p]) + # corresponds to the 1/λ value + print("1/λ(α, ρ):", pitch_inv[l, p]) + # for the Clebsch-type field line coordinates + nodes = grid.source_grid.meshgrid_reshape(grid.source_grid.nodes[:, :2], "arz") + print("(α, ρ):", nodes[m, l, 0]) + + # 7. Optionally check for correctness of bounce points + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + # 8. Plotting + fig, ax = bounce.plot(m, l, pitch_inv[l], include_legend=False, show=False) + return fig + + @pytest.mark.unit + @pytest.mark.parametrize("func", [interp_to_argmin, interp_to_argmin_hard]) + def test_interp_to_argmin(self, func): + """Test argmin interpolation.""" # noqa: D202 + + # Test functions chosen with purpose; don't change unless plotted and compared. + def h(z): + return np.cos(3 * z) * np.sin(2 * np.cos(z)) + np.cos(1.2 * z) + + def g(z): + return np.sin(3 * z) * np.cos(1 / (1 + z)) * np.cos(z**2) * z + + def dg_dz(z): + return ( + 3 * z * np.cos(3 * z) * np.cos(z**2) * np.cos(1 / (1 + z)) + - 2 * z**2 * np.sin(3 * z) * np.sin(z**2) * np.cos(1 / (1 + z)) + + z * np.sin(3 * z) * np.sin(1 / (1 + z)) * np.cos(z**2) / (1 + z) ** 2 + + np.sin(3 * z) * np.cos(z**2) * np.cos(1 / (1 + z)) + ) + + zeta = np.linspace(0, 3 * np.pi, 175) + bounce = Bounce1D( + Grid.create_meshgrid([1, 0, zeta], coordinates="raz"), + { + "B^zeta": np.ones_like(zeta), + "B^zeta_z|r,a": np.ones_like(zeta), + "|B|": g(zeta), + "|B|_z|r,a": dg_dz(zeta), + }, + ) + z1 = np.array(0, ndmin=4) + z2 = np.array(2 * np.pi, ndmin=4) + argmin = 5.61719 + h_min = h(argmin) + result = func( + h=h(zeta), + z1=z1, + z2=z2, + knots=zeta, + g=bounce.B, + dg_dz=bounce._dB_dz, + ) + assert result.shape == z1.shape + np.testing.assert_allclose(h_min, result, rtol=1e-3) + + # TODO: stellarator geometry test with ripples + @staticmethod + def drift_analytic(data): + """Compute analytic approximation for bounce-averaged binormal drift. + + Returns + ------- + drift_analytic : jnp.ndarray + Analytic approximation for the true result that the numerical computation + should attempt to match. + cvdrift, gbdrift : jnp.ndarray + Numerically computed ``data["cvdrift"]` and ``data["gbdrift"]`` normalized + by some scale factors for this unit test. These should be fed to the bounce + integration as input. + pitch_inv : jnp.ndarray + Shape (P, ). + 1/λ values used. + + """ + B = data["|B|"] / data["Bref"] + B0 = np.mean(B) + # epsilon should be changed to dimensionless, and computed in a way that + # is independent of normalization length scales, like "effective r/R0". + epsilon = data["a"] * data["rho"] # Aspect ratio of the flux surface. + np.testing.assert_allclose(epsilon, 0.05) + theta_PEST = data["alpha"] + data["iota"] * data["zeta"] + # same as 1 / (1 + epsilon cos(theta)) assuming epsilon << 1 + B_analytic = B0 * (1 - epsilon * np.cos(theta_PEST)) + np.testing.assert_allclose(B, B_analytic, atol=3e-3) + + gradpar = data["a"] * data["B^zeta"] / data["|B|"] + # This method of computing G0 suggests a fixed point iteration. + G0 = data["a"] + gradpar_analytic = G0 * (1 - epsilon * np.cos(theta_PEST)) + gradpar_theta_analytic = data["iota"] * gradpar_analytic + G0 = np.mean(gradpar_theta_analytic) + np.testing.assert_allclose(gradpar, gradpar_analytic, atol=5e-3) + + # Comparing coefficient calculation here with coefficients from compute/_metric + normalization = -np.sign(data["psi"]) * data["Bref"] * data["a"] ** 2 + cvdrift = data["cvdrift"] * normalization + gbdrift = data["gbdrift"] * normalization + dPdrho = np.mean(-0.5 * (cvdrift - gbdrift) * data["|B|"] ** 2) + alpha_MHD = -0.5 * dPdrho / data["iota"] ** 2 + gds21 = ( + -np.sign(data["iota"]) + * data["shear"] + * dot(data["grad(psi)"], data["grad(alpha)"]) + / data["Bref"] + ) + gds21_analytic = -data["shear"] * ( + data["shear"] * theta_PEST - alpha_MHD / B**4 * np.sin(theta_PEST) + ) + gds21_analytic_low_order = -data["shear"] * ( + data["shear"] * theta_PEST - alpha_MHD / B0**4 * np.sin(theta_PEST) + ) + np.testing.assert_allclose(gds21, gds21_analytic, atol=2e-2) + np.testing.assert_allclose(gds21, gds21_analytic_low_order, atol=2.7e-2) + + fudge_1 = 0.19 + gbdrift_analytic = fudge_1 * ( + -data["shear"] + + np.cos(theta_PEST) + - gds21_analytic / data["shear"] * np.sin(theta_PEST) + ) + gbdrift_analytic_low_order = fudge_1 * ( + -data["shear"] + + np.cos(theta_PEST) + - gds21_analytic_low_order / data["shear"] * np.sin(theta_PEST) + ) + fudge_2 = 0.07 + cvdrift_analytic = gbdrift_analytic + fudge_2 * alpha_MHD / B**2 + cvdrift_analytic_low_order = ( + gbdrift_analytic_low_order + fudge_2 * alpha_MHD / B0**2 + ) + np.testing.assert_allclose(gbdrift, gbdrift_analytic, atol=1e-2) + np.testing.assert_allclose(cvdrift, cvdrift_analytic, atol=2e-2) + np.testing.assert_allclose(gbdrift, gbdrift_analytic_low_order, atol=1e-2) + np.testing.assert_allclose(cvdrift, cvdrift_analytic_low_order, atol=2e-2) + + # Exclude singularity not captured by analytic approximation for pitch near + # the maximum |B|. (This is captured by the numerical integration). + pitch_inv = get_pitch_inv(np.min(B), np.max(B), 100)[:-1] + k2 = 0.5 * ((1 - B0 / pitch_inv) / (epsilon * B0 / pitch_inv) + 1) + I_0, I_1, I_2, I_3, I_4, I_5, I_6, I_7 = ( + TestBounce1DQuadrature.elliptic_incomplete(k2) + ) + y = np.sqrt(2 * epsilon * B0 / pitch_inv) + I_0, I_2, I_4, I_6 = map(lambda I: I / y, (I_0, I_2, I_4, I_6)) + I_1, I_3, I_5, I_7 = map(lambda I: I * y, (I_1, I_3, I_5, I_7)) + + drift_analytic_num = ( + fudge_2 * alpha_MHD / B0**2 * I_1 + - 0.5 + * fudge_1 + * ( + data["shear"] * (I_0 + I_1 - I_2 - I_3) + + alpha_MHD / B0**4 * (I_4 + I_5) + - (I_6 + I_7) + ) + ) / G0 + drift_analytic_den = I_0 / G0 + drift_analytic = drift_analytic_num / drift_analytic_den + return drift_analytic, cvdrift, gbdrift, pitch_inv + + @staticmethod + def drift_num_integrand(cvdrift, gbdrift, B, pitch): + """Integrand of numerator of bounce averaged binormal drift.""" + g = jnp.sqrt(1 - pitch * B) + return (cvdrift * g) - (0.5 * g * gbdrift) + (0.5 * gbdrift / g) + + @staticmethod + def drift_den_integrand(B, pitch): + """Integrand of denominator of bounce averaged binormal drift.""" + return 1 / jnp.sqrt(1 - pitch * B) + + @pytest.mark.unit + @pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d) + def test_binormal_drift_bounce1d(self): + """Test bounce-averaged drift with analytical expressions.""" + eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5") + psi_boundary = eq.Psi / (2 * np.pi) + psi = 0.25 * psi_boundary + rho = np.sqrt(psi / psi_boundary) + np.testing.assert_allclose(rho, 0.5) + + # Make a set of nodes along a single fieldline. + grid_fsa = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, sym=eq.sym, NFP=eq.NFP) + data = eq.compute(["iota"], grid=grid_fsa) + iota = grid_fsa.compress(data["iota"]).item() + alpha = 0 + zeta = np.linspace(-np.pi / iota, np.pi / iota, (2 * eq.M_grid) * 4 + 1) + grid = get_rtz_grid( + eq, + rho, + alpha, + zeta, + coordinates="raz", + period=(np.inf, 2 * np.pi, np.inf), + iota=iota, + ) + data = eq.compute( + Bounce1D.required_names + + [ + "cvdrift", + "gbdrift", + "grad(psi)", + "grad(alpha)", + "shear", + "iota", + "psi", + "a", + ], + grid=grid, + ) + np.testing.assert_allclose(data["psi"], psi) + np.testing.assert_allclose(data["iota"], iota) + assert np.all(data["B^zeta"] > 0) + data["Bref"] = 2 * np.abs(psi_boundary) / data["a"] ** 2 + data["rho"] = rho + data["alpha"] = alpha + data["zeta"] = zeta + data["psi"] = grid.compress(data["psi"]) + data["iota"] = grid.compress(data["iota"]) + data["shear"] = grid.compress(data["shear"]) + + # Compute analytic approximation. + drift_analytic, cvdrift, gbdrift, pitch_inv = TestBounce1D.drift_analytic(data) + # Compute numerical result. + bounce = Bounce1D( + grid.source_grid, + data, + quad=leggauss(28), # converges to absolute and relative tolerance of 1e-7 + Bref=data["Bref"], + Lref=data["a"], + check=True, + ) + bounce.check_points(*bounce.points(pitch_inv), pitch_inv, plot=False) + + f = Bounce1D.reshape_data(grid.source_grid, cvdrift, gbdrift) + drift_numerical_num = bounce.integrate( + integrand=TestBounce1D.drift_num_integrand, + pitch_inv=pitch_inv, + f=f, + num_well=1, + check=True, + ) + drift_numerical_den = bounce.integrate( + integrand=TestBounce1D.drift_den_integrand, + pitch_inv=pitch_inv, + num_well=1, + weight=np.ones(zeta.size), + check=True, + ) + drift_numerical = np.squeeze(drift_numerical_num / drift_numerical_den) + msg = "There should be one bounce integral per pitch in this example." + assert drift_numerical.size == drift_analytic.size, msg + np.testing.assert_allclose( + drift_numerical, drift_analytic, atol=5e-3, rtol=5e-2 + ) + + TestBounce1D._test_bounce_autodiff( + bounce, + TestBounce1D.drift_num_integrand, + f=f, + weight=np.ones(zeta.size), + ) + + fig, ax = plt.subplots() + ax.plot(pitch_inv, drift_analytic) + ax.plot(pitch_inv, drift_numerical) + return fig + + @staticmethod + def _test_bounce_autodiff(bounce, integrand, **kwargs): + """Make sure reverse mode AD works correctly on this algorithm. + + Non-differentiable operations (e.g. ``take_mask``) are used in computation. + See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html + and https://jax.readthedocs.io/en/latest/faq.html# + why-are-gradients-zero-for-functions-based-on-sort-order. + + If the AD tool works properly, then these operations should be assigned + zero gradients while the gradients wrt parameters of our physics computations + accumulate correctly. Less mature AD tools may have subtle bugs that cause + the gradients to not accumulate correctly. (There's a few + GitHub issues that JAX has fixed related to this in the past.) + + This test first confirms the gradients computed by reverse mode AD matches + the analytic approximation of the true gradient. Then we confirm that the + partial gradients wrt the integrand and bounce points are correct. + + Apply the Leibniz integral rule + https://en.wikipedia.org/wiki/Leibniz_integral_rule, with + the label w summing over the magnetic wells: + + ∂_λ ∑_w ∫_ζ₁^ζ₂ f dζ (λ) = ∑_w [ + ∫_ζ₁^ζ₂ (∂f/∂λ)(λ) dζ + + f(λ,ζ₂) (∂ζ₂/∂λ)(λ) + - f(λ,ζ₁) (∂ζ₁/∂λ)(λ) + ] + where (∂ζ₁/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₁) + (∂ζ₂/∂λ)(λ) = -λ² / (∂|B|/∂ζ|ρ,α)(ζ₂) + + All terms in these expressions are known analytically. + If we wanted, it's simple to check explicitly that AD takes each derivative + correctly because |w| = 1 is constant and our tokamak has symmetry + (∂|B|/∂ζ|ρ,α)(ζ₁) = - (∂|B|/∂ζ|ρ,α)(ζ₂). + + After confirming the left hand side is correct, we just check that derivative + wrt bounce points of the right hand side doesn't vanish due to some zero + gradient issue mentioned above. + + """ + + def integrand_grad(*args, **kwargs2): + grad_fun = jnp.vectorize( + grad(integrand, -1), signature="()," * len(kwargs["f"]) + "(),()->()" + ) + return grad_fun(*args, *kwargs2.values()) + + def fun1(pitch): + return bounce.integrate(integrand, 1 / pitch, check=False, **kwargs).sum() + + def fun2(pitch): + return bounce.integrate( + integrand_grad, 1 / pitch, check=True, **kwargs + ).sum() + + pitch = 1.0 + # can easily obtain from math or just extrapolate from analytic expression plot + analytic_approximation_of_gradient = 650 + np.testing.assert_allclose( + grad(fun1)(pitch), analytic_approximation_of_gradient, rtol=1e-3 + ) + # It is expected that this is much larger because the integrand is singular + # wrt λ but the boundary derivative: f(λ,ζ₂) (∂ζ₂/∂λ)(λ) - f(λ,ζ₁) (∂ζ₁/∂λ)(λ). + # smooths out because the bounce points ζ₁ and ζ₂ are smooth functions of λ. + np.testing.assert_allclose(fun2(pitch), -131750, rtol=1e-1) diff --git a/tests/test_interp_utils.py b/tests/test_interp_utils.py new file mode 100644 index 0000000000..606b0fe090 --- /dev/null +++ b/tests/test_interp_utils.py @@ -0,0 +1,103 @@ +"""Test interpolation utilities.""" + +import numpy as np +import pytest +from numpy.polynomial.polynomial import polyvander + +from desc.integrals.interp_utils import polyder_vec, polyroot_vec, polyval_vec + + +class TestPolyUtils: + """Test polynomial utilities used for local spline interpolation in integrals.""" + + @pytest.mark.unit + def test_polyroot_vec(self): + """Test vectorized computation of cubic polynomial exact roots.""" + c = np.arange(-24, 24).reshape(4, 6, -1).transpose(-1, 1, 0) + # Ensure broadcasting won't hide error in implementation. + assert np.unique(c.shape).size == c.ndim + + k = np.broadcast_to(np.arange(c.shape[-2]), c.shape[:-1]) + # Now increase dimension so that shapes still broadcast, but stuff like + # ``c[...,-1]-=k`` is not allowed because it grows the dimension of ``c``. + # This is needed functionality in ``polyroot_vec`` that requires an awkward + # loop to obtain if using jnp.vectorize. + k = np.stack([k, k * 2 + 1]) + r = polyroot_vec(c, k, sort=True) + + for i in range(k.shape[0]): + d = c.copy() + d[..., -1] -= k[i] + # np.roots cannot be vectorized because it strips leading zeros and + # output shape is therefore dynamic. + for idx in np.ndindex(d.shape[:-1]): + np.testing.assert_allclose( + r[(i, *idx)], + np.sort(np.roots(d[idx])), + err_msg=f"Eigenvalue branch of polyroot_vec failed at {i, *idx}.", + ) + + # Now test analytic formula branch, Ensure it filters distinct roots, + # and ensure zero coefficients don't bust computation due to singularities + # in analytic formulae which are not present in iterative eigenvalue scheme. + c = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [1, -1, -8, 12], + [1, -6, 11, -6], + [0, -6, 11, -2], + ] + ) + r = polyroot_vec(c, sort=True, distinct=True) + for j in range(c.shape[0]): + root = r[j][~np.isnan(r[j])] + unique_root = np.unique(np.roots(c[j])) + assert root.size == unique_root.size + np.testing.assert_allclose( + root, + unique_root, + err_msg=f"Analytic branch of polyroot_vec failed at {j}.", + ) + c = np.array([0, 1, -1, -8, 12]) + r = polyroot_vec(c, sort=True, distinct=True) + r = r[~np.isnan(r)] + unique_r = np.unique(np.roots(c)) + assert r.size == unique_r.size + np.testing.assert_allclose(r, unique_r) + + @pytest.mark.unit + def test_polyder_vec(self): + """Test vectorized computation of polynomial derivative.""" + c = np.arange(-18, 18).reshape(3, -1, 6) + # Ensure broadcasting won't hide error in implementation. + assert np.unique(c.shape).size == c.ndim + np.testing.assert_allclose( + polyder_vec(c), + np.vectorize(np.polyder, signature="(m)->(n)")(c), + ) + + @pytest.mark.unit + def test_polyval_vec(self): + """Test vectorized computation of polynomial evaluation.""" + + def test(x, c): + # Ensure broadcasting won't hide error in implementation. + assert np.unique(x.shape).size == x.ndim + assert np.unique(c.shape).size == c.ndim + np.testing.assert_allclose( + polyval_vec(x=x, c=c), + np.sum(polyvander(x, c.shape[-1] - 1) * c[..., ::-1], axis=-1), + ) + + c = np.arange(-60, 60).reshape(-1, 5, 3) + x = np.linspace(0, 20, np.prod(c.shape[:-1])).reshape(c.shape[:-1]) + test(x, c) + + x = np.stack([x, x * 2], axis=0) + x = np.stack([x, x * 2, x * 3, x * 4], axis=0) + assert c.shape[:-1] == x.shape[x.ndim - (c.ndim - 1) :] + assert np.unique((c.shape[-1],) + x.shape[c.ndim - 1 :]).size == x.ndim - 1 + test(x, c) diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index 06e7b83800..86f4174547 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -6,10 +6,11 @@ from desc.backend import jit, jnp from desc.basis import DoubleFourierSeries -from desc.compute import rpz2xyz_vec, xyz2rpz_vec +from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec from desc.compute.utils import get_params, get_transforms +from desc.derivatives import FiniteDiffDerivative as Derivative from desc.examples import get -from desc.geometry import FourierRZToroidalSurface +from desc.geometry import FourierRZToroidalSurface, FourierXYZCurve from desc.grid import LinearGrid from desc.io import load from desc.magnetic_fields import ( @@ -22,11 +23,13 @@ ScalarPotentialField, SplineMagneticField, ToroidalMagneticField, + VectorPotentialField, VerticalMagneticField, field_line_integrate, read_BNORM_file, ) from desc.magnetic_fields._dommaschk import CD_m_k, CN_m_k +from desc.utils import dot def phi_lm(R, phi, Z, a, m): @@ -59,8 +62,41 @@ def test_basic_fields(self): tfield = ToroidalMagneticField(2, 1) vfield = VerticalMagneticField(1) pfield = PoloidalMagneticField(2, 1, 2) + + def tfield_A(R, phi, Z, B0=2, R0=1): + az = -B0 * R0 * jnp.log(R) + arp = jnp.zeros_like(az) + A = jnp.array([arp, arp, az]).T + return A + + tfield_from_A = VectorPotentialField(tfield_A, params={"B0": 2, "R0": 1}) + + def vfield_A(R, phi, Z, B0=None): + coords_rpz = jnp.vstack([R, phi, Z]).T + coords_xyz = rpz2xyz(coords_rpz) + ax = B0 / 2 * coords_xyz[:, 1] + ay = -B0 / 2 * coords_xyz[:, 0] + + az = jnp.zeros_like(ax) + A = jnp.array([ax, -ay, az]).T + A = xyz2rpz_vec(A, phi=coords_rpz[:, 1]) + return A + + vfield_params = {"B0": 1} + vfield_from_A = VectorPotentialField(vfield_A, params=vfield_params) + np.testing.assert_allclose(tfield([1, 0, 0]), [[0, 2, 0]]) np.testing.assert_allclose((4 * tfield)([2, 0, 0]), [[0, 4, 0]]) + np.testing.assert_allclose(tfield_from_A([1, 0, 0]), [[0, 2, 0]]) + np.testing.assert_allclose( + tfield_A(1, 0, 0), + tfield_from_A.compute_magnetic_vector_potential([1, 0, 0]).squeeze(), + ) + np.testing.assert_allclose( + vfield_A(1, 0, 0, **vfield_params), + vfield_from_A.compute_magnetic_vector_potential([1, 0, 0]), + ) + np.testing.assert_allclose((tfield + vfield)([1, 0, 0]), [[0, 2, 1]]) np.testing.assert_allclose( (tfield + vfield - pfield)([1, 0, 0.1]), [[0.4, 2, 1]] @@ -104,17 +140,40 @@ def test_combined_fields(self): assert scaled_field.B0 == 2 assert scaled_field.scale == 3.1 np.testing.assert_allclose(scaled_field([1.0, 0, 0]), np.array([[0, 6.2, 0]])) + np.testing.assert_allclose( + scaled_field.compute_magnetic_vector_potential([2.0, 0, 0]), + np.array([[0, 0, -3.1 * 2 * 1 * np.log(2)]]), + ) + scaled_field.R0 = 1.3 scaled_field.scale = 1.0 np.testing.assert_allclose(scaled_field([1.3, 0, 0]), np.array([[0, 2, 0]])) + np.testing.assert_allclose( + scaled_field.compute_magnetic_vector_potential([2.0, 0, 0]), + np.array([[0, 0, -2 * 1.3 * np.log(2)]]), + ) assert scaled_field.optimizable_params == ["B0", "R0", "scale"] assert hasattr(scaled_field, "B0") sum_field = vfield + pfield + tfield + sum_field_tv = vfield + tfield # to test A since pfield does not have A assert len(sum_field) == 3 + assert len(sum_field_tv) == 2 + np.testing.assert_allclose( sum_field([1.3, 0, 0.0]), [[0.0, 2, 3.2 + 2 * 1.2 * 0.3]] ) + + tfield_A = np.array([[0, 0, -tfield.B0 * tfield.R0 * np.log(tfield.R0)]]) + x = tfield.R0 * np.cos(np.pi / 4) + y = tfield.R0 * np.sin(np.pi / 4) + vfield_A = np.array([[vfield.B0 * y, -vfield.B0 * x, 0]]) / 2 + + np.testing.assert_allclose( + sum_field_tv.compute_magnetic_vector_potential([x, y, 0.0], basis="xyz"), + tfield_A + vfield_A, + ) + assert sum_field.optimizable_params == [ ["B0"], ["B0", "R0", "iota"], @@ -304,6 +363,87 @@ def test_current_potential_field(self): with pytest.raises(AssertionError): field.potential_dzeta = 1 + @pytest.mark.unit + def test_current_potential_vector_potential(self): + """Test current potential field vector potential against analytic result.""" + R0 = 10 + a = 1 + surface = FourierRZToroidalSurface( + R_lmn=jnp.array([R0, a]), + Z_lmn=jnp.array([0, -a]), + modes_R=jnp.array([[0, 0], [1, 0]]), + modes_Z=jnp.array([[0, 0], [-1, 0]]), + NFP=10, + ) + # make a current potential corresponding a purely poloidal current + G = 100 # net poloidal current + potential = lambda theta, zeta, G: G * zeta / 2 / jnp.pi + potential_dtheta = lambda theta, zeta, G: jnp.zeros_like(theta) + potential_dzeta = lambda theta, zeta, G: G * jnp.ones_like(theta) / 2 / jnp.pi + + params = {"G": -G} + + field = CurrentPotentialField( + potential, + R_lmn=surface.R_lmn, + Z_lmn=surface.Z_lmn, + modes_R=surface._R_basis.modes[:, 1:], + modes_Z=surface._Z_basis.modes[:, 1:], + params=params, + potential_dtheta=potential_dtheta, + potential_dzeta=potential_dzeta, + NFP=surface.NFP, + ) + # test the loop integral of A around a curve encompassing the torus + # against the analytic result for flux in an ideal toroidal solenoid + prefactors = mu_0 * G / 2 / jnp.pi + correct_flux = -2 * np.pi * prefactors * (np.sqrt(R0**2 - a**2) - R0) + + curve = FourierXYZCurve() # curve to integrate A over + curve_grid = LinearGrid(zeta=20) + curve_data = curve.compute(["x", "x_s"], grid=curve_grid, basis="xyz") + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + surface_grid = LinearGrid(M=60, N=60, NFP=10) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + + field.params["G"] = -2 * field.params["G"] + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(-2 * correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(-2 * correct_flux, flux_rpz, rtol=1e-8) + @pytest.mark.unit def test_fourier_current_potential_field(self): """Test Fourier current potential magnetic field against analytic result.""" @@ -416,6 +556,124 @@ def test_fourier_current_potential_field(self): atol=1e-16, ) + @pytest.mark.unit + def test_fourier_current_potential_vector_potential(self): + """Test Fourier current potential vector potential against analytic result.""" + R0 = 10 + a = 1 + surface = FourierRZToroidalSurface( + R_lmn=jnp.array([R0, a]), + Z_lmn=jnp.array([0, -a]), + modes_R=jnp.array([[0, 0], [1, 0]]), + modes_Z=jnp.array([[0, 0], [-1, 0]]), + NFP=10, + ) + + basis = DoubleFourierSeries(M=2, N=2, sym="sin") + phi_mn = np.ones((basis.num_modes,)) + # make a current potential corresponding a purely poloidal current + G = 100 # net poloidal current + + # test the loop integral of A around a curve encompassing the torus + # against the analytic result for flux in an ideal toroidal solenoid + ## expression for flux inside of toroidal solenoid of radius a + prefactors = mu_0 * G / 2 / jnp.pi + correct_flux = -2 * np.pi * prefactors * (np.sqrt(R0**2 - a**2) - R0) + + curve = FourierXYZCurve() # curve to integrate A over + curve_grid = LinearGrid(zeta=20) + curve_data = curve.compute(["x", "x_s"], grid=curve_grid) + curve_data_rpz = curve.compute(["x", "x_s"], grid=curve_grid, basis="rpz") + + field = FourierCurrentPotentialField( + Phi_mn=phi_mn, + modes_Phi=basis.modes[:, 1:], + I=0, + G=-G, # to get a positive B_phi, we must put G negative + # since -G is the net poloidal current on the surface + # ( with G=-(net_current) meaning that we have net_current + # flowing poloidally (in clockwise direction) around torus) + sym_Phi="sin", + R_lmn=surface.R_lmn, + Z_lmn=surface.Z_lmn, + modes_R=surface._R_basis.modes[:, 1:], + modes_Z=surface._Z_basis.modes[:, 1:], + NFP=10, + ) + surface_grid = LinearGrid(M=60, N=60, NFP=10) + + phi_mn = np.zeros((basis.num_modes,)) + + field.Phi_mn = phi_mn + + field.change_resolution(3, 3) + field.change_Phi_resolution(2, 2) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + + field.G = -2 * field.G + field.I = 0 + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(-2 * correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(-2 * correct_flux, flux_rpz, rtol=1e-8) + + field = FourierCurrentPotentialField.from_surface( + surface=surface, + Phi_mn=phi_mn, + modes_Phi=basis.modes[:, 1:], + I=0, + G=-G, + ) + + A_xyz = field.compute_magnetic_vector_potential( + curve_data["x"], basis="xyz", source_grid=surface_grid + ) + A_rpz = field.compute_magnetic_vector_potential( + curve_data_rpz["x"], basis="rpz", source_grid=surface_grid + ) + + # integrate to get the flux + flux_xyz = jnp.sum( + dot(A_xyz, curve_data["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + flux_rpz = jnp.sum( + dot(A_rpz, curve_data_rpz["x_s"], axis=-1) * curve_grid.spacing[:, 2] + ) + + np.testing.assert_allclose(correct_flux, flux_xyz, rtol=1e-8) + np.testing.assert_allclose(correct_flux, flux_rpz, rtol=1e-8) + @pytest.mark.unit def test_fourier_current_potential_field_symmetry(self): """Test Fourier current potential magnetic field Phi symmetry logic.""" @@ -644,7 +902,7 @@ def test_init_Phi_mn_fourier_current_field(self): @pytest.mark.slow @pytest.mark.unit - def test_spline_field(self): + def test_spline_field(self, tmpdir_factory): """Test accuracy of spline magnetic field.""" field1 = ScalarPotentialField(phi_lm, args) R = np.linspace(0.5, 1.5, 20) @@ -659,10 +917,65 @@ def test_spline_field(self): extcur = [4700.0, 1000.0] mgrid = "tests/inputs/mgrid_test.nc" field3 = SplineMagneticField.from_mgrid(mgrid, extcur) + # test saving and loading from mgrid + tmpdir = tmpdir_factory.mktemp("spline_mgrid_with_A") + path = tmpdir.join("spline_mgrid_with_A.nc") + field3.save_mgrid( + path, + Rmin=np.min(field3._R), + Rmax=np.max(field3._R), + Zmin=np.min(field3._Z), + Zmax=np.max(field3._Z), + nR=field3._R.size, + nZ=field3._Z.size, + nphi=field3._phi.size, + ) + # no need for extcur b/c is saved in "raw" format, no need to scale again + field4 = SplineMagneticField.from_mgrid(path) + attrs_4d = ["_AR", "_Aphi", "_AZ", "_BR", "_Bphi", "_BZ"] + for attr in attrs_4d: + np.testing.assert_allclose( + (getattr(field3, attr) * np.array(extcur)).sum(axis=-1), + getattr(field4, attr).squeeze(), + err_msg=attr, + ) + attrs_3d = ["_R", "_phi", "_Z"] + for attr in attrs_3d: + np.testing.assert_allclose(getattr(field3, attr), getattr(field4, attr)) + + r = 0.70 + p = 0 + z = 0 + # use finite diff derivatives to check A accuracy + tfield_A = lambda R, phi, Z: field3.compute_magnetic_vector_potential( + jnp.vstack([R, phi, Z]).T + ) + funR = lambda x: tfield_A(x, p, z) + funP = lambda x: tfield_A(r, x, z) + funZ = lambda x: tfield_A(r, p, x) + + ap = tfield_A(r, p, z)[:, 1] + + # these are the gradients of each component of A + dAdr = Derivative.compute_jvp(funR, 0, (jnp.ones_like(r),), r) + dAdp = Derivative.compute_jvp(funP, 0, (jnp.ones_like(p),), p) + dAdz = Derivative.compute_jvp(funZ, 0, (jnp.ones_like(z),), z) + + # form the B components with the appropriate combinations + B2 = jnp.array( + [ + dAdp[:, 2] / r - dAdz[:, 1], + dAdz[:, 0] - dAdr[:, 2], + dAdr[:, 1] + (ap - dAdp[:, 0]) / r, + ] + ).T np.testing.assert_allclose( field3([0.70, 0, 0]), np.array([[0, -0.671, 0.0858]]), rtol=1e-3, atol=1e-8 ) + + np.testing.assert_allclose(field3([0.70, 0, 0]), B2, rtol=1e-3, atol=5e-3) + field3.currents *= 2 np.testing.assert_allclose( field3([0.70, 0, 0]), @@ -697,14 +1010,20 @@ def test_spline_field_axisym(self): -2.430716e04, -2.380229e04, ] - field = SplineMagneticField.from_mgrid( - "tests/inputs/mgrid_d3d.nc", extcur=extcur - ) + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + field = SplineMagneticField.from_mgrid( + "tests/inputs/mgrid_d3d.nc", extcur=extcur + ) # make sure field is invariant to shift in phi B1 = field.compute_magnetic_field(np.array([1.75, 0.0, 0.0])) B2 = field.compute_magnetic_field(np.array([1.75, 1.0, 0.0])) np.testing.assert_allclose(B1, B2) + # test the error when no vec pot values exist + with pytest.raises(ValueError, match="no vector potential"): + field.compute_magnetic_vector_potential(np.array([1.75, 0.0, 0.0])) + @pytest.mark.unit def test_field_line_integrate(self): """Test field line integration.""" @@ -842,8 +1161,15 @@ def test_mgrid_io(self, tmpdir_factory): Rmax = 7 Zmin = -2 Zmax = 2 - save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax) - load_field = SplineMagneticField.from_mgrid(path) + with pytest.raises(NotImplementedError): + # Raises error because poloidal field has no vector potential + # and so cannot save the vector potential + save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax) + save_field.save_mgrid(path, Rmin, Rmax, Zmin, Zmax, save_vector_potential=False) + with pytest.warns(UserWarning): + # user warning because saved mgrid has no vector potential + # and so cannot load the vector potential + load_field = SplineMagneticField.from_mgrid(path) # check that the fields are the same num_nodes = 50 diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 82f0dd337a..4711907ab9 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -24,12 +24,13 @@ from desc.compute import get_transforms from desc.equilibrium import Equilibrium from desc.examples import get -from desc.geometry import FourierRZToroidalSurface, FourierXYZCurve +from desc.geometry import FourierPlanarCurve, FourierRZToroidalSurface, FourierXYZCurve from desc.grid import ConcentricGrid, LinearGrid, QuadratureGrid from desc.io import load from desc.magnetic_fields import ( FourierCurrentPotentialField, OmnigenousField, + PoloidalMagneticField, SplineMagneticField, ToroidalMagneticField, VerticalMagneticField, @@ -367,6 +368,51 @@ def test_qh_boozer(self): # should have the same values up until then np.testing.assert_allclose(f[idx_f][:120], B_mn[idx_B][:120]) + @pytest.mark.unit + def test_qh_boozer_multiple_surfaces(self): + """Test for computing Boozer error on multiple surfaces.""" + eq = get("WISTELL-A") # WISTELL-A is optimized for QH symmetry + helicity = (1, -eq.NFP) + M_booz = eq.M + N_booz = eq.N + grid1 = LinearGrid(rho=0.5, M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False) + grid2 = LinearGrid(rho=1.0, M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False) + grid3 = LinearGrid( + rho=np.array([0.5, 1.0]), M=2 * eq.M, N=2 * eq.N, NFP=eq.NFP, sym=False + ) + + obj1 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid1, + normalize=False, + eq=eq, + ) + obj2 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid2, + normalize=False, + eq=eq, + ) + obj3 = QuasisymmetryBoozer( + helicity=helicity, + M_booz=M_booz, + N_booz=N_booz, + grid=grid3, + normalize=False, + eq=eq, + ) + obj1.build() + obj2.build() + obj3.build() + f1 = obj1.compute_unscaled(*obj1.xs(eq)) + f2 = obj2.compute_unscaled(*obj2.xs(eq)) + f3 = obj3.compute_unscaled(*obj3.xs(eq)) + np.testing.assert_allclose(f3, np.concatenate([f1, f2]), atol=1e-14) + @pytest.mark.unit def test_qs_twoterm(self): """Test calculation of two term QS metric.""" @@ -441,11 +487,6 @@ def test_qs_boozer_grids(self): with pytest.raises(ValueError): QuasisymmetryBoozer(eq=eq, grid=grid).build() - # multiple flux surfaces - grid = LinearGrid(M=eq.M, N=eq.N, NFP=eq.NFP, rho=[0.25, 0.5, 0.75, 1]) - with pytest.raises(ValueError): - QuasisymmetryBoozer(eq=eq, grid=grid).build() - @pytest.mark.unit def test_mercier_stability(self): """Test calculation of mercier stability criteria.""" @@ -869,6 +910,13 @@ def test(coil, grid=None): test(mixed_coils) test(nested_coils, grid=grid) + def test_coil_type_error(self): + """Tests error when objective is not passed a coil.""" + curve = FourierPlanarCurve(r_n=2, basis="rpz") + obj = CoilLength(curve) + with pytest.raises(TypeError): + obj.build() + @pytest.mark.unit def test_coil_min_distance(self): """Tests minimum distance between coils in a coilset.""" @@ -1113,10 +1161,14 @@ def test_quadratic_flux(self): @pytest.mark.unit def test_toroidal_flux(self): """Test calculation of toroidal flux from coils.""" - grid1 = LinearGrid(L=10, M=10, zeta=np.array(0.0)) + grid1 = LinearGrid(L=0, M=40, zeta=np.array(0.0)) def test(eq, field, correct_value, rtol=1e-14, grid=None): - obj = ToroidalFlux(eq=eq, field=field, eval_grid=grid) + obj = ToroidalFlux( + eq=eq, + field=field, + eval_grid=grid, + ) obj.build(verbose=2) torflux = obj.compute_unscaled(*obj.xs(field)) np.testing.assert_allclose(torflux, correct_value, rtol=rtol) @@ -1126,22 +1178,20 @@ def test(eq, field, correct_value, rtol=1e-14, grid=None): field = ToroidalMagneticField(B0=1, R0=1) # calc field Psi - data = eq.compute(["R", "phi", "Z", "|e_rho x e_theta|", "n_zeta"], grid=grid1) - field_B = field.compute_magnetic_field( + data = eq.compute(["R", "phi", "Z", "e_theta"], grid=grid1) + field_A = field.compute_magnetic_vector_potential( np.vstack([data["R"], data["phi"], data["Z"]]).T ) - B_dot_n_zeta = jnp.sum(field_B * data["n_zeta"], axis=1) + A_dot_e_theta = jnp.sum(field_A * data["e_theta"], axis=1) - psi_from_field = np.sum( - grid1.spacing[:, 0] - * grid1.spacing[:, 1] - * data["|e_rho x e_theta|"] - * B_dot_n_zeta - ) - eq.change_resolution(L_grid=10, M_grid=10) + psi_from_field = np.sum(grid1.spacing[:, 1] * A_dot_e_theta) + eq.change_resolution(L_grid=20, M_grid=20) test(eq, field, psi_from_field) + test(eq, field, psi_from_field, rtol=1e-3) + # test on field with no vector potential + test(eq, PoloidalMagneticField(1, 1, 1), 0.0) @pytest.mark.unit def test_signed_plasma_vessel_distance(self): @@ -2228,7 +2278,9 @@ def test_compute_scalar_resolution_heating_power(self): @pytest.mark.regression def test_compute_scalar_resolution_boundary_error(self): """BoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2254,7 +2306,9 @@ def test_compute_scalar_resolution_boundary_error(self): @pytest.mark.regression def test_compute_scalar_resolution_vacuum_boundary_error(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2281,7 +2335,8 @@ def test_compute_scalar_resolution_vacuum_boundary_error(self): @pytest.mark.regression def test_compute_scalar_resolution_quadratic_flux(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2305,7 +2360,25 @@ def test_compute_scalar_resolution_quadratic_flux(self): np.testing.assert_allclose(f, f[-1], rtol=5e-2) @pytest.mark.regression - def test_compute_scalar_resolution_toroidal_flux(self): + def test_compute_scalar_resolution_toroidal_flux_A(self): + """ToroidalFlux.""" + ext_field = ToroidalMagneticField(1, 1) + eq = get("precise_QA") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(4, 4, 4, 8, 8, 8) + + f = np.zeros_like(self.res_array, dtype=float) + for i, res in enumerate(self.res_array): + eq.change_resolution( + L_grid=int(eq.L * res), M_grid=int(eq.M * res), N_grid=int(eq.N * res) + ) + obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) + obj.build(verbose=0) + f[i] = obj.compute_scalar(obj.x()) + np.testing.assert_allclose(f, f[-1], rtol=5e-2) + + @pytest.mark.regression + def test_compute_scalar_resolution_toroidal_flux_B(self): """ToroidalFlux.""" ext_field = ToroidalMagneticField(1, 1) eq = get("precise_QA") @@ -2579,7 +2652,9 @@ def test_objective_no_nangrad_heating_power(self): @pytest.mark.unit def test_objective_no_nangrad_boundary_error(self): """BoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2600,7 +2675,9 @@ def test_objective_no_nangrad_boundary_error(self): @pytest.mark.unit def test_objective_no_nangrad_vacuum_boundary_error(self): """VacuumBoundaryError.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2623,7 +2700,9 @@ def test_objective_no_nangrad_vacuum_boundary_error(self): @pytest.mark.unit def test_objective_no_nangrad_quadratic_flux(self): """QuadraticFlux.""" - ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") + with pytest.warns(UserWarning): + # user warning because saved mgrid no vector potential + ext_field = SplineMagneticField.from_mgrid(r"tests/inputs/mgrid_solovev.nc") pres = PowerSeriesProfile([1.25e-1, 0, -1.25e-1]) iota = PowerSeriesProfile([-4.9e-1, 0, 3.0e-1]) @@ -2654,7 +2733,12 @@ def test_objective_no_nangrad_toroidal_flux(self): obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) obj.build() g = obj.grad(obj.x(ext_field)) - assert not np.any(np.isnan(g)), "toroidal flux" + assert not np.any(np.isnan(g)), "toroidal flux A" + + obj = ObjectiveFunction(ToroidalFlux(eq, ext_field), use_jit=False) + obj.build() + g = obj.grad(obj.x(ext_field)) + assert not np.any(np.isnan(g)), "toroidal flux B" @pytest.mark.unit @pytest.mark.parametrize( diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py new file mode 100644 index 0000000000..5a7c3d00e7 --- /dev/null +++ b/tests/test_quad_utils.py @@ -0,0 +1,103 @@ +"""Tests for quadrature utilities.""" + +import numpy as np +import pytest +from jax import grad + +from desc.backend import jnp +from desc.integrals.quad_utils import ( + automorphism_arcsin, + automorphism_sin, + bijection_from_disc, + bijection_to_disc, + composite_linspace, + grad_automorphism_arcsin, + grad_automorphism_sin, + grad_bijection_from_disc, + leggauss_lob, + tanh_sinh, +) +from desc.utils import only1 + + +@pytest.mark.unit +def test_composite_linspace(): + """Test this utility function which is used for integration over pitch.""" + B_min_tz = np.array([0.1, 0.2]) + B_max_tz = np.array([1, 3]) + breaks = np.linspace(B_min_tz, B_max_tz, num=5) + b = composite_linspace(breaks, num=3) + for i in range(breaks.shape[0]): + for j in range(breaks.shape[1]): + assert only1(np.isclose(breaks[i, j], b[:, j]).tolist()) + + +@pytest.mark.unit +def test_automorphism(): + """Test automorphisms.""" + a, b = -312, 786 + x = np.linspace(a, b, 10) + y = bijection_to_disc(x, a, b) + x_1 = bijection_from_disc(y, a, b) + np.testing.assert_allclose(x_1, x) + np.testing.assert_allclose(bijection_to_disc(x_1, a, b), y) + np.testing.assert_allclose(automorphism_arcsin(automorphism_sin(y)), y, atol=5e-7) + np.testing.assert_allclose(automorphism_sin(automorphism_arcsin(y)), y, atol=5e-7) + + np.testing.assert_allclose(grad_bijection_from_disc(a, b), 1 / (2 / (b - a))) + np.testing.assert_allclose( + grad_automorphism_sin(y), + 1 / grad_automorphism_arcsin(automorphism_sin(y)), + atol=2e-6, + ) + np.testing.assert_allclose( + 1 / grad_automorphism_arcsin(y), + grad_automorphism_sin(automorphism_arcsin(y)), + atol=2e-6, + ) + + # test that floating point error is acceptable + x = tanh_sinh(19)[0] + assert np.all(np.abs(x) < 1) + y = 1 / np.sqrt(1 - np.abs(x)) + assert np.isfinite(y).all() + y = 1 / np.sqrt(1 - np.abs(automorphism_sin(x))) + assert np.isfinite(y).all() + y = 1 / np.sqrt(1 - np.abs(automorphism_arcsin(x))) + assert np.isfinite(y).all() + + +@pytest.mark.unit +def test_leggauss_lobatto(): + """Test quadrature points and weights against known values.""" + with pytest.raises(ValueError): + x, w = leggauss_lob(1) + x, w = leggauss_lob(0, True) + assert x.size == w.size == 0 + + x, w = leggauss_lob(2) + np.testing.assert_allclose(x, [-1, 1]) + np.testing.assert_allclose(w, [1, 1]) + + x, w = leggauss_lob(3) + np.testing.assert_allclose(x, [-1, 0, 1]) + np.testing.assert_allclose(w, [1 / 3, 4 / 3, 1 / 3]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(4) + np.testing.assert_allclose(x, [-1, -np.sqrt(1 / 5), np.sqrt(1 / 5), 1]) + np.testing.assert_allclose(w, [1 / 6, 5 / 6, 5 / 6, 1 / 6]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + x, w = leggauss_lob(5) + np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1]) + np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10]) + np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + def fun(a): + x, w = leggauss_lob(a.size) + return jnp.dot(x * a, w) + + # make sure differentiable + # https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161 + assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all() diff --git a/tests/test_utils.py b/tests/test_utils.py index 6bfadb4008..2812e8a01b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,13 @@ """Tests for utility functions.""" +from functools import partial + import numpy as np import pytest -from desc.backend import tree_leaves, tree_structure +from desc.backend import flatnonzero, jnp, tree_leaves, tree_structure from desc.grid import LinearGrid -from desc.utils import broadcast_tree, isalmostequal, islinspaced +from desc.utils import broadcast_tree, isalmostequal, islinspaced, take_mask @pytest.mark.unit @@ -197,3 +199,35 @@ def test_broadcast_tree(): ] for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): np.testing.assert_allclose(leaf, leaf_correct) + + +@partial(jnp.vectorize, signature="(m)->()") +def _last_value(a): + """Return the last non-nan value in ``a``.""" + a = a[::-1] + idx = jnp.squeeze(flatnonzero(~jnp.isnan(a), size=1, fill_value=0)) + return a[idx] + + +@pytest.mark.unit +def test_take_mask(): + """Test custom masked array operation.""" + rows = 5 + cols = 7 + a = np.random.rand(rows, cols) + nan_idx = np.random.choice(rows * cols, size=(rows * cols) // 2, replace=False) + a.ravel()[nan_idx] = np.nan + taken = take_mask(a, ~np.isnan(a)) + last = _last_value(taken) + for i in range(rows): + desired = a[i, ~np.isnan(a[i])] + assert np.array_equal( + taken[i], + np.pad(desired, (0, cols - desired.size), constant_values=np.nan), + equal_nan=True, + ) + assert np.array_equal( + last[i], + desired[-1] if desired.size else np.nan, + equal_nan=True, + ) diff --git a/tests/test_vmec.py b/tests/test_vmec.py index d7ae22f2b4..0fef594b3c 100644 --- a/tests/test_vmec.py +++ b/tests/test_vmec.py @@ -368,14 +368,6 @@ def test_axis_surf_after_load(): f.close() -@pytest.mark.unit -def test_vmec_save_asym(TmpDir): - """Tests that saving a non-symmetric equilibrium runs without errors.""" - output_path = str(TmpDir.join("output.nc")) - eq = Equilibrium(L=2, M=2, N=2, NFP=3, pressure=np.array([[2, 0]]), sym=False) - VMECIO.save(eq, output_path) - - @pytest.mark.unit def test_vmec_save_kinetic(TmpDir): """Tests that saving an equilibrium with kinetic profiles runs without errors.""" @@ -874,6 +866,369 @@ def test_vmec_save_2(VMEC_save): np.testing.assert_allclose(currv_vmec, currv_desc, rtol=1e-2) +@pytest.mark.regression +@pytest.mark.slow +def test_vmec_save_asym(VMEC_save_asym): + """Tests that saving in NetCDF format agrees with VMEC.""" + vmec, desc, eq = VMEC_save_asym + # first, compare some quantities which don't require calculation + assert vmec.variables["version_"][:] == desc.variables["version_"][:] + assert vmec.variables["mgrid_mode"][:] == desc.variables["mgrid_mode"][:] + assert np.all( + np.char.compare_chararrays( + vmec.variables["mgrid_file"][:], + desc.variables["mgrid_file"][:], + "==", + False, + ) + ) + assert vmec.variables["ier_flag"][:] == desc.variables["ier_flag"][:] + assert ( + vmec.variables["lfreeb__logical__"][:] == desc.variables["lfreeb__logical__"][:] + ) + assert ( + vmec.variables["lrecon__logical__"][:] == desc.variables["lrecon__logical__"][:] + ) + assert vmec.variables["lrfp__logical__"][:] == desc.variables["lrfp__logical__"][:] + assert ( + vmec.variables["lasym__logical__"][:] == desc.variables["lasym__logical__"][:] + ) + assert vmec.variables["nfp"][:] == desc.variables["nfp"][:] + assert vmec.variables["ns"][:] == desc.variables["ns"][:] + assert vmec.variables["mpol"][:] == desc.variables["mpol"][:] + assert vmec.variables["ntor"][:] == desc.variables["ntor"][:] + assert vmec.variables["mnmax"][:] == desc.variables["mnmax"][:] + np.testing.assert_allclose(vmec.variables["xm"][:], desc.variables["xm"][:]) + np.testing.assert_allclose(vmec.variables["xn"][:], desc.variables["xn"][:]) + assert vmec.variables["mnmax_nyq"][:] == desc.variables["mnmax_nyq"][:] + np.testing.assert_allclose(vmec.variables["xm_nyq"][:], desc.variables["xm_nyq"][:]) + np.testing.assert_allclose(vmec.variables["xn_nyq"][:], desc.variables["xn_nyq"][:]) + assert vmec.variables["signgs"][:] == desc.variables["signgs"][:] + assert vmec.variables["gamma"][:] == desc.variables["gamma"][:] + assert vmec.variables["nextcur"][:] == desc.variables["nextcur"][:] + assert np.all( + np.char.compare_chararrays( + vmec.variables["pmass_type"][:], + desc.variables["pmass_type"][:], + "==", + False, + ) + ) + assert np.all( + np.char.compare_chararrays( + vmec.variables["piota_type"][:], + desc.variables["piota_type"][:], + "==", + False, + ) + ) + assert np.all( + np.char.compare_chararrays( + vmec.variables["pcurr_type"][:], + desc.variables["pcurr_type"][:], + "==", + False, + ) + ) + np.testing.assert_allclose( + vmec.variables["am"][:], desc.variables["am"][:], atol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["ai"][:], desc.variables["ai"][:], atol=1e-8 + ) + np.testing.assert_allclose( + vmec.variables["ac"][:], desc.variables["ac"][:], atol=3e-5 + ) + np.testing.assert_allclose( + vmec.variables["presf"][:], desc.variables["presf"][:], atol=2e-5 + ) + np.testing.assert_allclose(vmec.variables["pres"][:], desc.variables["pres"][:]) + np.testing.assert_allclose(vmec.variables["mass"][:], desc.variables["mass"][:]) + np.testing.assert_allclose( + vmec.variables["iotaf"][:], desc.variables["iotaf"][:], rtol=5e-4 + ) + np.testing.assert_allclose( + vmec.variables["q_factor"][:], desc.variables["q_factor"][:], rtol=5e-4 + ) + np.testing.assert_allclose( + vmec.variables["iotas"][:], desc.variables["iotas"][:], rtol=5e-4 + ) + np.testing.assert_allclose(vmec.variables["phi"][:], desc.variables["phi"][:]) + np.testing.assert_allclose(vmec.variables["phipf"][:], desc.variables["phipf"][:]) + np.testing.assert_allclose(vmec.variables["phips"][:], desc.variables["phips"][:]) + np.testing.assert_allclose( + vmec.variables["chi"][:], desc.variables["chi"][:], atol=3e-5, rtol=1e-3 + ) + np.testing.assert_allclose( + vmec.variables["chipf"][:], desc.variables["chipf"][:], atol=3e-5, rtol=1e-3 + ) + np.testing.assert_allclose( + vmec.variables["Rmajor_p"][:], desc.variables["Rmajor_p"][:] + ) + np.testing.assert_allclose( + vmec.variables["Aminor_p"][:], desc.variables["Aminor_p"][:] + ) + np.testing.assert_allclose(vmec.variables["aspect"][:], desc.variables["aspect"][:]) + np.testing.assert_allclose( + vmec.variables["volume_p"][:], desc.variables["volume_p"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["volavgB"][:], desc.variables["volavgB"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betatotal"][:], desc.variables["betatotal"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betapol"][:], desc.variables["betapol"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["betator"][:], desc.variables["betator"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["ctor"][:], + desc.variables["ctor"][:], + atol=1e-9, # it is a zero current solve + ) + np.testing.assert_allclose( + vmec.variables["rbtor"][:], desc.variables["rbtor"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["rbtor0"][:], desc.variables["rbtor0"][:], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["b0"][:], desc.variables["b0"][:], rtol=4e-3 + ) + np.testing.assert_allclose( + vmec.variables["buco"][20:100], desc.variables["buco"][20:100], atol=1e-15 + ) + np.testing.assert_allclose( + vmec.variables["bvco"][20:100], desc.variables["bvco"][20:100], rtol=1e-5 + ) + np.testing.assert_allclose( + vmec.variables["vp"][20:100], desc.variables["vp"][20:100], rtol=3e-4 + ) + np.testing.assert_allclose( + vmec.variables["bdotb"][20:100], desc.variables["bdotb"][20:100], rtol=3e-4 + ) + np.testing.assert_allclose( + vmec.variables["jdotb"][20:100], + desc.variables["jdotb"][20:100], + atol=4e-3, # nearly zero bc is vacuum + ) + np.testing.assert_allclose( + vmec.variables["jcuru"][20:100], desc.variables["jcuru"][20:100], atol=2 + ) + np.testing.assert_allclose( + vmec.variables["jcurv"][20:100], desc.variables["jcurv"][20:100], rtol=2 + ) + np.testing.assert_allclose( + vmec.variables["DShear"][20:100], desc.variables["DShear"][20:100], rtol=3e-2 + ) + np.testing.assert_allclose( + vmec.variables["DCurr"][20:100], + desc.variables["DCurr"][20:100], + atol=1e-4, # nearly zero bc vacuum + ) + np.testing.assert_allclose( + vmec.variables["DWell"][20:100], desc.variables["DWell"][20:100], rtol=1e-2 + ) + np.testing.assert_allclose( + vmec.variables["DGeod"][20:100], + desc.variables["DGeod"][20:100], + atol=4e-3, + rtol=1e-2, + ) + + # the Mercier stability is pretty off, + # but these are not exactly similar solutions to eachother + np.testing.assert_allclose( + vmec.variables["DMerc"][20:100], desc.variables["DMerc"][20:100], atol=4e-3 + ) + np.testing.assert_allclose( + vmec.variables["raxis_cc"][:], + desc.variables["raxis_cc"][:], + rtol=5e-5, + atol=4e-3, + ) + np.testing.assert_allclose( + vmec.variables["zaxis_cs"][:], + desc.variables["zaxis_cs"][:], + rtol=5e-5, + atol=1e-3, + ) + np.testing.assert_allclose( + vmec.variables["rmin_surf"][:], desc.variables["rmin_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["rmax_surf"][:], desc.variables["rmax_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["zmax_surf"][:], desc.variables["zmax_surf"][:], rtol=5e-3 + ) + np.testing.assert_allclose( + vmec.variables["beta_vol"][:], desc.variables["beta_vol"][:], rtol=5e-5 + ) + np.testing.assert_allclose( + vmec.variables["betaxis"][:], desc.variables["betaxis"][:], rtol=5e-5 + ) + # Next, calculate some quantities and compare + # the DESC wout -> DESC (should be very close) + # and the DESC wout -> VMEC wout (should be approximately close) + vol_grid = LinearGrid( + rho=np.sqrt( + abs( + vmec.variables["phi"][:].filled() + / np.max(np.abs(vmec.variables["phi"][:].filled())) + ) + )[10::10], + M=15, + N=15, + NFP=eq.NFP, + axis=False, + sym=False, + ) + bdry_grid = LinearGrid(rho=1.0, M=15, N=15, NFP=eq.NFP, axis=False, sym=False) + + def test( + nc_str, + desc_str, + negate_DESC_quant=False, + use_nyq=True, + convert_sqrt_g_or_B_rho=False, + atol_desc_desc_wout=5e-5, + rtol_desc_desc_wout=1e-5, + atol_vmec_desc_wout=1e-5, + rtol_vmec_desc_wout=1e-2, + grid=vol_grid, + ): + """Helper fxn to evaluate Fourier series from wout and compare to DESC.""" + xm = desc.variables["xm_nyq"][:] if use_nyq else desc.variables["xm"][:] + xn = desc.variables["xn_nyq"][:] if use_nyq else desc.variables["xn"][:] + + si = abs(vmec.variables["phi"][:] / np.max(np.abs(vmec.variables["phi"][:]))) + rho = grid.nodes[:, 0] + s = rho**2 + # some quantities must be negated before comparison bc + # they are negative in the wout i.e. B^theta + negate = -1 if negate_DESC_quant else 1 + + quant_from_desc_wout = VMECIO.vmec_interpolate( + desc.variables[nc_str + "c"][:], + desc.variables[nc_str + "s"][:], + xm, + xn, + theta=-grid.nodes[:, 1], # -theta bc when we save wout we reverse theta + phi=grid.nodes[:, 2], + s=s, + sym=False, + si=si, + ) + + quant_from_vmec_wout = VMECIO.vmec_interpolate( + vmec.variables[nc_str + "c"][:], + vmec.variables[nc_str + "s"][:], + xm, + xn, + # pi - theta bc VMEC, when it gets a CW angle bdry, + # changes poloidal angle to theta -> pi-theta + theta=np.pi - grid.nodes[:, 1], + phi=grid.nodes[:, 2], + s=s, + sym=False, + si=si, + ) + + data = eq.compute(["rho", "sqrt(g)", desc_str], grid=grid) + # convert sqrt(g) or B_rho->B_psi if needed + quant_desc = ( + data[desc_str] / 2 / data["rho"] + if convert_sqrt_g_or_B_rho + else data[desc_str] + ) + + # add sqrt(g) factor if currents being compared + quant_desc = ( + quant_desc * abs(data["sqrt(g)"]) / 2 / data["rho"] + if "J" in desc_str + else quant_desc + ) + + np.testing.assert_allclose( + negate * quant_desc, + quant_from_desc_wout, + atol=atol_desc_desc_wout, + rtol=rtol_desc_desc_wout, + ) + np.testing.assert_allclose( + quant_from_desc_wout, + quant_from_vmec_wout, + atol=atol_vmec_desc_wout, + rtol=rtol_vmec_desc_wout, + ) + + # R & Z & lambda + test("rmn", "R", use_nyq=False) + test("zmn", "Z", use_nyq=False, atol_vmec_desc_wout=4e-2) + + # |B| + test("bmn", "|B|", rtol_desc_desc_wout=7e-4) + + # B^zeta + test("bsupvmn", "B^zeta") # ,rtol_desc_desc_wout=6e-5) + + # B_zeta + test("bsubvmn", "B_zeta", rtol_desc_desc_wout=3e-4) + + # hard to compare to VMEC for the currents, since + # VMEC F error is worse and equilibria are not exactly similar + # just compare back to DESC + test("currumn", "J^theta", atol_vmec_desc_wout=1e4) + test("currvmn", "J^zeta", negate_DESC_quant=True, atol_vmec_desc_wout=1e5) + + # can only compare lambda, sqrt(g) B_psi B^theta and B_theta at bdry + test( + "lmn", + "lambda", + use_nyq=False, + negate_DESC_quant=True, + grid=bdry_grid, + atol_desc_desc_wout=4e-4, + atol_vmec_desc_wout=5e-2, + ) + test( + "gmn", + "sqrt(g)", + convert_sqrt_g_or_B_rho=True, + negate_DESC_quant=True, + grid=bdry_grid, + rtol_desc_desc_wout=5e-4, + rtol_vmec_desc_wout=4e-2, + ) + test( + "bsupumn", + "B^theta", + negate_DESC_quant=True, + grid=bdry_grid, + atol_vmec_desc_wout=6e-4, + ) + test( + "bsubumn", + "B_theta", + negate_DESC_quant=True, + grid=bdry_grid, + atol_desc_desc_wout=1e-4, + atol_vmec_desc_wout=4e-4, + ) + test( + "bsubsmn", + "B_rho", + grid=bdry_grid, + convert_sqrt_g_or_B_rho=True, + rtol_vmec_desc_wout=6e-2, + atol_vmec_desc_wout=9e-3, + ) + + @pytest.mark.unit @pytest.mark.mpl_image_compare(tolerance=1) def test_plot_vmec_comparison():