-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdt.pyx
147 lines (116 loc) · 4.53 KB
/
dt.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import itertools
from numpy cimport ndarray
import numpy as np
cimport numpy as np
cimport cython
# ----------------------------------------------------------------------------
# Distance Function
# ----------------------------------------------------------------------------
cdef class DistanceFunction(object):
"""Interface for defining distance functions
User-declared distance functions must inherit from this base class so that
the Cython-compiled code can access the methods provided.
"""
cdef intersection(self, int x0, int x1, double y0, double y1):
raise NotImplementedError
cdef envelope(self, int x, double y):
raise NotImplementedError
cdef class L2(DistanceFunction):
"""Squared Euclidean distance (L2)
L2 expresses distance of the form:
d(p,q) = a*(p - q)^2 + b*(p - q)
Keyword Args:
a (float): The quadratic slope (default: 1.0)
b (float): The quadratic offset (default: 0.0)
"""
cdef double a, b
def __init__(self, a=1.0, b=0.0):
self.a = a
self.b = b
cdef intersection(self, int x0, int x1, double y0, double y1):
return ((y1-y0) - self.b*(x1-x0) + self.a*(x1*x1 - x0*x0)) / (2*self.a*(x1-x0))
cdef envelope(self, int x, double y):
return self.a*x*x + self.b*x + y
# ----------------------------------------------------------------------------
# Distance Transform
# ----------------------------------------------------------------------------
def compute(x, axes=None, f=L2):
"""Compute the distance transform of a sampled function
Compute the N-dimensional distance transform using the method described in:
P. Felzenszwalb, D. Huttenlocher "Distance Transforms of Sampled Functions"
Args:
x (ndarray): An n-dimensional array representing the data term
Keyword Args:
axes (tuple): The axes over which to perform the distance transforms. The
order does not matter. (default all axes)
f (DistanceFunction): The distance function to apply (default L2)
"""
shape = x.shape
axes = axes if axes else tuple(range(x.ndim))
f = f() if isinstance(f, type) else f
# initialize the minima and argument arrays
min = x.copy()
arg = tuple(np.empty(shape, dtype=int) for axis in axes)
# create some scratch space for the transforms
v = np.empty((max(shape)+1,), dtype=int)
z = np.empty((max(shape)+1,), dtype=float)
# compute transforms over the given axes
for n, axis in enumerate(axes):
numel = shape[axis]
minbuf = np.empty((numel,), dtype=float)
argbuf = np.empty((numel,), dtype=int)
slices = map(xrange, shape)
slices[axis] = [Ellipsis]
for index in itertools.product(*slices):
# compute the optimal minima
_compute1d(min[index], f, minbuf, argbuf, z, v)
min[index] = minbuf
arg[n][index] = argbuf
nindex = tuple(argbuf if i is Ellipsis else i for i in index)
# update the optimal arguments across preceding axes
for m in reversed(range(n)):
arg[m][index] = arg[m][nindex]
# return the minimum and the argument
return min, arg
# ----------------------------------------------------------------------------
# 1D Distance Transform (Cython)
# ----------------------------------------------------------------------------
@cython.boundscheck(False)
cdef _compute1d(
ndarray[double] x, DistanceFunction f, # input array and distance function
ndarray[double] min, ndarray[long] arg, # output arrays
ndarray[double] z, ndarray[long] v): # working buffers
"""Low-level 1D distance transform
This Cython function provides the implementation of the 1D distance transform.
It is compiled for speed - it is roughly 150x faster than the same Python
implementation without type declarations. It optimizes:
arg min f(p,q) + x(q)
q
Args:
x (ndarray): The input
f (DistanceFunction): The distance function
min (ndarray): The minimum solution
arg (ndarray): The argument of the minimum
z (ndarray): A double-precision working buffer of length N+1
v (ndarray): An integer-precision working buffer of length N
"""
# predeclare object types
cdef int N = x.shape[0]
cdef int k, q
cdef double s
z.fill(np.inf)
# initial conditions
v[0], z[0] = 0, -np.inf
# compute the intersection points
k = 0
for q in xrange(1,N):
s = f.intersection(v[k], q, x[v[k]], x[q])
while s <= z[k]:
k = k-1
s = f.intersection(v[k], q, x[v[k]], x[q])
k, v[k], z[k] = k+1, q, s
# compute the projection onto the lower envelope
k = 0
for q in xrange(N):
while z[k+1] < q: k += 1
min[q], arg[q] = f.envelope(q-v[k], x[v[k]]), v[k]