From 3ff6fdefb21532c2ca90616de93aaa184b0e1594 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Thu, 12 Dec 2024 23:27:40 +0100 Subject: [PATCH] update keep min --- python/text_utils/inference/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/text_utils/inference/utils.py b/python/text_utils/inference/utils.py index a2982a7..9b361c6 100644 --- a/python/text_utils/inference/utils.py +++ b/python/text_utils/inference/utils.py @@ -270,28 +270,31 @@ def _top_k( return _top_k -def nucleus_masking(p: float) -> LogitFn: +def nucleus_masking(p: float, keep_min: int = 1) -> LogitFn: assert 0.0 <= p <= 1.0, "p must be in [0, 1]" + assert keep_min > 0, "keep_min must be positive" def _nuc( _input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam] ) -> torch.Tensor: + keep = min(keep_min, logits.shape[-1]) probs = torch.softmax(logits, dim=-1) - sorted_probs, indices = torch.sort(probs, dim=-1, descending=True) + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) cum_sum_probs = torch.cumsum(sorted_probs, dim=-1) nucleus = cum_sum_probs < p nucleus = torch.cat( - [nucleus.new_ones((len(nucleus), 1)), nucleus[:, :-1]], dim=-1 + [nucleus.new_ones((len(nucleus), keep)), nucleus[:, :-keep]], dim=-1 ) - sorted_logits = torch.gather(logits, -1, indices) + sorted_logits = torch.gather(logits, -1, sorted_indices) sorted_logits[torch.logical_not(nucleus)] = float("-inf") - return sorted_logits.gather(-1, indices.argsort(-1)) + return sorted_logits.gather(-1, sorted_indices.argsort(-1)) return _nuc def min_p_masking(min_p: float, keep_min: int = 1) -> LogitFn: assert 0.0 <= min_p <= 1.0, "min_p must be in [0, 1]" + assert keep_min > 0, "keep_min must be positive" def _min_p( _input_ids: torch.Tensor, logits: torch.Tensor, _: list[Beam]