diff --git a/derivative/differentiation.py b/derivative/differentiation.py index e633c68..dbf8657 100644 --- a/derivative/differentiation.py +++ b/derivative/differentiation.py @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra return dX.flatten() else: # order of operations coupled with _align_axes - extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis) + orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args + extra_dims = tuple( + length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis + ) moved_shape = (orig_shape[axis],) + extra_dims dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis) return dX diff --git a/derivative/utils.py b/derivative/utils.py index 707f866..616c88c 100644 --- a/derivative/utils.py +++ b/derivative/utils.py @@ -10,7 +10,11 @@ else: from importlib.metadata import entry_points -hyperparam_algorithms = entry_points(group="derivative.hyperparam_opt") +try: + hyperparam_algorithms = entry_points(group="derivative.hyperparam_opt") +except TypeError as exc: + import sys + raise TypeError(f"Oops, no 'group' kwarg. function imported from {entry_points.__module__}") def _load_hyperparam_func(func_key): diff --git a/pyproject.toml b/pyproject.toml index 238ebce..5e8e4f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "derivative" -version = "0.6.1" +version = "0.6.2" description = "Numerical differentiation in python." repository = "https://github.com/andgoldschmidt/derivative" documentation = "https://derivative.readthedocs.io/" @@ -20,6 +20,9 @@ numpy = "^1.18.3" scipy = "^1.4.1" scikit-learn = "^1" +# third-party access to the functionality of importlib.metadata +importlib-metadata = "^7.1.0" + # docs sphinx = {version = "^5", optional = true} nbsphinx = {version = "^0.6.1", optional = true} @@ -32,12 +35,12 @@ matplotlib = {version = "^3.2.1", optional = true} pytest = {version = "^7", optional = true} [tool.poetry.extras] -docs = ["sphinx", "nbsphinx", "ipykernel", "jupyter_client", "matplotlib", "pandoc"] +docs = ["sphinx", "nbsphinx", "ipykernel", "jupyter_client", "matplotlib"] dev = ["pytest"] [build-system] requires = ["poetry-core>=1.1.0"] build-backend = "poetry.core.masonry.api" -[tool.poetry.plugins.'derivative.hyperparam_opt'] +[tool.poetry.plugins."derivative.hyperparam_opt"] "kalman.default" = "derivative.utils:_default_kalman" diff --git a/tests/test_interface.py b/tests/test_interface.py index 49a88fc..b70c7f8 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -98,4 +98,12 @@ def test_hyperparam_entrypoint(): func = utils._load_hyperparam_func("kalman.default") expected = 1 result = func(None, None) - assert result == expected \ No newline at end of file + assert result == expected + + +def test_negative_axis(): + t = np.arange(3) + x = np.ones((2, 3, 2)) + axis = -2 + dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1) + assert x.shape == dx.shape \ No newline at end of file