Skip to content

Commit

Permalink
Support pipeline parallel for glm-4-9b-chat (intel-analytics#11463)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored and MeouSker77 committed Jul 19, 2024
1 parent 31245a9 commit 7f402ed
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 6 deletions.
4 changes: 3 additions & 1 deletion python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-MoE-A2.7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
Expand Down Expand Up @@ -116,11 +117,12 @@ bash run_qwen1.5_arc_2_card.sh
<details>
<summary> Show chatglm example </summary>

#### Run chatglm3-6B on two Intel Arc A770
#### Run glm-4-9b-chat / chatglm3-6B on two Intel Arc A770

You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for chatglm to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

```bash
pip install transformers==4.37.0 "tiktoken>=0.7.0"
bash run_chatglm_arc_2_card.sh
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ NUM_GPUS=2 # number of used GPU
# To run chatglm3-6b
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'THUDM/chatglm3-6b' --gpu-num $NUM_GPUS --low-bit 'sym_int4'

# # To run glm-4-9b-chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'THUDM/glm-4-9b-chat' --gpu-num $NUM_GPUS --low-bit 'sym_int4'
4 changes: 4 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
convert_forward(model,
module.SelfAttention,
chatglm4_attention_forward)
Expand All @@ -1127,6 +1128,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model,
module.GLMTransformer,
chatglm4_encoder_forward)

elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def chatglm2_model_forward(
else:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
seq_length, batch_size, _ = inputs_embeds.shape
input_ids = torch.empty((batch_size, seq_length),
dtype=inputs_embeds.dtype, device=inputs_embeds.device)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (
Expand Down
73 changes: 71 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def chatglm4_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
batch_size, seq_length = input_ids.shape
inputs_embeds = self.embedding(input_ids)
else:
batch_size, seq_length, _ = inputs_embeds.shape
input_ids = torch.empty((batch_size, seq_length),
dtype=inputs_embeds.dtype, device=inputs_embeds.device)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\
Expand Down Expand Up @@ -234,3 +237,69 @@ def chatglm4_attention_forward(
output = self.dense(attn_output)

return output, past_key_value


def chatglm4_encoder_forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
):
if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False

all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
for index in range(self.num_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer = self._get_layer(index)
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache,
use_reentrant=False
)
else:
# if kv_caches[index] is not None:
layer_ret = layer(
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=kv_caches[index],
use_cache=use_cache
)
hidden_states, kv_cache = layer_ret
if use_cache:
# token by token decoding, use tuple format
if kv_caches[0] is not None:
presents = presents + (kv_cache,)
# prefilling in decoding, use tensor format to save cuda memory
else:
if len(presents) == 0:
presents = kv_cache
else:
# bigdl-llm change starts
# to fix first token's kv cache error of tensor format in pipeline parallel
if isinstance(kv_cache, tuple):
kv_cache = torch.tensor(kv_cache,
dtype=hidden_states.dtype).to(hidden_states.device)
# bigdl-llm change ends
presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)

return hidden_states, presents, all_hidden_states, all_self_attentions
33 changes: 30 additions & 3 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(self, *args):
def forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
):
if kv_cache is None:
return hidden_states, ()
return hidden_states, kv_cache


Expand Down Expand Up @@ -282,8 +284,20 @@ def pipeline_parallel_generate(self,
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# Temporarily specify as Baichuan and ChatGLM
if self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
if self.config.model_type == "chatglm" and self.config.num_layers == 40:
# for glm-4-9b-chat
if step == 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (outputs.past_key_values)[: layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
elif self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
Expand Down Expand Up @@ -421,7 +435,20 @@ def model_step(self, input, cur_batch):
attention_mask=attention_mask,
use_cache=True,)

if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (output.past_key_values)[: layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
_past_key_values = past_key_values_placeholder
else:
_past_key_values = output.past_key_values
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
Expand Down

0 comments on commit 7f402ed

Please sign in to comment.