Skip to content

Commit

Permalink
Fix gptj kvcache & position id (#10141)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita authored Feb 18, 2024
1 parent 92c43d4 commit a7dd7e7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/llm/src/bigdl/llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def speculative_generate(self,
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = draft_past_key_values[0][0].size(1)
past_length = draft_past_key_values[0][0].size(2)
position_ids = torch.Tensor([[past_length]]).long().to(self.device)
forward_args["position_ids"] = position_ids
draft_output = draft_model(**forward_args)
Expand Down Expand Up @@ -563,7 +563,7 @@ def speculative_generate(self,
position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len
forward_args["position_ids"] = position_ids
elif self.config.model_type == "gptj":
past_length = past_key_values[0][0].size(1)
past_length = past_key_values[0][0].size(2)
input_len = drafted_input_ids.shape[1]
position_ids = torch.arange(past_length, input_len + past_length,
dtype=torch.long, device=drafted_input_ids.device)
Expand Down Expand Up @@ -644,7 +644,7 @@ def speculative_generate(self,
past_key_values = [[tmp, key_cache, value_cache, beam_idx]
for _, key_cache, value_cache, beam_idx in past_key_values]
else:
if self.config.model_type in ["qwen", "gptj"]:
if self.config.model_type in ["qwen"]:
past_key_values = [
(k[:, :-(max_of_max_matched - max_matched), :],
v[:, :-(max_of_max_matched - max_matched), :])
Expand All @@ -657,7 +657,7 @@ def speculative_generate(self,
v[:-(max_of_max_matched - max_matched), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type == "baichuan":
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched), :],
v[:, :, :-(max_of_max_matched - max_matched), :])
Expand Down

0 comments on commit a7dd7e7

Please sign in to comment.