Skip to content

Commit

Permalink
Fix update_kv_cache in Pipeline-Parallel-Serving for glm4-9b model (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuT authored Jul 9, 2024
1 parent fa81dbe commit a1cede9
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,14 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):

return kv_cache_1

def update_kv_cache(self, kv_cache, cur_id):
def update_kv_cache(self, kv_cache, prefill=False):
layer_start = self.model.layer_start
layer_end = self.model.layer_end
num_layers = self.model.num_layers

if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
if prefill:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
Expand All @@ -528,13 +528,10 @@ def update_kv_cache(self, kv_cache, cur_id):
pass
elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
value_placeholder = torch.empty_like((kv_cache)[-1][0])
kv_cache = tuple((value_placeholder, value_placeholder)) + \
tuple(None for _ in range(layer_start)) + \
(kv_cache)[layer_start:]
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (kv_cache)[layer_start:]
# kv_cache = past_key_values_placeholder
past_key_values_placeholder = tuple(
(value_placeholder, value_placeholder) for _ in range(layer_start)
) + (kv_cache)[layer_start:]
kv_cache = past_key_values_placeholder
else:
pass

Expand Down Expand Up @@ -590,7 +587,7 @@ def model_step(self, input, cur_batch):
# torch.xpu.empty_cache()

if cur_batch.prefilled_index == cur_batch.batch_size:
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, True)

self.past_key_values_dict[cur_id] = tmp_past_key_values

Expand All @@ -604,7 +601,8 @@ def model_step(self, input, cur_batch):
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
self.partial_output_dict[cur_id] = _pre_output
else:
_past_key_values = self.update_kv_cache(output.past_key_values, cur_id)
_prefill = self.past_key_values_dict.get(cur_id, None) is None
_past_key_values = self.update_kv_cache(output.past_key_values, prefill=_prefill)
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
Expand Down Expand Up @@ -687,7 +685,6 @@ async def process_step(self, tokenizer, result_dict):

if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
cur_id = cur_batch.batch_id
# cur_batch = self.prepare_batch(cur_batch)
if cur_batch.prefilled_index >= cur_batch.batch_size:
cur_batch.partial_prefilling = 0
if cur_batch.partial_prefilling > 0:
Expand Down Expand Up @@ -810,14 +807,9 @@ async def process_step(self, tokenizer, result_dict):
dist.recv(cur_input, src=self.pre_rank)

output, cur_batch = self.model_step(cur_input, cur_batch)
# if output is not None and self.rank == self.world_size - 1:
# output = torch.argmax(output[:, -1:, :], dim=-1)

if output is not None:
# dist.send(output, dst=self.next_rank)
self.send_buff = output
else:
self.send_buff = None
self.send_buff = output

if self.rank == 0:
self.on_going_batches[:-1] = self.on_going_batches[1:]
self.on_going_batches[self.world_size - 1] = cur_batch
Expand Down

0 comments on commit a1cede9

Please sign in to comment.