Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM error with omniquant Llama-3-8b #256

Open
vengdeng opened this issue Dec 12, 2024 · 7 comments
Open

vLLM error with omniquant Llama-3-8b #256

vengdeng opened this issue Dec 12, 2024 · 7 comments

Comments

@vengdeng
Copy link

Hi developer,

I found another param_dict error when using vLLM for omniquant compressed  Llama-3-8b  (which saved with vLLM=True)

File ~/delta_p/lib/python3.10/site-packages/vllm/model_executor/models/utils.py:175, in AutoWeightsLoader._load_module(self, base_prefix, module, weights)
173 module_load_weights = getattr(module, "load_weights", None)
174 if callable(module_load_weights):
--> 175 module_load_weights(weights)
176 return
178 child_modules = dict(module.named_children())

File ~/delta_p/lib/python3.10/site-packages/vllm/model_executor/models/llama.py:407, in LlamaModel.load_weights(self, weights)
404 if is_pp_missing_parameter(name, self):
405 continue
--> 407 param = params_dict[name]
408 weight_loader = getattr(param, "weight_loader",
409 default_weight_loader)
410 weight_loader(param, loaded_weight)

KeyError: 'layers.0.fc1_smooth_scale'

Thank you for your help!

@gushiqiao
Copy link
Contributor

Can you provide the configuration file?

@vengdeng
Copy link
Author

Thank you for your reply, here is the configuration file, I also tried the qwen2 but also face the same error

base:
seed: &seed 2
model:
type: Llama
path: meta-llama/Meta-Llama-3-8B-Instruct
torch_dtype: auto
calib:
name: wikitext2
download: True
path: /save_files/data
n_samples: 128
bs: 10
seq_len: 2048
preproc: wikitext2_gptq
seed: *seed
eval:
eval_pos: []
name: wikitext2
download: True
path: /save_files/data
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 10
inference_per_block: False
# Consistency of tokens between original and fake-quantized model output.
eval_token_consist: True
quant:
method: OmniQuant
weight:
bit: 8
symmetric: True
granularity: per_channel
calib_algo: learnable
ste: True
act:
bit: 8
symmetric: True
granularity: per_token
ste: True
special:
aug_loss: False
let: True
lwc: True
lwc_lr: 0.01
# Set "let_lr: 0.001" for w4a4 quantization.
let_lr: 0.005
# Set to "True" if the model has bias (e.g. Opt).
use_shift: False
# Use "0.75" for w4a4 quantization.
alpha: 0.5
deactive_amp: True
epochs: 1
wd: 0
robust_weight: 0
quant_out: True
save:
save_trans: False
save_vllm: True
save_path: vllm_omni_llama3_8b_w8a8

@gushiqiao
Copy link
Contributor

Okay, I'll try to reproduce the problem you encountered.

@vengdeng
Copy link
Author

thank you very much !

@gushiqiao
Copy link
Contributor

You can try the latest code, which has already fixed the mentioned issue.

@vengdeng
Copy link
Author

Thank you for your help, however, it seems the added code influenced the block_forward function and caused the following error. May I know your configuration?

[rank0]: Traceback (most recent call last):
[rank0]: File "/data/dwenlong/llmc/llmc/main.py", line 317, in
[rank0]: main(config)
[rank0]: File "/data/dwenlong/llmc/llmc/main.py", line 121, in main
[rank0]: blockwise_opt.run_block_loop()
[rank0]: File "/data/dwenlong/llmc/llmc/compression/blockwise_optimization.py", line 40, in run_block_loop
[rank0]: self.block_opt(self.blocks[self.block_idx])
[rank0]: File "/data/dwenlong/llmc/llmc/compression/quantization/base_blockwise_quantization.py", line 418, in block_opt
[rank0]: self.run(block, input_feat, handles)
[rank0]: File "/data/dwenlong/llmc/llmc/compression/quantization/base_blockwise_quantization.py", line 463, in run
[rank0]: self.input['data'] = self.block_forward(block)
[rank0]: File "/data/dwenlong/llmc/llmc/compression/quantization/omniq.py", line 132, in block_forward
[rank0]: input_data[i] = input_data[i].to(device=next(block.parameters()).device)
[rank0]: StopIteration

@gushiqiao
Copy link
Contributor

The latest code should not affect the block forward function. You can clone it again, update the running environment (especially the transformers version), and try to run it again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants