diff --git a/pytransform3d/trajectories.py b/pytransform3d/trajectories.py index 7778f7614..e7cde8129 100644 --- a/pytransform3d/trajectories.py +++ b/pytransform3d/trajectories.py @@ -12,6 +12,7 @@ transform_from_exponential_coordinates, screw_axis_from_exponential_coordinates, screw_parameters_from_screw_axis, screw_axis_from_screw_parameters) +from ._array_api_compat import array_namespace def invert_transforms(A2Bs): @@ -27,14 +28,17 @@ def invert_transforms(A2Bs): B2As : array, shape (..., 4, 4) Transforms from frames B to frames A """ - A2Bs = np.asarray(A2Bs) + xp = array_namespace(A2Bs) + + A2Bs = xp.asarray(A2Bs) instances_shape = A2Bs.shape[:-2] - B2As = np.empty_like(A2Bs) + B2As = xp.empty_like(A2Bs) # ( R t )^-1 ( R^T -R^T*t ) # ( 0 1 ) = ( 0 1 ) - B2As[..., :3, :3] = A2Bs[..., :3, :3].transpose( + B2As[..., :3, :3] = xp.permute_dims( + A2Bs[..., :3, :3], list(range(A2Bs.ndim - 2)) + [A2Bs.ndim - 1, A2Bs.ndim - 2]) - B2As[..., :3, 3] = np.einsum( + B2As[..., :3, 3] = xp.einsum( "nij,nj->ni", -B2As[..., :3, :3].reshape(-1, 3, 3), A2Bs[..., :3, 3].reshape(-1, 3)).reshape(