Skip to content

Commit

Permalink
revert lr dpo (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkevinzc authored Dec 1, 2024
1 parent 870e4a2 commit c638ce4
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions oat/learners/dap.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ def get_batch_logps(

if self.algo != DAPAlgo.BNF:

if self.algo == DAPAlgo.LR_DPO:
length = torch.min(loss_masks.sum(-1))
else:
length = loss_masks.sum(-1)
length = loss_masks.sum(-1)

if average_log_prob:
return (target_logps * loss_masks).sum(-1) / length
Expand Down

0 comments on commit c638ce4

Please sign in to comment.