diff --git a/derivative/differentiation.py b/derivative/differentiation.py index e633c68..dbf8657 100644 --- a/derivative/differentiation.py +++ b/derivative/differentiation.py @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra return dX.flatten() else: # order of operations coupled with _align_axes - extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis) + orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args + extra_dims = tuple( + length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis + ) moved_shape = (orig_shape[axis],) + extra_dims dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis) return dX