diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d2053..375d404a8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -680,7 +680,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs(