Skip to content

Commit

Permalink
ENH: _restore_axes accomodates for negative axis arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
yb6599 committed Jun 5, 2024
1 parent 9474e7f commit 8cde00d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion derivative/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra
return dX.flatten()
else:
# order of operations coupled with _align_axes
extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis)
orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args
extra_dims = tuple(
length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis
)
moved_shape = (orig_shape[axis],) + extra_dims
dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis)
return dX

0 comments on commit 8cde00d

Please sign in to comment.