diff --git a/.github/workflows/pull-docs.yml b/.github/workflows/pull-docs.yml index db86adb..73446dc 100644 --- a/.github/workflows/pull-docs.yml +++ b/.github/workflows/pull-docs.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ "3.7" ] + python-version: [ "3.9" ] poetry-version: [ "1.2.1" ] steps: # ====== diff --git a/.github/workflows/push-test.yml b/.github/workflows/push-test.yml index b3b1659..323ab4b 100644 --- a/.github/workflows/push-test.yml +++ b/.github/workflows/push-test.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7"] + python-version: ["3.9"] poetry-version: ["1.2.1"] steps: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aeb86f5..cc27b21 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: name: release strategy: matrix: - python-version: ["3.7"] + python-version: ["3.9"] poetry-version: ["1.2.1"] steps: # ====== diff --git a/derivative/__version__.py b/derivative/__version__.py index 3775afe..1e25219 100644 --- a/derivative/__version__.py +++ b/derivative/__version__.py @@ -1 +1 @@ -__version__: str = '0.6.0' +__version__: str = '0.6.1' diff --git a/derivative/differentiation.py b/derivative/differentiation.py index 4942f85..e633c68 100644 --- a/derivative/differentiation.py +++ b/derivative/differentiation.py @@ -1,5 +1,6 @@ import abc import numpy as np +from numpy.typing import NDArray from .utils import _memoize_arrays @@ -216,41 +217,38 @@ def x(self, X, t, axis=1): Returns: :obj:`ndarray` of float: Returns dX/dt along axis. """ - X, flat = _align_axes(X, t, axis) + X, orig_shape = _align_axes(X, t, axis) if X.shape[1] == 1: dX = X else: dX = np.array([list(self.compute_x_for(t, x, np.arange(len(t)))) for x in X]) - return _restore_axes(dX, axis, flat) + return _restore_axes(dX, axis, orig_shape) -def _align_axes(X, t, axis): - # Cast +def _align_axes(X, t, axis) -> tuple[NDArray, tuple[int, ...]]: X = np.array(X) - flat = False - # Check shape and axis - if len(X.shape) == 1: + orig_shape = X.shape + # By convention, differentiate axis 1 + if len(orig_shape) == 1: X = X.reshape(1, -1) - flat = True - elif len(X.shape) == 2: - if axis == 0: - X = X.T - elif axis == 1: - pass - else: - raise ValueError("Invalid axis.") else: - raise ValueError("Invalid shape of X.") - + ax_len = orig_shape[axis] + # order of operations coupled with _restore_axes. Move differentiation axis to + # zero so reshape does not skew differentiation axis + X = np.moveaxis(X, axis, 0).reshape((ax_len, -1)).T if X.shape[1] != len(t): raise ValueError("Desired X axis size does not match t size.") - return X, flat + return X, orig_shape -def _restore_axes(dX, axis, flat): - if flat: +def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArray: + if len(orig_shape) == 1: return dX.flatten() else: - return dX if axis == 1 else dX.T + # order of operations coupled with _align_axes + extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis) + moved_shape = (orig_shape[axis],) + extra_dims + dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis) + return dX diff --git a/pyproject.toml b/pyproject.toml index 107e0bd..8c2510b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "derivative" -version = "0.6.0" +version = "0.6.1" description = "Numerical differentiation in python." repository = "https://github.com/andgoldschmidt/derivative" documentation = "https://derivative.readthedocs.io/" @@ -15,7 +15,7 @@ readme = "README.rst" [tool.poetry.dependencies] -python = "^3.7" +python = "^3.9" numpy = "^1.18.3" scipy = "^1.4.1" scikit-learn = "^1" @@ -41,4 +41,4 @@ requires = ["poetry-core>=1.1.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.plugins.'derivative.hyperparam_opt'] -"kalman.default" = "derivative.utils:_default_kalman" \ No newline at end of file +"kalman.default" = "derivative.utils:_default_kalman" diff --git a/tests/test_interface.py b/tests/test_interface.py index 00e5085..49a88fc 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -33,6 +33,10 @@ def test_axis(): shape2 = ts.shape test3 = dxdt(ts, ts, axis=1).shape assert shape2 == test3 + n = 4 + x3d = np.tile(np.arange(n), (n, n, 1)) + test_4 = dxdt(x3d, np.arange(n), axis=2) + np.testing.assert_array_equal(test_4, np.ones((n,n,n))) def test_empty():