forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSoftMax.cu
677 lines (581 loc) · 25.6 KB
/
SoftMax.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <type_traits>
#include <ATen/native/cuda/PersistentSoftmax.cuh>
namespace at {
namespace native {
namespace {
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxForwardEpilogue {
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: logsum(max_input + std::log(sum)) {}
__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(input - logsum);
}
const AccumT logsum;
};
template<typename T, typename AccumT, typename OutT>
struct LogSoftMaxBackwardEpilogue {
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
: sum(sum) {}
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum);
}
const AccumT sum;
};
template<typename T, typename AccumT, typename OutT>
struct SoftMaxForwardEpilogue {
__device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input)
, sum(sum) {}
__device__ __forceinline__ OutT operator()(T input) const {
return static_cast<OutT>(std::exp(input - max_input) / sum);
}
const AccumT max_input;
const AccumT sum;
};
template<typename T, typename AccumT, typename OutT>
struct SoftMaxBackwardEpilogue {
__device__ __forceinline__ SoftMaxBackwardEpilogue(AccumT sum)
: sum(sum) {}
// XXX: gradOutput that we get here is really gradOutput * output
// Look for cmul in SoftMax_updateGradInput
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const {
return static_cast<T>(gradOutput - output * sum);
}
const AccumT sum;
};
////////////////////////////////////////////////////////////////////////////////
// Spatial kernel (fast with large inner_size and small dim_size)
////////////////////////////////////////////////////////////////////////////////
// Let's assume that our input has been flattened to have only three dimension:
// outer x dim x inner
// The spatial algorithm tries to parallelize along all of them.
// Within a 2d block threadIdx.y parallelizes over dim slices, and threads that
// share it will speed up reductions over dim (along axis x).
// The 2d grid is used to parallelize inner dimension over y axis and outer over x.
inline dim3 SpatialSoftMax_getGridSize(
dim3 block, uint32_t max_active_blocks,
uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
// First, tile as many blocks as we can over the y axis
uint32_t inner_blocks = (inner_size + block.y - 1) / block.y;
if (inner_blocks > max_active_blocks)
inner_blocks = max_active_blocks;
// Fill the x axis with as many blocks as we can fit (a little more is ok too)
uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
if (outer_blocks > outer_size)
outer_blocks = outer_size;
return dim3(outer_blocks, inner_blocks);
}
const int max_threads = 1024;
inline dim3 SpatialSoftMax_getBlockSize(
uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
uint32_t inner_threads = inner_size;
inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads));
uint32_t dim_threads = 1;
if (inner_threads <= 64 && dim_size >= 64) {
while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size)
dim_threads *= 2;
dim_threads /= 2;
}
return dim3(dim_threads, inner_threads);
}
template<typename accscalar_t, typename Kernel>
void SpatialSoftMax_getLaunchSizes(
Kernel k,
uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
dim3& grid, dim3& block, uint32_t& smem_size) {
block = SpatialSoftMax_getBlockSize(outer_size, dim_size, inner_size);
uint32_t block_threads = block.x * block.y;
smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t);
int max_active_blocks;
#ifdef __HIP_PLATFORM_HCC__
max_active_blocks = 16;
#else
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
k, block_threads, smem_size);
#endif
max_active_blocks *= at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
}
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < max_block_size) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32));
return dim3(block_size);
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
// Note that it's not a complete block-wide reduction.
// Only threads that share threadIdx.y reduce values.
template<typename T, template<typename> class ReduceOp>
__forceinline__ __device__
T spatialBlockReduceX(T *shared, T val) {
ReduceOp<T> r;
shared += threadIdx.y * blockDim.x;
__syncthreads();
shared[threadIdx.x] = val;
// NOTE: loop starts with __syncthreads()
int offset = blockDim.x / 2;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset)
shared[threadIdx.x] = r(shared[threadIdx.x], shared[threadIdx.x + offset]);
offset /= 2;
}
__syncthreads();
return shared[0];
}
template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void cunn_SpatialSoftMaxForward(
outscalar_t *output, scalar_t *input,
uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
const uint32_t outer_stride = inner_size * dim_size;
const uint32_t dim_stride = inner_size;
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
const uint32_t outer_offset = outer_index * outer_stride;
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
////////////////////////////////////////////////////////////
// These two blocks are really eqivalent, but specializing on
// blockDim.x == 1 makes the kernel faster when it's unused.
// I didn't want to thread an extra template parameter, and nvcc
// seems to be smart enough to hoist the if outside of the loops.
////////////////////////////////////////////////////////////
if (blockDim.x > 1) {
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
}
max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input);
accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
} else {
accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
max_input = Max<accscalar_t>()(max_input, value);
}
accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride])
- max_input);
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
}
}
}
}
template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void cunn_SpatialSoftMaxBackward(
scalar_t *gradInput, outscalar_t *output, outscalar_t *gradOutput,
uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
const uint32_t outer_stride = inner_size * dim_size;
const uint32_t dim_stride = inner_size;
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
const uint32_t outer_offset = outer_index * outer_stride;
for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) {
const uint32_t data_offset = outer_offset + inner_index;
// See the comment in forward kernel
if (blockDim.x > 1) {
accscalar_t sum = 0;
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
sum += gradOutput[data_offset + d * dim_stride];
sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum);
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
gradInput[data_offset + d * dim_stride] =
epilogue(gradOutput[data_offset + d * dim_stride],
output[data_offset + d * dim_stride]);
}
} else {
accscalar_t sum = 0;
for (uint32_t d = 0; d < dim_size; d++)
sum += gradOutput[data_offset + d * dim_stride];
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum);
for (uint32_t d = 0; d < dim_size; d++) {
gradInput[data_offset + d * dim_stride] =
epilogue(gradOutput[data_offset + d * dim_stride],
output[data_offset + d * dim_stride]);
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Regular kernel (fast when dim_size is large; requires inner_size == 1)
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename AccumT>
struct MaxFloat
{
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const {
return ::max(max, (AccumT)v);
}
};
template<typename T, typename AccumT>
struct AddFloat
{
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + v;
}
};
template<typename T, typename AccumT>
struct SumExpFloat
{
__device__ __forceinline__ SumExpFloat(AccumT v)
: max_k(v) {}
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + std::exp(v - max_k);
}
const AccumT max_k;
};
template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal)
{
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();
smem[threadIdx.x] = val;
__syncthreads();
AccumT warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1;
if (threadIdx.x < 32) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]);
}
#if CUDA_VERSION >= 9000
__syncwarp(mask);
#endif
smem[lane] = warpVal;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
AccumT blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = blockVal;
}
// Sync and broadcast
__syncthreads();
return smem[0];
}
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {
T tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = data[offset + j * blockDim.x];
#pragma unroll
for (int j = 0; j < ILP; ++j)
threadVal = r(threadVal, tmp[j]);
}
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
return threadVal;
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxForward(outscalar_t *output, scalar_t *input, int classes)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
// forward pointers to batch[blockIdx.x]
// each block handles a sample in the mini-batch
input += blockIdx.x * classes;
output += blockIdx.x * classes;
// find the max
accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
accscalar_t max_k = blockReduce<Max, accscalar_t>(
sdata, threadMax, Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
// reduce all values
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
scalar_t tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = input[offset + j * blockDim.x];
#pragma unroll
for (int j = 0; j < ILP; ++j)
output[offset + j * blockDim.x] = epilogue(tmp[j]);
}
for (; offset < classes; offset += blockDim.x)
output[offset] = epilogue(input[offset]);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxBackward(scalar_t *gradInput, outscalar_t *output, outscalar_t *gradOutput, int classes)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
gradInput += blockIdx.x * classes;
output += blockIdx.x * classes;
gradOutput += blockIdx.x * classes;
accscalar_t threadSum = ilpReduce<AddFloat, 4, outscalar_t, accscalar_t>(
gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0));
accscalar_t sum_k = blockReduce<Add, accscalar_t>(
sdata, threadSum, Add<accscalar_t>(), accscalar_t(0));
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum_k);
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
outscalar_t tmpGradOutput[ILP];
outscalar_t tmpOutput[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmpGradOutput[j] = gradOutput[offset + j * blockDim.x];
tmpOutput[j] = output[offset + j * blockDim.x];
}
#pragma unroll
for (int j = 0; j < ILP; ++j)
gradInput[offset + j * blockDim.x] = epilogue(tmpGradOutput[j], tmpOutput[j]);
}
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = epilogue(gradOutput[offset], output[offset]);
}
template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float){
if (half_to_float) AT_ASSERTM(input_.scalar_type() == ScalarType::Half,"conversion is supported for Half type only");
auto input = input_.contiguous();
Tensor output = half_to_float ? at::empty_like(input, input.options().dtype(ScalarType::Float)) : at::empty_like(input);
static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
if (input.dim() == 0) input = input.view(1);
int64_t dim = maybe_wrap_dim(dim_, input.dim());
TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions");
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
if (input.numel() > 0) {
int64_t inner_size = 1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (int64_t i = 0; i < dim; ++i)
outer_size *= input.size(i);
for (int64_t i = dim + 1; i < input.dim(); ++i)
inner_size *= input.size(i);
// This kernel spawns a block per each element in the batch.
// XXX: it assumes that inner_size == 1
if (inner_size == 1) {
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] {
using accscalar_t = acc_type<scalar_t, true>;
if (!half_to_float) {
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), dim_size, dim_size, outer_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), dim_size
);
}
} else {
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax>(
output.data_ptr<accscalar_t>(), input.data_ptr<scalar_t>(), dim_size, dim_size, outer_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
output.data_ptr<accscalar_t>(), input.data_ptr<scalar_t>(), dim_size
);
}
}
});
// This kernel runs in a 2D grid, where each application along y dimension has a fixed
// outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size.
// Reductions over dim are done in a single-threaded manner.
} else {
uint32_t smem_size;
dim3 grid, block;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "host_softmax", [&] {
using accscalar_t = acc_type<scalar_t, true>;
if (!half_to_float) {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), outer_size, dim_size, inner_size
);
} else {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
output.data_ptr<accscalar_t>(), input.data_ptr<scalar_t>(), outer_size, dim_size, inner_size
);
}
});
}
THCudaCheck(cudaGetLastError());
}
return output;
}
template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
Tensor host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float){
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
Tensor gI = half_to_float ? at::empty_like(grad_, grad_.options().dtype(ScalarType::Half)) : at::empty_like(grad_);
if (grad_.numel() == 0) {
return gI;
}
auto grad = grad_.contiguous();
static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float");
if (grad.dim() == 0) grad = grad.view(1);
TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions");
auto output = output_.contiguous();
if (output.dim() == 0) output = output.view(1);
int64_t outer_size = 1;
int64_t dim_size = output.size(dim);
int64_t inner_size = 1;
for (int64_t i = 0; i < dim; ++i)
outer_size *= output.size(i);
for (int64_t i = dim + 1; i < output.dim(); ++i)
inner_size *= output.size(i);
// See descriptions of kernels above.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (inner_size == 1) {
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(gI.scalar_type(), "host_softmax_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
if (!half_to_float) {
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, is_log_softmax>(
gI.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), dim_size, dim_size, outer_size);
} else {
cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), dim_size
);
}
} else {
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
dispatch_softmax_backward<accscalar_t, scalar_t, accscalar_t, is_log_softmax>(
gI.data_ptr<scalar_t>(), grad.data_ptr<accscalar_t>(), output.data_ptr<accscalar_t>(), dim_size, dim_size, outer_size);
} else {
cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
gI.data_ptr<scalar_t>(), output.data_ptr<accscalar_t>(), grad.data_ptr<accscalar_t>(), dim_size
);
}
}
});
} else {
uint32_t smem_size;
dim3 grid, block;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "host_softmax_backward", [&] {
using accscalar_t = acc_type<scalar_t, true>;
if (!half_to_float) {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
gI.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
outer_size, dim_size, inner_size
);
} else {
SpatialSoftMax_getLaunchSizes<accscalar_t>(
&cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>,
outer_size, dim_size, inner_size,
grid, block, smem_size);
cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_size, stream>>>(
gI.data_ptr<scalar_t>(), output.data_ptr<accscalar_t>(), grad.data_ptr<accscalar_t>(),
outer_size, dim_size, inner_size
);
}
});
}
THCudaCheck(cudaGetLastError());
return gI;
}
}
Tensor log_softmax_cuda(const Tensor &input, const int64_t dim, const bool half_to_float){
return host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float);
}
Tensor log_softmax_backward_cuda(const Tensor &grad, const Tensor &output, int64_t dim, const Tensor &input){
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
AT_ASSERTM((grad.scalar_type() == ScalarType::Float && input.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
}
return host_softmax_backward<LogSoftMaxBackwardEpilogue,true>(grad, output, dim, half_to_float);
}
Tensor softmax_cuda(const Tensor &input, const int64_t dim, const bool half_to_float){
return host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float);
}
Tensor softmax_backward_cuda(const Tensor &grad, const Tensor &output, int64_t dim, const Tensor &input){
bool half_to_float = grad.scalar_type() != input.scalar_type();
if (half_to_float) {
AT_ASSERTM((grad.scalar_type() == ScalarType::Float && input.scalar_type() == ScalarType::Half), "expected input and grad types to match, or input to be at::Half and grad to be at::Float");
}
Tensor tmp = grad * output;
return host_softmax_backward<SoftMaxBackwardEpilogue,false>(tmp, output, dim, half_to_float);
}
}
}