Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gemma for 4.41 #11531

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu

# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

On Windows:
Expand All @@ -32,7 +32,7 @@ conda activate llm

pip install --pre --upgrade ipex-llm[all]

pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu

# According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

On Windows:
Expand All @@ -33,7 +33,7 @@ conda activate llm

pip install --pre --upgrade ipex-llm[all]

pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ conda activate llm
# install the latest ipex-llm nightly build with 'all' option
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

On Windows:
Expand All @@ -31,7 +31,7 @@ conda create -n llm python=3.11
conda activate llm

pip install --pre --upgrade ipex-llm[all]
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Run
Expand Down
4 changes: 2 additions & 2 deletions python/llm/example/GPU/HuggingFace/LLM/codegemma/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

#### 1.2 Installation on Windows
Expand All @@ -33,7 +33,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Configures OneAPI environment variables for Linux
Expand Down
4 changes: 2 additions & 2 deletions python/llm/example/GPU/HuggingFace/LLM/gemma/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

#### 1.2 Installation on Windows
Expand All @@ -31,7 +31,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to Gemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Configures OneAPI environment variables for Linux
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

#### 1.2 Installation on Windows
Expand All @@ -33,7 +33,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

# According to CodeGemma's requirement, please make sure you are using a stable version of Transformers, 4.38.1 or newer.
pip install transformers==4.38.1
pip install "transformers>=4.38.1"
```

### 2. Configures OneAPI environment variables for Linux
Expand Down
21 changes: 16 additions & 5 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,21 +1480,32 @@ def _optimize_post(model, lightweight_bmm=False):
module.MistralMLP,
llama_mlp_forward)
elif model.config.model_type == "gemma":
invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"),
"Please upgrade transformers to 4.38.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.gemma import gemma_attention_forward
if version.parse(trans_version) >= version.parse("4.39.0"):
from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.gemma import gemma_attention_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward,
)
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma import gemma_mlp_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward,
)
convert_forward(model,
module.GemmaRMSNorm,
gemma_rms_norm_forward)
convert_forward(model,
module.GemmaMLP,
gemma_mlp_forward)

elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down
153 changes: 153 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,156 @@ def gemma_attention_forward(
attn_weights = None

return attn_output.to(original_dtype), attn_weights, past_key_value


def gemma_attention_forward_4_39(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
cache_position: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)

if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)

cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]

kv_seq_len = cache_k.shape[-2]

import xe_linear
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1

# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
past_key_value._seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]

if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "gemma")
else:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, None)

if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]

# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]

if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)

new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v

key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)

# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
else:
causal_mask = attention_mask
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output.to(original_dtype), attn_weights, past_key_value
Loading