Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pipeline parallel for glm-4-9b-chat #11463

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-MoE-A2.7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/CodeQwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/glm-4-9b-chat](./run_chatglm_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
Expand Down Expand Up @@ -116,11 +117,12 @@ bash run_qwen1.5_arc_2_card.sh
<details>
<summary> Show chatglm example </summary>

#### Run chatglm3-6B on two Intel Arc A770
#### Run glm-4-9b-chat / chatglm3-6B on two Intel Arc A770

You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for chatglm to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

```bash
pip install transformers==4.37.0 "tiktoken>=0.7.0"
bash run_chatglm_arc_2_card.sh
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ NUM_GPUS=2 # number of used GPU
# To run chatglm3-6b
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'THUDM/chatglm3-6b' --gpu-num $NUM_GPUS --low-bit 'sym_int4'

# # To run glm-4-9b-chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'THUDM/glm-4-9b-chat' --gpu-num $NUM_GPUS --low-bit 'sym_int4'
4 changes: 4 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm4 import chatglm4_encoder_forward
convert_forward(model,
module.SelfAttention,
chatglm4_attention_forward)
Expand All @@ -1127,6 +1128,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model,
module.GLMTransformer,
chatglm4_encoder_forward)

elif "mpt" in model.config.model_type:
if model.config.architectures is not None:
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def chatglm2_model_forward(
else:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
seq_length, batch_size, _ = inputs_embeds.shape
input_ids = torch.empty((batch_size, seq_length),
dtype=inputs_embeds.dtype, device=inputs_embeds.device)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (
Expand Down
73 changes: 71 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ def chatglm4_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
batch_size, seq_length = input_ids.shape
inputs_embeds = self.embedding(input_ids)
else:
batch_size, seq_length, _ = inputs_embeds.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If attention_mask is given a not-None and not-All value, input_ids is needed in line 58 self.get_masks() and it will raise error if input_ids is still None. Maybe add an empty tensor here?

input_ids = torch.empty((batch_size, seq_length), device=inputs_embeds.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relative code in modeling_chatglm:

    def get_masks(self, input_ids, past_key_values, padding_mask=None):
        batch_size, seq_length = input_ids.shape

Copy link
Contributor Author

@plusbang plusbang Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If attention_mask is given a not-None and not-All value, input_ids is needed in line 58 self.get_masks() and it will raise error if input_ids is still None. Maybe add an empty tensor here?

Have updated in chatglm2.py and chatglm4.py.

input_ids = torch.empty((batch_size, seq_length),
dtype=inputs_embeds.dtype, device=inputs_embeds.device)

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or\
Expand Down Expand Up @@ -234,3 +237,69 @@ def chatglm4_attention_forward(
output = self.dense(attn_output)

return output, past_key_value


def chatglm4_encoder_forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
):
if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False

all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
for index in range(self.num_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer = self._get_layer(index)
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache,
use_reentrant=False
)
else:
# if kv_caches[index] is not None:
layer_ret = layer(
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=kv_caches[index],
use_cache=use_cache
)
hidden_states, kv_cache = layer_ret
if use_cache:
# token by token decoding, use tuple format
if kv_caches[0] is not None:
presents = presents + (kv_cache,)
# prefilling in decoding, use tensor format to save cuda memory
else:
if len(presents) == 0:
presents = kv_cache
else:
# bigdl-llm change starts
# to fix first token's kv cache error of tensor format in pipeline parallel
if isinstance(kv_cache, tuple):
kv_cache = torch.tensor(kv_cache,
dtype=hidden_states.dtype).to(hidden_states.device)
# bigdl-llm change ends
presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)

return hidden_states, presents, all_hidden_states, all_self_attentions
33 changes: 30 additions & 3 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(self, *args):
def forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
):
if kv_cache is None:
return hidden_states, ()
return hidden_states, kv_cache


Expand Down Expand Up @@ -278,8 +280,20 @@ def pipeline_parallel_generate(self,
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# Temporarily specify as Baichuan and ChatGLM
if self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
if self.config.model_type == "chatglm" and self.config.num_layers == 40:
# for glm-4-9b-chat
if step == 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (outputs.past_key_values)[: layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
_past_key_values = past_key_values_placeholder
else:
_past_key_values = outputs.past_key_values
elif self.config.model_type in ["baichuan", "chatglm"] and local_rank != 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
Expand Down Expand Up @@ -417,7 +431,20 @@ def model_step(self, input, cur_batch):
attention_mask=attention_mask,
use_cache=True,)

if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (output.past_key_values)[: layer_end - layer_start] + tuple(
(value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
)
_past_key_values = past_key_values_placeholder
else:
_past_key_values = output.past_key_values
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
# for baichuan2 and chatglm3
value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
Expand Down
Loading