From 42096162bf5e20baf77816244ca66d8d074df4f5 Mon Sep 17 00:00:00 2001 From: Pavel Geyn Date: Mon, 20 Jan 2025 17:13:53 +0500 Subject: [PATCH] fix after rebase --- turbo_alignment/trainers/dpo.py | 36 +++++++++++---------------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/turbo_alignment/trainers/dpo.py b/turbo_alignment/trainers/dpo.py index 5805dcf1..01c25281 100755 --- a/turbo_alignment/trainers/dpo.py +++ b/turbo_alignment/trainers/dpo.py @@ -683,21 +683,12 @@ def _get_batch_logps( return local_loss def concatenated_forward( - self, - model: nn.Module, - batch: dict[str, Any], - get_from_dataset=False, - ) -> ( - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None] - | tuple[torch.Tensor, torch.Tensor] - ): + self, model: nn.Module, batch: dict[str, Any] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]: concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device) precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None) - if get_from_dataset: - return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps') - input_ids = concatenated_batch['input_ids'] attention_mask = concatenated_batch['attention_mask'] labels = concatenated_batch['labels'] @@ -747,19 +738,16 @@ def concatenated_forward( return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]: - if self.ref_model is None: - chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True) - else: - with torch.no_grad(): - if model is not None: - (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ( - chosen_logps, - rejected_logps, - *_, - ) = self.concatenated_forward(self.model, batch) + with torch.no_grad(): + if model is not None: + (chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + chosen_logps, + rejected_logps, + *_, + ) = self.concatenated_forward(self.model, batch) return chosen_logps, rejected_logps