Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix chatglm3 npu output #11590

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions python/llm/src/ipex_llm/transformers/npu_models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,16 @@ def chatglm2_model_forward(
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# ipex-llm change start: change rope cache shape
# rotary_pos_emb: [bsz, seq_len, rot_dim//2, 2]
cos, sin = rotary_pos_emb.permute(3, 0, 1, 2).chunk(2, dim=0)
cos = cos.squeeze(0).unsqueeze(1)
sin = sin.squeeze(0).unsqueeze(1)
cos = cos.repeat_interleave(2, dim=-1)
sin = sin.repeat_interleave(2, dim=-1)
# cos, sin: [bsz, 1, seq_len, rot_dim]
rotary_pos_emb = (cos, sin)
# ipex-llm change end

# ipex-llm changes begin:
# generate `causal_mask` and replace `full_attention_mask` with it
Expand All @@ -76,14 +85,6 @@ def chatglm2_model_forward(
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
else:
causal_mask = None

Expand Down Expand Up @@ -174,24 +175,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:


@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
def rotate_every_two(x: torch.Tensor):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: Tuple[torch.Tensor]) -> torch.Tensor:
# x: [bsz, n_head, seq_len, head_dim]
cos, sin = rope_cache
rot_dim = cos.size(-1)
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
x_out = x * cos + rotate_every_two(x) * sin
return torch.cat([x_out, x_pass], dim=-1)


def chatglm2_attention_forward(
Expand Down Expand Up @@ -246,7 +243,7 @@ def chatglm2_attention_forward(
key_states,
value_states,
attn_mask=attention_mask,
is_causal=q_len > 1 and bsz == 1,
is_causal=attention_mask is None and q_len > 1 and bsz == 1,
)
attn_weights = None
else:
Expand Down
Loading