Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Dec 12, 2024
1 parent e6bf295 commit fcaf71b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 104 deletions.
46 changes: 23 additions & 23 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,29 +301,29 @@ def optimize_npu_model(cls, *args, **kwargs):
model.share_memory()

if not pipeline:
# if (not hasattr(model, 'llm') and
# model.config.model_type in ["qwen2", "llama", "minicpm"]):
# from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
# optimize_llm_single_process(
# llm,
# kv_len=max_context_len,
# max_prompt_len=max_prompt_len,
# transpose_value_cache=transpose_value_cache,
# group_size=quantization_group_size,
# qtype=qtype,
# save_directory=save_directory,
# fuse_layers=fuse_layers
# )
# else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
if (not hasattr(model, 'llm') and
model.config.model_type in ["qwen2", "llama", "minicpm"]):
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
optimize_llm_single_process(
llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
save_directory=save_directory,
fuse_layers=fuse_layers
)
else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
else:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
import convert_llm
Expand Down
96 changes: 16 additions & 80 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,32 +188,18 @@ def __init__(
curr_key_values = []
cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids)
for i in range(num_layers):
if self.mode == "prefill":
hidden_states, new_key_states, new_value_states, temp = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
else:
hidden_states, new_key_states, new_value_states = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
hidden_states, new_key_states, new_value_states = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
curr_key_values.append((new_key_states, new_value_states))

# define outputs
Expand All @@ -223,9 +209,6 @@ def __init__(
new_key_states = self.convert_to_fp16(curr_key_values[i][0])
new_value_states = self.convert_to_fp16(curr_key_values[i][1])

if self.mode == "prefill":
temp = self.convert_to_fp16(temp)

print("start compiling")
if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1":
self.compile(npu_dpu_groups=6)
Expand All @@ -247,40 +230,8 @@ def build_decoder(
):

residual = hidden_states
# if self.mode == "prefill":
# temp = hidden_states # first exec is okay
# hidden_states = self.layer_norm(hidden_states, input_layernorm_weight)
# # if self.mode == "prefill":
# # temp = hidden_states # first exec is wrong
if self.mode == "prefill":
# temp = hidden_states # first exec is okay
hidden_states = self.convert_to_fp32(hidden_states)
variance = self.reduce_mean(
self.power(hidden_states, self.constant(np.array([[2]], dtype=np.float32))),
-1,
keep_dims=True,
)
# temp = variance # first exec is okay
eps = self.constant(self.rms_norm_eps)
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
# temp = hidden_states # first exec is okay
input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)

mul_res = self.eltwise_mul(input_layernorm_weight, hidden_states)
temp = mul_res
hidden_states = temp

# mul_res = self.eltwise_mul(input_layernorm_weight, hidden_states)
# hidden_states = mul_res
# temp = hidden_states

# temp = self.eltwise_mul(input_layernorm_weight, hidden_states)

# hidden_states = self.eltwise_mul(input_layernorm_weight, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
else:
hidden_states = self.layer_norm(hidden_states, input_layernorm_weight)
input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size))
input_2d = self.layer_norm(input_2d, input_layernorm_weight)
attn_output, new_key_states, new_value_states = self.attention(
hidden_states=input_2d,
position_ids=position_ids,
Expand All @@ -303,11 +254,7 @@ def build_decoder(
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)

if self.mode == "prefill":
return hidden_states, new_key_states, new_value_states, temp
else:
return hidden_states, new_key_states, new_value_states
# return hidden_states, new_key_states, new_value_states
return hidden_states, new_key_states, new_value_states


class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
Expand Down Expand Up @@ -553,13 +500,9 @@ def forward(
if self.cached_cos is None:
inputs += (cos.to(torch.float32), sin.to(torch.float32),)
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value, temp = run_model(
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
)
print(f'[DEBUG INFO OF TEMP VAR] temp is {temp}')
# hidden_states, past_key, past_value = run_model(
# inputs, self.op_parameters, backend_cls, self.op_id, replica=2
# )
cache_kwargs = {
"cache_position": cache_position,
"max_seq_len": self.max_seq_len,
Expand Down Expand Up @@ -928,11 +871,7 @@ def run_prefill(
result_queue.put("loading finish")

with torch.inference_mode():
# for decoder_layer in deocderlayers:
for idx, decoder_layer in enumerate(deocderlayers):
print(f'Before running {idx} decoder layer, hidden_states is {hidden_states}')
print(f'Before running {idx} decoder layer, causal_mask is {causal_mask}')
print(f'Before running {idx} decoder layer, position_ids is {position_ids}')
for decoder_layer in deocderlayers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
Expand All @@ -946,11 +885,8 @@ def run_prefill(
)

hidden_states = layer_outputs[0]
print(f'After running {idx} decoder layer, hidden_states is {hidden_states}')
print('===============')
next_decoder_cache = layer_outputs[1]

print(f'run_prefill result, hidden_states is {hidden_states}')
result_queue.put((hidden_states, next_decoder_cache))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ def layer_norm(self, hidden_states, layernorm_weight):
)
eps = self.constant(self.rms_norm_eps)
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
layernorm_weight = self.convert_to_fp32(layernorm_weight)
# layernorm_weight = self.convert_to_fp32(layernorm_weight)
hidden_states = self.convert_to_fp16(hidden_states)
hidden_states = self.eltwise_mul(layernorm_weight, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
return hidden_states
Expand Down

0 comments on commit fcaf71b

Please sign in to comment.