forked from IIPCVLAB/LCCNet
-
Notifications
You must be signed in to change notification settings - Fork 14
/
quaternion_distances.py
72 lines (61 loc) · 2.18 KB
/
quaternion_distances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# -------------------------------------------------------------------
# Copyright (C) 2020 Università degli studi di Milano-Bicocca, iralab
# Author: Daniele Cattaneo ([email protected])
# Released under Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# http://creativecommons.org/licenses/by-nc-sa/4.0/
# -------------------------------------------------------------------
# Modified Author: Xudong Lv
# based on github.com/cattaneod/CMRNet/blob/master/quaternion_distances.py
import numpy as np
import torch
def quatmultiply(q, r, device='cpu'):
"""
Batch quaternion multiplication
Args:
q (torch.Tensor/np.ndarray): shape=[Nx4]
r (torch.Tensor/np.ndarray): shape=[Nx4]
device (str): 'cuda' or 'cpu'
Returns:
torch.Tensor: shape=[Nx4]
"""
if isinstance(q, torch.Tensor):
t = torch.zeros(q.shape[0], 4, device=device)
elif isinstance(q, np.ndarray):
t = np.zeros(q.shape[0], 4)
else:
raise TypeError("Type not supported")
t[:, 0] = r[:, 0] * q[:, 0] - r[:, 1] * q[:, 1] - r[:, 2] * q[:, 2] - r[:, 3] * q[:, 3]
t[:, 1] = r[:, 0] * q[:, 1] + r[:, 1] * q[:, 0] - r[:, 2] * q[:, 3] + r[:, 3] * q[:, 2]
t[:, 2] = r[:, 0] * q[:, 2] + r[:, 1] * q[:, 3] + r[:, 2] * q[:, 0] - r[:, 3] * q[:, 1]
t[:, 3] = r[:, 0] * q[:, 3] - r[:, 1] * q[:, 2] + r[:, 2] * q[:, 1] + r[:, 3] * q[:, 0]
return t
def quatinv(q):
"""
Batch quaternion inversion
Args:
q (torch.Tensor/np.ndarray): shape=[Nx4]
Returns:
torch.Tensor/np.ndarray: shape=[Nx4]
"""
if isinstance(q, torch.Tensor):
t = q.clone()
elif isinstance(q, np.ndarray):
t = q.copy()
else:
raise TypeError("Type not supported")
t *= -1
t[:, 0] *= -1
return t
def quaternion_distance(q, r, device):
"""
Batch quaternion distances, used as loss
Args:
q (torch.Tensor): shape=[Nx4]
r (torch.Tensor): shape=[Nx4]
device (str): 'cuda' or 'cpu'
Returns:
torch.Tensor: shape=[N]
"""
t = quatmultiply(q, quatinv(r), device)
return 2 * torch.atan2(torch.norm(t[:, 1:], dim=1), torch.abs(t[:, 0]))