Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Dec 11, 2024
1 parent 74d01dc commit e6bf295
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 41 deletions.
46 changes: 23 additions & 23 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,29 +301,29 @@ def optimize_npu_model(cls, *args, **kwargs):
model.share_memory()

if not pipeline:
if (not hasattr(model, 'llm') and
model.config.model_type in ["qwen2", "llama", "minicpm"]):
from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
optimize_llm_single_process(
llm,
kv_len=max_context_len,
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size,
qtype=qtype,
save_directory=save_directory,
fuse_layers=fuse_layers
)
else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
# if (not hasattr(model, 'llm') and
# model.config.model_type in ["qwen2", "llama", "minicpm"]):
# from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process
# optimize_llm_single_process(
# llm,
# kv_len=max_context_len,
# max_prompt_len=max_prompt_len,
# transpose_value_cache=transpose_value_cache,
# group_size=quantization_group_size,
# qtype=qtype,
# save_directory=save_directory,
# fuse_layers=fuse_layers
# )
# else:
optimize_llm(
llm,
max_context_len=max_context_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
)
else:
from ipex_llm.transformers.npu_pipeline_model.convert_pipeline \
import convert_llm
Expand Down
8 changes: 7 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def llama_model_forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand All @@ -144,6 +144,9 @@ def llama_model_forward(
cache_position,
)
else:
print(f'Before running {idx} decoder layer, hidden_states is {hidden_states}')
print(f'Before running {idx} decoder layer, causal_mask is {causal_mask}')
print(f'Before running {idx} decoder layer, position_ids is {position_ids}')
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
Expand All @@ -155,13 +158,16 @@ def llama_model_forward(
)

hidden_states = layer_outputs[0]
print(f'After running {idx} decoder layer, hidden_states is {hidden_states}')
print('===============')

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

print(f'run_prefill result, hidden_states is {hidden_states}')
hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
Expand Down
99 changes: 82 additions & 17 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,32 @@ def __init__(
curr_key_values = []
cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids)
for i in range(num_layers):
hidden_states, new_key_states, new_value_states = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
if self.mode == "prefill":
hidden_states, new_key_states, new_value_states, temp = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
else:
hidden_states, new_key_states, new_value_states = self.build_decoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids if cos_condition else None,
input_layernorm_weight=input_layernorm_weights[i],
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
cos=self.cos,
sin=self.sin,
)
curr_key_values.append((new_key_states, new_value_states))

# define outputs
Expand All @@ -209,6 +223,9 @@ def __init__(
new_key_states = self.convert_to_fp16(curr_key_values[i][0])
new_value_states = self.convert_to_fp16(curr_key_values[i][1])

if self.mode == "prefill":
temp = self.convert_to_fp16(temp)

print("start compiling")
if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1":
self.compile(npu_dpu_groups=6)
Expand All @@ -230,9 +247,42 @@ def build_decoder(
):

residual = hidden_states
hidden_states = self.layer_norm(hidden_states, input_layernorm_weight)
# if self.mode == "prefill":
# temp = hidden_states # first exec is okay
# hidden_states = self.layer_norm(hidden_states, input_layernorm_weight)
# # if self.mode == "prefill":
# # temp = hidden_states # first exec is wrong
if self.mode == "prefill":
# temp = hidden_states # first exec is okay
hidden_states = self.convert_to_fp32(hidden_states)
variance = self.reduce_mean(
self.power(hidden_states, self.constant(np.array([[2]], dtype=np.float32))),
-1,
keep_dims=True,
)
# temp = variance # first exec is okay
eps = self.constant(self.rms_norm_eps)
hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps)))
# temp = hidden_states # first exec is okay
input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight)

mul_res = self.eltwise_mul(input_layernorm_weight, hidden_states)
temp = mul_res
hidden_states = temp

# mul_res = self.eltwise_mul(input_layernorm_weight, hidden_states)
# hidden_states = mul_res
# temp = hidden_states

# temp = self.eltwise_mul(input_layernorm_weight, hidden_states)

# hidden_states = self.eltwise_mul(input_layernorm_weight, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)
else:
hidden_states = self.layer_norm(hidden_states, input_layernorm_weight)
input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size))
attn_output, new_key_states, new_value_states = self.attention(
hidden_states=hidden_states,
hidden_states=input_2d,
position_ids=position_ids,
attention_mask=attention_mask,
past_key=past_key,
Expand All @@ -253,7 +303,11 @@ def build_decoder(
hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states)

return hidden_states, new_key_states, new_value_states
if self.mode == "prefill":
return hidden_states, new_key_states, new_value_states, temp
else:
return hidden_states, new_key_states, new_value_states
# return hidden_states, new_key_states, new_value_states


class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
Expand Down Expand Up @@ -499,9 +553,13 @@ def forward(
if self.cached_cos is None:
inputs += (cos.to(torch.float32), sin.to(torch.float32),)
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
hidden_states, past_key, past_value, temp = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
)
print(f'[DEBUG INFO OF TEMP VAR] temp is {temp}')
# hidden_states, past_key, past_value = run_model(
# inputs, self.op_parameters, backend_cls, self.op_id, replica=2
# )
cache_kwargs = {
"cache_position": cache_position,
"max_seq_len": self.max_seq_len,
Expand Down Expand Up @@ -870,7 +928,11 @@ def run_prefill(
result_queue.put("loading finish")

with torch.inference_mode():
for decoder_layer in deocderlayers:
# for decoder_layer in deocderlayers:
for idx, decoder_layer in enumerate(deocderlayers):
print(f'Before running {idx} decoder layer, hidden_states is {hidden_states}')
print(f'Before running {idx} decoder layer, causal_mask is {causal_mask}')
print(f'Before running {idx} decoder layer, position_ids is {position_ids}')
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
Expand All @@ -884,8 +946,11 @@ def run_prefill(
)

hidden_states = layer_outputs[0]
print(f'After running {idx} decoder layer, hidden_states is {hidden_states}')
print('===============')
next_decoder_cache = layer_outputs[1]

print(f'run_prefill result, hidden_states is {hidden_states}')
result_queue.put((hidden_states, next_decoder_cache))


Expand Down

0 comments on commit e6bf295

Please sign in to comment.