Skip to content

Commit

Permalink
Merge branch 'Jacob-Stevens-Haas-diff-anyaxis'
Browse files Browse the repository at this point in the history
  • Loading branch information
andgoldschmidt committed May 14, 2024
2 parents f0d566d + 394253e commit 99ee097
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.7" ]
python-version: [ "3.9" ]
poetry-version: [ "1.2.1" ]
steps:
# ======
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/push-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7"]
python-version: ["3.9"]
poetry-version: ["1.2.1"]

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
name: release
strategy:
matrix:
python-version: ["3.7"]
python-version: ["3.9"]
poetry-version: ["1.2.1"]
steps:
# ======
Expand Down
2 changes: 1 addition & 1 deletion derivative/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__: str = '0.6.0'
__version__: str = '0.6.1'
40 changes: 19 additions & 21 deletions derivative/differentiation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import numpy as np
from numpy.typing import NDArray

from .utils import _memoize_arrays

Expand Down Expand Up @@ -216,41 +217,38 @@ def x(self, X, t, axis=1):
Returns:
:obj:`ndarray` of float: Returns dX/dt along axis.
"""
X, flat = _align_axes(X, t, axis)
X, orig_shape = _align_axes(X, t, axis)

if X.shape[1] == 1:
dX = X
else:
dX = np.array([list(self.compute_x_for(t, x, np.arange(len(t)))) for x in X])

return _restore_axes(dX, axis, flat)
return _restore_axes(dX, axis, orig_shape)


def _align_axes(X, t, axis):
# Cast
def _align_axes(X, t, axis) -> tuple[NDArray, tuple[int, ...]]:
X = np.array(X)
flat = False
# Check shape and axis
if len(X.shape) == 1:
orig_shape = X.shape
# By convention, differentiate axis 1
if len(orig_shape) == 1:
X = X.reshape(1, -1)
flat = True
elif len(X.shape) == 2:
if axis == 0:
X = X.T
elif axis == 1:
pass
else:
raise ValueError("Invalid axis.")
else:
raise ValueError("Invalid shape of X.")

ax_len = orig_shape[axis]
# order of operations coupled with _restore_axes. Move differentiation axis to
# zero so reshape does not skew differentiation axis
X = np.moveaxis(X, axis, 0).reshape((ax_len, -1)).T
if X.shape[1] != len(t):
raise ValueError("Desired X axis size does not match t size.")
return X, flat
return X, orig_shape


def _restore_axes(dX, axis, flat):
if flat:
def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArray:
if len(orig_shape) == 1:
return dX.flatten()
else:
return dX if axis == 1 else dX.T
# order of operations coupled with _align_axes
extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis)
moved_shape = (orig_shape[axis],) + extra_dims
dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis)
return dX
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "derivative"
version = "0.6.0"
version = "0.6.1"
description = "Numerical differentiation in python."
repository = "https://github.com/andgoldschmidt/derivative"
documentation = "https://derivative.readthedocs.io/"
Expand All @@ -15,7 +15,7 @@ readme = "README.rst"


[tool.poetry.dependencies]
python = "^3.7"
python = "^3.9"
numpy = "^1.18.3"
scipy = "^1.4.1"
scikit-learn = "^1"
Expand All @@ -41,4 +41,4 @@ requires = ["poetry-core>=1.1.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.plugins.'derivative.hyperparam_opt']
"kalman.default" = "derivative.utils:_default_kalman"
"kalman.default" = "derivative.utils:_default_kalman"
4 changes: 4 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def test_axis():
shape2 = ts.shape
test3 = dxdt(ts, ts, axis=1).shape
assert shape2 == test3
n = 4
x3d = np.tile(np.arange(n), (n, n, 1))
test_4 = dxdt(x3d, np.arange(n), axis=2)
np.testing.assert_array_equal(test_4, np.ones((n,n,n)))


def test_empty():
Expand Down

0 comments on commit 99ee097

Please sign in to comment.