forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SparseMatMul.cpp
279 lines (235 loc) · 8.62 KB
/
SparseMatMul.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/StridedRandomAccessor.h>
#include <ATen/native/CompositeRandomAccessor.h>
#include <c10/util/irange.h>
#include <unordered_map>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_sparse_sparse_matmul_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like_native.h>
#endif
namespace at { namespace native {
using namespace at::sparse;
/*
This is an implementation of the SMMP algorithm:
"Sparse Matrix Multiplication Package (SMMP)"
Randolph E. Bank and Craig C. Douglas
https://doi.org/10.1007/BF02070824
*/
namespace {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
/*
Expands a compressed row pointer into a row indices array
Inputs:
`n_row` is the number of rows in `Ap`
`Ap` is the row pointer
Output:
`Bi` is the row indices
*/
for (const auto i : c10::irange(n_row)) {
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
Bi[jj] = i;
}
}
}
int64_t _csr_matmult_maxnnz(
const int64_t n_row,
const int64_t n_col,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[]) {
/*
Compute needed buffer size for matrix `C` in `C = A@B` operation.
The matrices should be in proper CSR structure, and their dimensions
should be compatible.
*/
std::vector<int64_t> mask(n_col, -1);
int64_t nnz = 0;
for (const auto i : c10::irange(n_row)) {
int64_t row_nnz = 0;
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
int64_t j = Aj[jj];
for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
int64_t k = Bj[kk];
if (mask[k] != i) {
mask[k] = i;
row_nnz++;
}
}
}
int64_t next_nnz = nnz + row_nnz;
nnz = next_nnz;
}
return nnz;
}
template<class scalar_t>
void _csr_matmult(
const int64_t n_row,
const int64_t n_col,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Ax[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Bx[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scalar_t Cx[]) {
/*
Compute CSR entries for matrix C = A@B.
The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
should be compatible.
Inputs:
`n_row` - number of row in A
`n_col` - number of columns in B
`Ap[n_row+1]` - row pointer
`Aj[nnz(A)]` - column indices
`Ax[nnz(A)] - nonzeros
`Bp[?]` - row pointer
`Bj[nnz(B)]` - column indices
`Bx[nnz(B)]` - nonzeros
Outputs:
`Cp[n_row+1]` - row pointer
`Cj[nnz(C)]` - column indices
`Cx[nnz(C)]` - nonzeros
Note:
Output arrays Cp, Cj, and Cx must be preallocated
*/
std::vector<int64_t> next(n_col, -1);
std::vector<scalar_t> sums(n_col, 0);
int64_t nnz = 0;
Cp[0] = 0;
for (const auto i : c10::irange(n_row)) {
int64_t head = -2;
int64_t length = 0;
int64_t jj_start = Ap[i];
int64_t jj_end = Ap[i + 1];
for (const auto jj : c10::irange(jj_start, jj_end)) {
int64_t j = Aj[jj];
scalar_t v = Ax[jj];
int64_t kk_start = Bp[j];
int64_t kk_end = Bp[j + 1];
for (const auto kk : c10::irange(kk_start, kk_end)) {
int64_t k = Bj[kk];
sums[k] += v * Bx[kk];
if (next[k] == -1) {
next[k] = head;
head = k;
length++;
}
}
}
for (const auto jj : c10::irange(length)) {
(void)jj; //Suppress unused variable warning
// NOTE: the linked list that encodes col indices
// is not guaranteed to be sorted.
Cj[nnz] = head;
Cx[nnz] = sums[head];
nnz++;
int64_t temp = head;
head = next[head];
next[temp] = -1; // clear arrays
sums[temp] = 0;
}
// Make sure that col indices are sorted.
// TODO: a better approach is to implement a CSR @ CSC kernel.
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
auto kv_accessor = CompositeRandomAccessorCPU<
decltype(col_indices_accessor), decltype(val_accessor)
>(col_indices_accessor, val_accessor);
std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool {
return get<0>(lhs) < get<0>(rhs);
});
Cp[i + 1] = nnz;
}
}
template <typename scalar_t>
void sparse_matmul_kernel(
Tensor& output,
const Tensor& mat1,
const Tensor& mat2) {
/*
Computes the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
*/
auto M = mat1.size(0);
auto N = mat2.size(1);
const auto mat1_csr = mat1.to_sparse_csr();
const auto mat2_csr = mat2.to_sparse_csr();
const auto nnz = _csr_matmult_maxnnz(
M,
N,
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>());
auto output_indices = output._indices();
auto output_values = output._values();
Tensor output_indptr = at::empty({M + 1}, kLong);
at::native::resize_output(output_indices, {2, nnz});
at::native::resize_output(output_values, nnz);
Tensor output_row_indices = output_indices.select(0, 0);
Tensor output_col_indices = output_indices.select(0, 1);
// TODO: replace with a CSR @ CSC kernel for better performance.
_csr_matmult(
M,
N,
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().data_ptr<int64_t>(),
mat1_csr.values().data_ptr<scalar_t>(),
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.values().data_ptr<scalar_t>(),
output_indptr.data_ptr<int64_t>(),
output_col_indices.data_ptr<int64_t>(),
output_values.data_ptr<scalar_t>());
csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
output._coalesced_(true);
}
} // end anonymous namespace
Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
TORCH_INTERNAL_ASSERT(mat1_.is_sparse());
TORCH_INTERNAL_ASSERT(mat2_.is_sparse());
TORCH_CHECK(mat1_.dim() == 2);
TORCH_CHECK(mat2_.dim() == 2);
TORCH_CHECK(mat1_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat1_.dense_dim(), "D values");
TORCH_CHECK(mat2_.dense_dim() == 0, "sparse_sparse_matmul_cpu: scalar values expected, got ", mat2_.dense_dim(), "D values");
TORCH_CHECK(
mat1_.size(1) == mat2_.size(0), "mat1 and mat2 shapes cannot be multiplied (",
mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
"mat1 dtype ", mat1_.scalar_type(), " does not match mat2 dtype ", mat2_.scalar_type());
auto output = at::native::empty_like(mat1_);
output.sparse_resize_and_clear_({mat1_.size(0), mat2_.size(1)}, mat1_.sparse_dim(), 0);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(mat1_.scalar_type(), "sparse_matmul", [&] {
sparse_matmul_kernel<scalar_t>(output, mat1_.coalesce(), mat2_.coalesce());
});
return output;
}
} // namespace native
} // namespace at