Skip to content

Commit

Permalink
fix after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Geyn committed Jan 21, 2025
1 parent c407488 commit 4209616
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions turbo_alignment/trainers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,21 +683,12 @@ def _get_batch_logps(
return local_loss

def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Any],
get_from_dataset=False,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]
| tuple[torch.Tensor, torch.Tensor]
):
self, model: nn.Module, batch: dict[str, Any]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
concatenated_batch = concatenated_inputs(batch, device=self.accelerator.device)

precomputed_margins: torch.Tensor | None = concatenated_batch.pop('margin', None)

if get_from_dataset:
return concatenated_batch.pop('ref_chosen_logps'), concatenated_batch.pop('ref_rejected_logps')

input_ids = concatenated_batch['input_ids']
attention_mask = concatenated_batch['attention_mask']
labels = concatenated_batch['labels']
Expand Down Expand Up @@ -747,19 +738,16 @@ def concatenated_forward(
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, precomputed_margins

def _get_logps(self, model: nn.Module | None, batch: dict[str, Any]) -> tuple[torch.Tensor, torch.Tensor]:
if self.ref_model is None:
chosen_logps, rejected_logps = self.concatenated_forward(model, batch, get_from_dataset=True)
else:
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)
with torch.no_grad():
if model is not None:
(chosen_logps, rejected_logps, *_) = self.concatenated_forward(model, batch)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
chosen_logps,
rejected_logps,
*_,
) = self.concatenated_forward(self.model, batch)

return chosen_logps, rejected_logps

Expand Down

0 comments on commit 4209616

Please sign in to comment.