From 42096162bf5e20baf77816244ca66d8d074df4f5 Mon Sep 17 00:00:00 2001
From: Pavel Geyn
Date: Mon, 20 Jan 2025 17:13:53 +0500
Subject: [PATCH] fix after rebase
---
turbo_alignment/trainers/dpo.py | 36 +++++++++++----------------------
1 file changed, 12 insertions(+), 24 deletions(-)
diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py
index 5805dcf1..01c25281 100755
--- a/turbo_alignment/trainers/dpo.py
+++ b/turbo_alignment/trainers/dpo.py
@@ -683,21 +683,12 @@ def _get_batch_logps(
return local_loss
def concatenated_forward(
- self,
- model: nn.Module,
- batch: dict[str, Any],
- get_from_dataset=False,
- ) -> (
- tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]
- | tuple[torch.Tensor, torch.Tensor]
- ):
+ self, model: nn.Module, batch: dict[str, Any]
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)
precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)
- if get_from_dataset:
- return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps')
-
input_ids = concatenated_batch['input_ids']
attention_mask = concatenated_batch['attention_mask']
labels = concatenated_batch['labels']
@@ -747,19 +738,16 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins
def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
- if self.ref_model is None:
- chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True)
- else:
- with torch.no_grad():
- if model is not None:
- (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
- else:
- with self.accelerator.unwrap_model(self.model).disable_adapter():
- (
- chosen_logps,
- rejected_logps,
- *_,
- ) = self.concatenated_forward(self.model, batch)
+ with torch.no_grad():
+ if model is not None:
+ (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
+ else:
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
+ (
+ chosen_logps,
+ rejected_logps,
+ *_,
+ ) = self.concatenated_forward(self.model, batch)
return chosen_logps, rejected_logps