forked from pollen-robotics/dtw
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdtw.py
120 lines (112 loc) · 4.32 KB
/
dtw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from numpy import array, zeros, argmin, inf, equal, ndim
from scipy.spatial.distance import cdist
def dtw(x, y, dist):
"""
Computes Dynamic Time Warping (DTW) of two sequences.
:param array x: N1*M array
:param array y: N2*M array
:param func dist: distance used as cost measure
Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
"""
assert len(x)
assert len(y)
r, c = len(x), len(y)
D0 = zeros((r + 1, c + 1))
D0[0, 1:] = inf
D0[1:, 0] = inf
D1 = D0[1:, 1:] # view
for i in range(r):
for j in range(c):
D1[i, j] = dist(x[i], y[j])
C = D1.copy()
for i in range(r):
for j in range(c):
D1[i, j] += min(D0[i, j], D0[i, j+1], D0[i+1, j])
if len(x)==1:
path = zeros(len(y)), range(len(y))
elif len(y) == 1:
path = range(len(x)), zeros(len(x))
else:
path = _traceback(D0)
return D1[-1, -1] / sum(D1.shape), C, D1, path
def fastdtw(x, y, dist):
"""
Computes Dynamic Time Warping (DTW) of two sequences in a faster way.
Instead of iterating through each element and calculating each distance,
this uses the cdist function from scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
:param array x: N1*M array
:param array y: N2*M array
:param string or func dist: distance parameter for cdist. When string is given, cdist uses optimized functions for the distance metrics.
If a string is passed, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
"""
assert len(x)
assert len(y)
if ndim(x)==1:
x = x.reshape(-1,1)
if ndim(y)==1:
y = y.reshape(-1,1)
r, c = len(x), len(y)
D0 = zeros((r + 1, c + 1))
D0[0, 1:] = inf
D0[1:, 0] = inf
D1 = D0[1:, 1:]
D0[1:,1:] = cdist(x,y,dist)
C = D1.copy()
for i in range(r):
for j in range(c):
D1[i, j] += min(D0[i, j], D0[i, j+1], D0[i+1, j])
if len(x)==1:
path = zeros(len(y)), range(len(y))
elif len(y) == 1:
path = range(len(x)), zeros(len(x))
else:
path = _traceback(D0)
return D1[-1, -1] / sum(D1.shape), C, D1, path
def _traceback(D):
i, j = array(D.shape) - 2
p, q = [i], [j]
while ((i > 0) or (j > 0)):
tb = argmin((D[i, j], D[i, j+1], D[i+1, j]))
if (tb == 0):
i -= 1
j -= 1
elif (tb == 1):
i -= 1
else: # (tb == 2):
j -= 1
p.insert(0, i)
q.insert(0, j)
return array(p), array(q)
if __name__ == '__main__':
if 0: # 1-D numeric
from sklearn.metrics.pairwise import manhattan_distances
x = [0, 0, 1, 1, 2, 4, 2, 1, 2, 0]
y = [1, 1, 1, 2, 2, 2, 2, 3, 2, 0]
dist_fun = manhattan_distances
elif 0: # 2-D numeric
from sklearn.metrics.pairwise import euclidean_distances
x = [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [4, 3], [2, 3], [1, 1], [2, 2], [0, 1]]
y = [[1, 0], [1, 1], [1, 1], [2, 1], [4, 3], [4, 3], [2, 3], [3, 1], [1, 2], [1, 0]]
dist_fun = euclidean_distances
else: # 1-D list of strings
from nltk.metrics.distance import edit_distance
#x = ['we', 'shelled', 'clams', 'for', 'the', 'chowder']
#y = ['class', 'too']
x = ['i', 'soon', 'found', 'myself', 'muttering', 'to', 'the', 'walls']
y = ['see', 'drown', 'himself']
#x = 'we talked about the situation'.split()
#y = 'we talked about the situation'.split()
dist_fun = edit_distance
dist, cost, acc, path = dtw(x, y, dist_fun)
# vizualize
from matplotlib import pyplot as plt
plt.imshow(cost.T, origin='lower', cmap=plt.cm.Reds, interpolation='nearest')
plt.plot(path[0], path[1], '-o') # relation
plt.xticks(range(len(x)), x)
plt.yticks(range(len(y)), y)
plt.xlabel('x')
plt.ylabel('y')
plt.axis('tight')
plt.title('Minimum distance: {}'.format(dist))
plt.show()