Skip to content

Commit

Permalink
Fix Arc StarCoder wrong query_shape when input is long (#10268)
Browse files Browse the repository at this point in the history
* Fix Arc StarCoder wrong query_shape when input is long

* Update gptbigcode.py
  • Loading branch information
Uxito-Ada authored Feb 28, 2024
1 parent c8937ac commit bde8e5c
Showing 1 changed file with 15 additions and 28 deletions.
43 changes: 15 additions & 28 deletions python/llm/src/bigdl/llm/transformers/models/gptbigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@ def gptbigcode_attention_forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs):
if "padding_mask" in kwargs:
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37." +
"Please make sure use `attention_mask` instead.`"
)
output_attentions: Optional[bool] = False):

if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
Expand All @@ -60,45 +54,38 @@ def gptbigcode_attention_forward(
"Please make sure to instantiate class with " +
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim)
# instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)

if layer_past is not None:
if layer_past.shape[-2] == key_value.shape[-2]:
key_value = torch.cat((layer_past, key_value), dim=-2)
else:
fill_zeros = torch.zeros(layer_past.shape[0],
layer_past.shape[1],
key_value.shape[2] - layer_past.shape[2],
dtype=layer_past.dtype,
device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2)

if layer_past is not None:
if layer_past.shape[-2] == key_value.shape[-2]:
key_value = torch.cat((layer_past, key_value), dim=-2)
else:
fill_zeros = torch.zeros(layer_past.shape[0],
layer_past.shape[1],
key_value.shape[2] - layer_past.shape[2],
dtype=layer_past.dtype,
device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

attn_output, attn_weights = self._attn(query,
key.transpose(-1, -2),
value,
attention_mask,
head_mask)
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2),
value, attention_mask, head_mask)

if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
Expand Down

0 comments on commit bde8e5c

Please sign in to comment.