diff --git a/oat/learners/dap.py b/oat/learners/dap.py index dfa67ff..cc11a15 100644 --- a/oat/learners/dap.py +++ b/oat/learners/dap.py @@ -223,10 +223,7 @@ def get_batch_logps( if self.algo != DAPAlgo.BNF: - if self.algo == DAPAlgo.LR_DPO: - length = torch.min(loss_masks.sum(-1)) - else: - length = loss_masks.sum(-1) + length = loss_masks.sum(-1) if average_log_prob: return (target_logps * loss_masks).sum(-1) / length