Skip to content

Commit

Permalink
add basic llama 3.2 vision support (#12163)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Oct 8, 2024
1 parent 9b75806 commit 644af2a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,16 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, module.LlamaMLP, mlp_silu_forward)
convert_forward(model, module.LlamaModel, llama_model_forward)
convert_forward(model, module.LlamaAttention, llama_attention_forward)
elif model.config.model_type == "mllama":
# llama 3.2 vision
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
from ipex_llm.transformers.models.mllama import mllama_vision_attention_forward
convert_forward(model, module.MllamaVisionAttention, mllama_vision_attention_forward)
convert_forward(model, module.MllamaTextRMSNorm, rms_norm_forward)
convert_forward(model, module.MllamaTextMLP, mlp_silu_forward)
elif model.config.model_type == "llama":
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaMLP
Expand Down
78 changes: 78 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mllama/modeling_mllama.py
# which is licensed under Apache License 2.0:
#
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
import torch

from typing import Optional


def mllama_vision_attention_forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = None,
):
query = self.q_proj(hidden_state)
key = self.k_proj(hidden_state)
value = self.v_proj(hidden_state)

batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape

query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)

attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, self.training)

attn_output = torch.matmul(attn_weights, value)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)

output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return output, attn_weights

0 comments on commit 644af2a

Please sign in to comment.