Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export_int8_model.py size issue #91

Open
ljhyeok123 opened this issue Jul 11, 2024 · 1 comment
Open

export_int8_model.py size issue #91

ljhyeok123 opened this issue Jul 11, 2024 · 1 comment

Comments

@ljhyeok123
Copy link

ljhyeok123 commented Jul 11, 2024

Hi, I'm having trouble with the export_int8_model.py code results and would like to ask a question.

The model in the huggingface was fine, but I'm wondering what is causing the size issue when applying the int8 model saved from export_int8_model.py to Int8OPTForCausalLM.from_pretrained() in the examples/smoothquant_opt_real_int8_demo.ipynb code.

Evaluating SmoothQuant INT8 model...
Traceback (most recent call last):
  File "smoothquant_opt.py", line 84, in <module>
    model_smoothquant = Int8OPTForCausalLM.from_pretrained('/home/hyeok/smoothquant/output/opt-1.3b-smoothquant', torch_dtype=torch.float16, device_map='auto')
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3307, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3695, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 741, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 362, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048]) in "bias" (which has shape torch.Size([1, 2048])), this look incorrect.

So I modified the code for the torch_int class w8a8b8o8linear(torch.nn.Module) to the following, and this is the result SmoothQuant INT8 accuracy: 0.407, per-sample latency: 38.878ms

#int8_module.weight = int8_weight
        int8_module.weight = torch.reshape(int8_weight,int8_module.weight.shape)
        #int8_module.bias = int8_bias
        int8_module.bias = torch.reshape(int8_bias, int8_module.bias.shape)
@lzd19981105
Copy link

Hi, I'm having trouble with the export_int8_model.py code results and would like to ask a question.

The model in the huggingface was fine, but I'm wondering what is causing the size issue when applying the int8 model saved from export_int8_model.py to Int8OPTForCausalLM.from_pretrained() in the examples/smoothquant_opt_real_int8_demo.ipynb code.

Evaluating SmoothQuant INT8 model...
Traceback (most recent call last):
  File "smoothquant_opt.py", line 84, in <module>
    model_smoothquant = Int8OPTForCausalLM.from_pretrained('/home/hyeok/smoothquant/output/opt-1.3b-smoothquant', torch_dtype=torch.float16, device_map='auto')
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3307, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3695, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/transformers/modeling_utils.py", line 741, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/home/hyeok/.conda/envs/int/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 362, in set_module_tensor_to_device
    raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([2048]) in "bias" (which has shape torch.Size([1, 2048])), this look incorrect.

So I modified the code for the torch_int class w8a8b8o8linear(torch.nn.Module) to the following, and this is the result SmoothQuant INT8 accuracy: 0.407, per-sample latency: 38.878ms

#int8_module.weight = int8_weight
        int8_module.weight = torch.reshape(int8_weight,int8_module.weight.shape)
        #int8_module.bias = int8_bias
        int8_module.bias = torch.reshape(int8_bias, int8_module.bias.shape)

That's because the shape of bias initialized is (1,self.out_features), you can modify it to (self.out_features,) to solve the problem. In my case, the opt-1.3b-int8 accuracy is 0.698, per-sample latency is 35.392 on 3090.

class W8A8B8O8Linear(torch.nn.Module):
    # For qkv_proj
    def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer('weight', torch.randint(-127, 127, (self.out_features,
                                                                 self.in_features), dtype=torch.int8, requires_grad=False))
        self.register_buffer('bias', torch.zeros(
            (self.out_features,), dtype=torch.int8, requires_grad=False))
        self.register_buffer('a', torch.tensor(alpha))
        self.register_buffer('b', torch.tensor(beta))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants