Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix chatglm2/3-32k/128k fp16 #11311

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def chatglm2_model_forward(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)

if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
if not getattr(self.rotary_pos_emb, "cached", False):
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype) / rot_dim))
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
inv_freq = inv_freq.to(inputs_embeds.dtype)
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype
self.rotary_pos_emb.cached = True

# `full_attention_mask` is not None only when
# `past_key_values` is not None and `seq_length` > 1
Expand Down
43 changes: 4 additions & 39 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@
#

import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.chatglm2 import repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast
import math

import os

KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512


def chatglm4_model_forward(
self,
Expand All @@ -45,34 +39,6 @@ def chatglm4_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
return chatglm4_model_forward_internal(
self=self,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
full_attention_mask=full_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)


def chatglm4_model_forward_internal(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
Expand Down Expand Up @@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)

if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
if not getattr(self.rotary_pos_emb, "cached", False):
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
# We should generate float inv_freq to avoid overflow, as base is too large.
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
self.rotary_pos_emb.register_buffer("inv_freq",
inv_freq.to(inputs_embeds.dtype),
persistent=False)
inv_freq = inv_freq.to(inputs_embeds.dtype)
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
self.rotary_pos_emb.cached = True

# `full_attention_mask` is not None only when
Expand Down
Loading