Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Feb 11, 2025
1 parent efec5f5 commit 8e49724
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor | None]:
x_grad, weight_grad = _backward(
x=x,
weight=weight,
eps=ctx.eps,
eps=eps,
rmsnorm_denominator=rmsnorm_denominator,
output_grad=output_grad,
kernel_backend=CutoTuneParameter(),
BLOCK_SIZE_B=CutoTuneParameter(),
BLOCK_SIZE_H=CutoTuneParameter(),
)

if ctx.is_x_1d:
if is_x_1d:
x_grad = x_grad.squeeze(0)

return x_grad, weight_grad, None
Expand Down

0 comments on commit 8e49724

Please sign in to comment.