From ba01b85c13ab204852a10ebc691e0371d4a965d0 Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:46:21 +0800 Subject: [PATCH] empty cache only for 1st token but rest token to speed up (#11665) --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 2e6e99b51e8..de51cbca5ca 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -959,9 +959,13 @@ def llama_causallm_forward_4_37_lowmem( logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] # noqa logits = torch.cat(logits, dim=-1) else: - torch.xpu.empty_cache() + # Only empty cache for first token + if hidden_states.shape[1] > 1: + torch.xpu.empty_cache() logits = self.lm_head(hidden_states) - torch.xpu.empty_cache() + # Only empty cache for first token + if hidden_states.shape[1] > 1: + torch.xpu.empty_cache() # logits = logits.float() # ipex-llm change ends