Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jun 14, 2024
1 parent 8ee7bb7 commit 1b6a72f
Showing 1 changed file with 4 additions and 39 deletions.
43 changes: 4 additions & 39 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@
#

import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.chatglm2 import repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast
import math

import os

KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512


def chatglm4_model_forward(
self,
Expand All @@ -45,34 +39,6 @@ def chatglm4_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
return chatglm4_model_forward_internal(
self=self,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
full_attention_mask=full_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)


def chatglm4_model_forward_internal(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
Expand Down Expand Up @@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)

if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
if not getattr(self.rotary_pos_emb, "cached", False):
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
# We should generate float inv_freq to avoid overflow, as base is too large.
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
self.rotary_pos_emb.register_buffer("inv_freq",
inv_freq.to(inputs_embeds.dtype),
persistent=False)
inv_freq = inv_freq.to(inputs_embeds.dtype)
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
self.rotary_pos_emb.cached = True

# `full_attention_mask` is not None only when
Expand Down

0 comments on commit 1b6a72f

Please sign in to comment.