From fc9fe28615cd8398518d2f43f56e2abe7e349a40 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 13 May 2024 09:58:06 -0600 Subject: [PATCH] Squash IBM/vllm 22 --- vllm/model_executor/layers/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d2053..375d404a8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -680,7 +680,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs(