From 8ee7bb7e8b978ea8e8e2f0f3bd1d7e114da82d0c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 13 Jun 2024 17:37:58 +0800 Subject: [PATCH 1/2] fix chatglm2/3-32k/128k fp16 --- python/llm/src/ipex_llm/transformers/models/chatglm2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 375bfa475ca..c83675c78e2 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -98,14 +98,15 @@ def chatglm2_model_forward( dtype=torch.int64, device=inputs_embeds.device) position_ids = position_ids.repeat(batch_size, 1) - if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype: + if not getattr(self.rotary_pos_emb, "cached", False): rot_dim = self.rotary_pos_emb.dim base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1) inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype) / rot_dim)) + dtype=torch.float, + device=inputs_embeds.device) / rot_dim)) + inv_freq = inv_freq.to(inputs_embeds.dtype) self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False) - self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype + self.rotary_pos_emb.cached = True # `full_attention_mask` is not None only when # `past_key_values` is not None and `seq_length` > 1 From 1b6a72f9cf5b943344275590af269a5d9374cda0 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 14 Jun 2024 08:36:47 +0800 Subject: [PATCH 2/2] update --- .../ipex_llm/transformers/models/chatglm4.py | 43 ++----------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index cf936105774..5f0bd6082ce 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -18,8 +18,7 @@ # import torch -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -import torch.nn.functional as F +from typing import Optional, Tuple, Union from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb @@ -27,11 +26,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast import math -import os - -KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) -KV_CACHE_ALLOC_MIN_LENGTH = 512 - def chatglm4_model_forward( self, @@ -45,34 +39,6 @@ def chatglm4_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - from ipex_llm.transformers.kv import DynamicFp8Cache - use_cache = use_cache if use_cache is not None else self.config.use_cache - return chatglm4_model_forward_internal( - self=self, - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -def chatglm4_model_forward_internal( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -104,16 +70,15 @@ def chatglm4_model_forward_internal( dtype=torch.int64, device=inputs_embeds.device) position_ids = position_ids.repeat(batch_size, 1) - if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype: + if not getattr(self.rotary_pos_emb, "cached", False): rot_dim = self.rotary_pos_emb.dim base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1) # We should generate float inv_freq to avoid overflow, as base is too large. inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, dtype=torch.float, device=inputs_embeds.device) / rot_dim)) - self.rotary_pos_emb.register_buffer("inv_freq", - inv_freq.to(inputs_embeds.dtype), - persistent=False) + inv_freq = inv_freq.to(inputs_embeds.dtype) + self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False) self.rotary_pos_emb.cached = True # `full_attention_mask` is not None only when