Batch normalization and ReLU operations fused into one kernel with backprop implementation.
The forward kernel can be broken up into five steps:
-
Compute the mean:
$$\mu = \frac{1}{N} \sum_{i=1}^N x_i$$ -
Compute the variance:
$$\sigma^2 = \frac{1}{N} \sum_{i=1}^N (x_i - \mu)^2$$ -
Normalize the input:
$$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$$ -
Scale and shift with learnable γ and β parameters:
$$y_i = \gamma \hat{x}_i + \beta$$ where γ and β are learnable parameters that scale and shift the normalized value.
-
Apply ReLU activation function:
$$a_i = \text{ReLU}(y_i)$$
The backward pass needs to compute both the gradients for the two sets of learnable parameters γ and β as well as the gradient of the input to the layer, which needs to be passed on to the previous layer.
-
Compute gradients for the learnable γ parameters:
$$\frac{dL}{d\gamma} = \sum_{i=1}^N \frac{dL}{da_i} \cdot \frac{da_i}{dy_i} \cdot \frac{dy_i}{d\gamma}$$ -
Compute gradients for the learnable β parameters:
$$\frac{dL}{d\beta} = \sum_{i=1}^N \frac{dL}{da_i} \cdot \frac{da_i}{dy_i} \cdot \frac{dy_i}{d\beta}$$ -
Compute the gradient of the input w.r.t. the loss:
$$\frac{dL}{dx_i} = \frac{dL}{da_i} \cdot \frac{da_i}{dy_i} \cdot \frac{dy_i}{dx_{\hat{i}}} \cdot \frac{d{\hat{x_i}}}{dx_i}$$
where the final term can be written as:
The implementation can be found in fused_kernel_backward.cu