forked from projectM-visualizer/projectm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathHungarianMethod.hpp
192 lines (175 loc) · 7.97 KB
/
HungarianMethod.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#ifndef HUNGARIAN_METHOD_HPP
#define HUNGARIAN_METHOD_HPP
//#include "Common.hpp"
#include <cstdlib>
#include <cstdio>
#include <cstring>
#include <limits>
/// A function object which calculates the maximum-weighted bipartite matching between
/// two sets via the hungarian method.
template <int N=20>
class HungarianMethod {
public :
static const size_t MAX_SIZE = N;
private:
size_t n, max_match; //n workers and n jobs
double lx[N], ly[N]; //labels of X and Y parts
int xy[N]; //xy[x] - vertex that is matched with x,
int yx[N]; //yx[y] - vertex that is matched with y
bool S[N], T[N]; //sets S and T in algorithm
double slack[N]; //as in the algorithm description
double slackx[N]; //slackx[y] such a vertex, that
// l(slackx[y]) + l(y) - w(slackx[y],y) = slack[y]
int prev[N]; //array for memorizing alternating paths
void init_labels(const double cost[N][N])
{
memset(lx, 0, sizeof(lx));
memset(ly, 0, sizeof(ly));
for (unsigned int x = 0; x < n; x++)
for (unsigned int y = 0; y < n; y++)
lx[x] = std::max(lx[x], cost[x][y]);
}
void augment(const double cost[N][N]) //main function of the algorithm
{
if (max_match == n) return; //check wether matching is already perfect
unsigned int x, y, root = 0; //just counters and root vertex
int q[N], wr = 0, rd = 0; //q - queue for bfs, wr,rd - write and read
//pos in queue
memset(S, false, sizeof(S)); //init set S
memset(T, false, sizeof(T)); //init set T
memset(prev, -1, sizeof(prev)); //init set prev - for the alternating tree
for (x = 0; x < n; x++) //finding root of the tree
if (xy[x] == -1)
{
q[wr++] = root = x;
prev[x] = -2;
S[x] = true;
break;
}
for (y = 0; y < n; y++) //initializing slack array
{
slack[y] = lx[root] + ly[y] - cost[root][y];
slackx[y] = root;
}
while (true) //main cycle
{
while (rd < wr) //building tree with bfs cycle
{
x = q[rd++]; //current vertex from X part
for (y = 0; y < n; y++) //iterate through all edges in equality graph
if (cost[x][y] == lx[x] + ly[y] && !T[y])
{
if (yx[y] == -1) break; //an exposed vertex in Y found, so
//augmenting path exists!
T[y] = true; //else just add y to T,
q[wr++] = yx[y]; //add vertex yx[y], which is matched
//with y, to the queue
add_to_tree(yx[y], x, cost); //add edges (x,y) and (y,yx[y]) to the tree
}
if (y < n) break; //augmenting path found!
}
if (y < n) break; //augmenting path found!
update_labels(); //augmenting path not found, so improve labeling
wr = rd = 0;
for (y = 0; y < n; y++)
//in this cycle we add edges that were added to the equality graph as a
//result of improving the labeling, we add edge (slackx[y], y) to the tree if
//and only if !T[y] && slack[y] == 0, also with this edge we add another one
//(y, yx[y]) or augment the matching, if y was exposed
if (!T[y] && slack[y] == 0)
{
if (yx[y] == -1) //exposed vertex in Y found - augmenting path exists!
{
x = slackx[y];
break;
}
else
{
T[y] = true; //else just add y to T,
if (!S[yx[y]])
{
q[wr++] = yx[y]; //add vertex yx[y], which is matched with
//y, to the queue
add_to_tree(yx[y], slackx[y],cost); //and add edges (x,y) and (y,
//yx[y]) to the tree
}
}
}
if (y < n) break; //augmenting path found!
}
if (y < n) //we found augmenting path!
{
max_match++; //increment matching
//in this cycle we inverse edges along augmenting path
for (int cx = x, cy = y, ty; cx != -2; cx = prev[cx], cy = ty)
{
ty = xy[cx];
yx[cy] = cx;
xy[cx] = cy;
}
augment(cost); //recall function, go to step 1 of the algorithm
}
}//end of augment() function
void update_labels()
{
unsigned int x, y;
double delta = std::numeric_limits<double>::max();
for (y = 0; y < n; y++) //calculate delta using slack
if (!T[y])
delta = std::min(delta, slack[y]);
for (x = 0; x < n; x++) //update X labels
if (S[x]) lx[x] -= delta;
for (y = 0; y < n; y++) //update Y labels
if (T[y]) ly[y] += delta;
for (y = 0; y < n; y++) //update slack array
if (!T[y])
slack[y] -= delta;
}
void add_to_tree(int x, int prevx, const double cost[N][N])
//x - current vertex,prevx - vertex from X before x in the alternating path,
//so we add edges (prevx, xy[x]), (xy[x], x)
{
S[x] = true; //add x to S
prev[x] = prevx; //we need this when augmenting
for (unsigned int y = 0; y < n; y++) //update slacks, because we add new vertex to S
if (lx[x] + ly[y] - cost[x][y] < slack[y])
{
slack[y] = lx[x] + ly[y] - cost[x][y];
slackx[y] = x;
}
}
public:
/// Computes the best matching of two sets given its cost matrix.
/// See the matching() method to get the computed match result.
/// \param cost a matrix of two sets I,J where cost[i][j] is the weight of edge i->j
/// \param logicalSize the number of elements in both I and J
/// \returns the total cost of the best matching
inline double operator()(const double cost[N][N], size_t logicalSize)
{
n = logicalSize;
assert(n <= N);
double ret = 0; //weight of the optimal matching
max_match = 0; //number of vertices in current matching
memset(xy, -1, sizeof(xy));
memset(yx, -1, sizeof(yx));
init_labels(cost); //step 0
augment(cost); //steps 1-3
for (unsigned int x = 0; x < n; x++) //forming answer there
ret += cost[x][xy[x]];
return ret;
}
/// Gets the matching element in 2nd set of the ith element in the first set
/// \param i the index of the ith element in the first set (passed in operator())
/// \returns an index j, denoting the matched jth element of the 2nd set
inline int matching(int i) const {
return xy[i];
}
/// Gets the matching element in 1st set of the jth element in the 2nd set
/// \param j the index of the jth element in the 2nd set (passed in operator())
/// \returns an index i, denoting the matched ith element of the 1st set
/// \note inverseMatching(matching(i)) == i
inline int inverseMatching(int j) const {
return yx[j];
}
};
#endif