diff --git a/alt_cuda_corr/correlation_kernel.cu b/alt_cuda_corr/correlation_kernel.cu index 145e5804..e9719f24 100644 --- a/alt_cuda_corr/correlation_kernel.cu +++ b/alt_cuda_corr/correlation_kernel.cu @@ -119,144 +119,197 @@ __global__ void corr_forward_kernel( } -template -__global__ void corr_backward_kernel( - const torch::PackedTensorAccessor32 fmap1, - const torch::PackedTensorAccessor32 fmap2, - const torch::PackedTensorAccessor32 coords, - const torch::PackedTensorAccessor32 corr_grad, - torch::PackedTensorAccessor32 fmap1_grad, - torch::PackedTensorAccessor32 fmap2_grad, - torch::PackedTensorAccessor32 coords_grad, - int r) -{ - - const int b = blockIdx.x; - const int h0 = blockIdx.y * blockDim.x; - const int w0 = blockIdx.z * blockDim.y; - const int tid = threadIdx.x * blockDim.y + threadIdx.y; - - const int H1 = fmap1.size(1); - const int W1 = fmap1.size(2); - const int H2 = fmap2.size(1); - const int W2 = fmap2.size(2); - const int N = coords.size(1); - const int C = fmap1.size(3); - - __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; - __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; - - __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; - __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; - - __shared__ scalar_t x2s[BLOCK_HW]; - __shared__ scalar_t y2s[BLOCK_HW]; - - for (int c=0; c(floor(y2s[k1]))-r+iy; - int w2 = static_cast(floor(x2s[k1]))-r+ix; - int c2 = tid % CHANNEL_STRIDE; - - auto fptr = fmap2[b][h2][w2]; - if (within_bounds(h2, w2, H2, W2)) - f2[c2][k1] = fptr[c+c2]; - else - f2[c2][k1] = 0.0; - - f2_grad[c2][k1] = 0.0; - } - - __syncthreads(); +// template +// __global__ void corr_backward_kernel( +// const torch::PackedTensorAccessor32 fmap1, +// const torch::PackedTensorAccessor32 fmap2, +// const torch::PackedTensorAccessor32 coords, +// const torch::PackedTensorAccessor32 corr_grad, +// torch::PackedTensorAccessor32 fmap1_grad, +// torch::PackedTensorAccessor32 fmap2_grad, +// torch::PackedTensorAccessor32 coords_grad, +// int r) +// { + +// const int b = blockIdx.x; +// const int h0 = blockIdx.y * blockDim.x; +// const int w0 = blockIdx.z * blockDim.y; +// const int tid = threadIdx.x * blockDim.y + threadIdx.y; + +// const int H1 = fmap1.size(1); +// const int W1 = fmap1.size(2); +// const int H2 = fmap2.size(1); +// const int W2 = fmap2.size(2); +// const int N = coords.size(1); +// const int C = fmap1.size(3); + +// __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; +// __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; + +// __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; +// __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; + +// __shared__ scalar_t x2s[BLOCK_HW]; +// __shared__ scalar_t y2s[BLOCK_HW]; + +// for (int c=0; c(floor(y2s[k1]))-r+iy; +// int w2 = static_cast(floor(x2s[k1]))-r+ix; +// int c2 = tid % CHANNEL_STRIDE; + +// auto fptr = fmap2[b][h2][w2]; +// if (within_bounds(h2, w2, H2, W2)) +// f2[c2][k1] = fptr[c+c2]; +// else +// f2[c2][k1] = 0.0; + +// f2_grad[c2][k1] = 0.0; +// } + +// __syncthreads(); - const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; - scalar_t g = 0.0; +// const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; +// scalar_t g = 0.0; - int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); - int ix_ne = H1*W1*((iy-1) + rd*ix); - int ix_sw = H1*W1*(iy + rd*(ix-1)); - int ix_se = H1*W1*(iy + rd*ix); +// int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); +// int ix_ne = H1*W1*((iy-1) + rd*ix); +// int ix_sw = H1*W1*(iy + rd*(ix-1)); +// int ix_se = H1*W1*(iy + rd*ix); - if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) - g += *(grad_ptr + ix_nw) * dy * dx; +// if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) +// g += *(grad_ptr + ix_nw) * dy * dx; - if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) - g += *(grad_ptr + ix_ne) * dy * (1-dx); +// if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) +// g += *(grad_ptr + ix_ne) * dy * (1-dx); - if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) - g += *(grad_ptr + ix_sw) * (1-dy) * dx; +// if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) +// g += *(grad_ptr + ix_sw) * (1-dy) * dx; - if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) - g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); +// if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) +// g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); - for (int k=0; k(floor(y2s[k1]))-r+iy; +// int w2 = static_cast(floor(x2s[k1]))-r+ix; +// int c2 = tid % CHANNEL_STRIDE; + +// scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; +// if (within_bounds(h2, w2, H2, W2)) +// atomicAdd(fptr+c+c2, f2_grad[c2][k1]); +// } +// } +// } +// } +// __syncthreads(); + + +// for (int k=0; k(floor(y2s[k1]))-r+iy; - int w2 = static_cast(floor(x2s[k1]))-r+ix; - int c2 = tid % CHANNEL_STRIDE; - scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; - if (within_bounds(h2, w2, H2, W2)) - atomicAdd(fptr+c+c2, f2_grad[c2][k1]); +template +__global__ void corr_backward_kernel( + const torch::PackedTensorAccessor32 fmap1, + const torch::PackedTensorAccessor32 fmap2, + const torch::PackedTensorAccessor32 coords, + const torch::PackedTensorAccessor32 corr_grad, + torch::PackedTensorAccessor32 fmap1_grad, + torch::PackedTensorAccessor32 fmap2_grad, + torch::PackedTensorAccessor32 coords_grad, + int r +) { + const int B = fmap1.size(0); + const int H1 = fmap1.size(1); + const int W1 = fmap1.size(2); + const int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx < B*H1*W1) { + const int H2 = fmap2.size(1); + const int W2 = fmap2.size(2); + const int N = coords.size(1); + const int C = fmap1.size(3); + const int rd = int(sqrt(corr_grad.size(2))); + const int b = idx / (H1 * W1); + const int h = (idx % (H1*W1)) / W1; + const int w = (idx % (H1*W1)) % W1; + const int r = (rd - 1) / 2; + for (int n = 0; n < N; n++) { + scalar_t coords_x = coords[b][n][h][w][0], coords_y = coords[b][n][h][w][1]; + for (int iy = -r; iy <= r; iy++) { + for (int ix = -r; ix <= r; ix++) { + scalar_t x = coords_x + ix, y = coords_y + iy; + int x_floor = static_cast(floor(x)), y_floor = static_cast(floor(y)); + scalar_t dx = x - x_floor, dy = y - y_floor; + scalar_t weights_ii[2] = {1 - dx, dx}, weights_jj[2] = {1 - dy, dy}; + scalar_t g = corr_grad[b][n][(ix + r) * rd + iy + r][h][w]; + for (int c = 0; c < C; c++) { + scalar_t f = fmap1[b][h][w][c]; + scalar_t gf = 0; + for (int ii = 0; ii < 2; ii++) { + for (int jj = 0; jj < 2; jj++) { + int x0 = x_floor + ii, y0 = y_floor + jj; + if (within_bounds(y0, x0, H2, W2)) { + gf += g * fmap2[b][y0][x0][c] * weights_ii[ii] * weights_jj[jj]; + atomicAdd(&fmap2_grad[b][y0][x0][c], g * f * weights_ii[ii] * weights_jj[jj]); + } + } + } + fmap1_grad[b][h][w][c] += gf; } } - } - } - __syncthreads(); - - - for (int k=0; k corr_cuda_forward( torch::Tensor fmap1, torch::Tensor fmap2, @@ -305,12 +358,11 @@ std::vector corr_cuda_backward( auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); - - const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); - const dim3 threads(BLOCK_H, BLOCK_W); + int threadsPerBlock = 256; + int blocksPerGrid = (B*H1*W1 + threadsPerBlock - 1) / threadsPerBlock; - corr_backward_kernel<<>>( + corr_backward_kernel<<>>( fmap1.packed_accessor32(), fmap2.packed_accessor32(), coords.packed_accessor32(), diff --git a/core/corr.py b/core/corr.py index cffcbc82..0ab82162 100644 --- a/core/corr.py +++ b/core/corr.py @@ -59,6 +59,23 @@ def corr(fmap1, fmap2): corr = corr.view(batch, ht, wd, 1, ht, wd) return corr / torch.sqrt(torch.tensor(dim).float()) +class AltCudaCorr(torch.autograd.Function): + @staticmethod + def forward(ctx, fmap1, fmap2_i, coords, r): + ctx.save_for_backward(fmap1, fmap2_i, coords) + ctx.r = r + corr, = alt_cuda_corr.forward(fmap1, fmap2_i, coords, r) + return corr, + # this should be different from return alt_cuda_corr.forward(... + + @staticmethod + def backward(ctx, corr_grad): + fmap1, fmap2_i, coords = ctx.saved_tensors + corr_grad = corr_grad.contiguous() + fmap1_grad, fmap2_grad, coords_grad = alt_cuda_corr.backward(fmap1, fmap2_i, coords, corr_grad, ctx.r) + return fmap1_grad, fmap2_grad, coords_grad, None + + class AlternateCorrBlock: def __init__(self, fmap1, fmap2, num_levels=4, radius=4): @@ -83,7 +100,7 @@ def __call__(self, coords): fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() - corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr, = AltCudaCorr.apply(fmap1_i, fmap2_i, coords_i, r) corr_list.append(corr.squeeze(1)) corr = torch.stack(corr_list, dim=1) diff --git a/train.py b/train.py index 30757309..de10b9c9 100644 --- a/train.py +++ b/train.py @@ -173,7 +173,7 @@ def train(args): loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma) scaler.scale(loss).backward() - scaler.unscale_(optimizer) + scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) scaler.step(optimizer) @@ -236,6 +236,7 @@ def train(args): parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') parser.add_argument('--add_noise', action='store_true') + parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') args = parser.parse_args() torch.manual_seed(1234)