Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Autotune Chunk Size #395

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 66 additions & 11 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def forward(
bias=None,
loss_fn=None,
chunk_size=1,
auto_tune_chunk_size=False,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
Expand All @@ -44,22 +45,22 @@ def forward(
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
auto_tune_chunk_size (bool): Whether to auto-tune the chunk size.
compute_nll_loss (bool): Whether to compute NLL loss.
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size

grad_weight = torch.zeros_like(weight)
grad_chosen_inputs = []
grad_rejected_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_acc = torch.zeros((), device=_input.device)
# Index at which the rejected responses start
len_chosen = target.shape[0] // 2

chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
loss_func_to_call = partial(
LigerFusedLinearPreferenceBase._compute_loss,
preference_loss_fn=loss_fn,
Expand Down Expand Up @@ -94,11 +95,67 @@ def accumulate_chunk(input_chunk, target_chunk):
loss_acc.add_(chunk_loss)
return chunk_grad_input

len_chosen = target.shape[0] // 2
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

if auto_tune_chunk_size and _input.shape[0] > 2:
total_gpu_memory = torch.cuda.get_device_properties(
_input.device
).total_memory
memory_allocated = torch.cuda.memory_allocated(device=_input.device)
torch.cuda.reset_peak_memory_stats(device=_input.device)
# print(
# f"Total GPU memory: {total_gpu_memory}, Memory allocated: {memory_allocated}"
# )
auto_tune_input_chunk = torch.cat(
[_input[0].unsqueeze(0), _input[len_chosen].unsqueeze(0)], dim=0
)
auto_tune_target_chunk = torch.cat(
[target[0].unsqueeze(0), target[len_chosen].unsqueeze(0)], dim=0
)
grad_input = accumulate_chunk(auto_tune_input_chunk, auto_tune_target_chunk)

grad_chosen_inputs.append(grad_input[0].unsqueeze(0))
grad_rejected_inputs.append(grad_input[1].unsqueeze(0))
peak_memory = torch.cuda.max_memory_allocated(device=_input.device)
# print(f"Peak memory: {peak_memory}")
memory_of_forward = peak_memory - memory_allocated
total_free_memory_available = (
total_gpu_memory - memory_allocated - memory_of_forward
)
CHUNK_SIZE = max(1, total_free_memory_available // memory_of_forward)
chunks = max(1, (_input.shape[0] - 2) // CHUNK_SIZE)
else:
CHUNK_SIZE = chunk_size
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))

if auto_tune_chunk_size and _input.shape[0] > 2:
# Skip the first chosen and rejected input since they were used for auto-tuning
_chosen_input_chunks = torch.chunk(
_input[1:len_chosen], chunks=chunks, dim=0
)
_chosen_target_chunks = torch.chunk(
target[1:len_chosen], chunks=chunks, dim=0
)
_rejected_input_chunks = torch.chunk(
_input[len_chosen + 1 :], chunks=chunks, dim=0
)
_rejected_target_chunks = torch.chunk(
target[len_chosen + 1 :], chunks=chunks, dim=0
)
else:
_chosen_input_chunks = torch.chunk(
_input[:len_chosen], chunks=chunks, dim=0
)
_chosen_target_chunks = torch.chunk(
target[:len_chosen], chunks=chunks, dim=0
)
_rejected_input_chunks = torch.chunk(
_input[len_chosen:], chunks=chunks, dim=0
)
_rejected_target_chunks = torch.chunk(
target[len_chosen:], chunks=chunks, dim=0
)

for (
chosen_input_chunk,
Expand All @@ -116,8 +173,6 @@ def accumulate_chunk(input_chunk, target_chunk):
[chosen_target_chunk, rejected_target_chunk], dim=0
)

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)
grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
Expand All @@ -131,7 +186,7 @@ def accumulate_chunk(input_chunk, target_chunk):
grad_weight,
grad_bias,
)
return loss_acc
return loss_acc, CHUNK_SIZE

@staticmethod
def backward(ctx, grad_output):
Expand Down
15 changes: 13 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,21 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
auto_tune_chunk_size=False,
chunk_size=1,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Handles both the forward and backward pass of the final linear layer with ORPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""
if hasattr(ctx, "chunk_size"):
chunk_size = ctx.chunk_size
auto_tune_chunk_size = False
else:
chunk_size = chunk_size

return LigerFusedLinearPreferenceBase.forward(
loss, chunk_size = LigerFusedLinearPreferenceBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
Expand All @@ -53,11 +60,15 @@ def forward(
ignore_index=ignore_index,
beta=beta,
compiled=compiled,
chunk_size=chunk_size,
auto_tune_chunk_size=auto_tune_chunk_size,
)
ctx.chunk_size = chunk_size
return loss

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None
return *grads, None, None, None, None, None
110 changes: 76 additions & 34 deletions test/chunked_loss/test_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,39 @@ def alignment_loss(
return losses


def get_random_input_tensors(B, T, H, V, scalar, dtype, ignore_index, bias):
"""
B: batch size
T: sequence length
H: hidden size
V: vocab size
scalar: scale factor for input
dtype: data type for input
ignore_index: ignore index for loss
bias: whether to include bias in the linear layer
"""
B = 2 * B # orpo loss requires B to be even

_input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B, T), device="cuda", dtype=torch.long)
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

_weight = torch.randn(V, H, device="cuda", dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)

_bias = torch.randn(V, device="cuda", dtype=dtype)
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None

return input1, weight1, target, bias1, input2, weight2, bias2


@pytest.mark.parametrize(
"B, T, H, V",
[
Expand All @@ -74,46 +107,15 @@ def alignment_loss(
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta):
B = 2 * B # orpo loss requires B to be even

_input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar
input1 = _input.detach().clone().requires_grad_(True)
input2 = _input.detach().clone().requires_grad_(True)

target = torch.randint(
0,
V,
(
B,
T,
),
device="cuda",
dtype=torch.long,
input1, weight1, target, bias1, input2, weight2, bias2 = get_random_input_tensors(
B, T, H, V, scalar, dtype, ignore_index, bias
)
# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

_weight = torch.randn(V, H, device="cuda", dtype=dtype)
weight1 = _weight.detach().clone().requires_grad_(True)
weight2 = _weight.detach().clone().requires_grad_(True)

_bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None
bias1 = _bias.detach().clone().requires_grad_(True) if bias else None
bias2 = _bias.detach().clone().requires_grad_(True) if bias else None

loss1 = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics(
input1, weight1, target, bias1
)
loss2 = LigerFusedLinearORPOFunction.apply(
input2,
weight2,
target,
bias2,
ignore_index,
beta,
True,
input2, weight2, target, bias2, ignore_index, beta, True, True
)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
Expand All @@ -125,3 +127,43 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index,
assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"B, T, H, V",
[
(1, 12, 36, 128),
(3, 47, 31, 123), # random shape
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1)])
def test_autotune(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta):

input1, weight1, target, bias1, input2, weight2, bias2 = get_random_input_tensors(
B, T, H, V, scalar, dtype, ignore_index, bias
)

loss1 = HF_ORPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics(
input1, weight1, target, bias1
)
loss2 = LigerFusedLinearORPOFunction.apply(
input2, weight2, target, bias2, ignore_index, beta, True, True, True
)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)

loss1.backward()
loss2.backward()

assert torch.allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
assert torch.allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol)
if bias:
assert torch.allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)
Loading