Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable GPTQModel to handle GraniteMoeParallelExperts #122

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"granite",
"gemma",
"dbrx_converted",
"granitemoe",
]

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Expand All @@ -43,6 +44,7 @@
"granite": GraniteGPTQ,
"dbrx": DbrxGPTQ,
"dbrx_converted": DbrxConvertedGPTQ,
"granitemoe": GraniteMoeGPTQ,
}

at_least_one_cuda_v6 = any(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import accelerate
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

# Local
Expand All @@ -61,6 +62,7 @@
convert_gptq_v1_to_v2_format,
convert_gptq_v2_to_v1_format,
find_layers,
get_all_modules_by_name_suffix,
get_checkpoints,
get_device,
get_module_by_name_prefix,
Expand Down Expand Up @@ -91,6 +93,9 @@ class BaseGPTQModel(nn.Module):
# does not include the node which holds all the repeating layers
base_modules: List[str] = None

# 3D Module to be converted to ModuleList
convert_3d_modulelist: List[str] = None

# name of lm_head
lm_head: str = "lm_head"

Expand Down Expand Up @@ -223,6 +228,22 @@ def quantize(
if len(calibration_dataset) == 0:
raise ValueError("Calibration dataset must not be empty.")

##### SWAP 3D MODULES TO MODULELIST #####
if self.convert_3d_modulelist:
for name in self.convert_3d_modulelist:
matches = get_all_modules_by_name_suffix(self.model, name)
for parent, module, full_name in matches:

# Modify the matched module
if parent is not None:
new_module = self.swap_3d_tensors(module)

# Replace the old module with the new one
# Derive the child attribute name from the tail of full_name
child_name = full_name.split(".")[-1]

setattr(parent, child_name, new_module)

min_calibration_dataset_size = 256
min_calibration_dataset_input_ids_avg_length = 256

Expand Down Expand Up @@ -558,7 +579,7 @@ def save_quantized(
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
value=META_QUANTIZER_GPTQMODEL,
version=__version__,
version="1.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be changed?

)

# The config, quantize_config and model may be edited in place in save_quantized.
Expand Down Expand Up @@ -1211,5 +1232,48 @@ def __getattr__(self, item):
except Exception:
return getattr(self.model, item)

def swap_3d_tensors(self, module: nn.Module) -> nn.ModuleList:
"""Swap 3D Parameters to ModuleList of 3D Parameters."""

num_experts = module.num_experts
input_size = module.input_size
output_size = module.output_size
module = MoE3DModuleList(
[nn.Linear(input_size, output_size, bias=False) for _ in range(num_experts)]
)
return module


class MoE3DModuleList(nn.ModuleList):
def forward(self, inputs: torch.Tensor, expert_size: int) -> torch.Tensor:
"""
Forward pass of the MoE3DModuleList module.
Args:
inputs (Tensor):
Input tensor.
expert_size:
Expert size information.
Returns:
Tensor: Output tensor.
"""
input_list = inputs.split(expert_size, dim=0)
output_list = []

# Iterate over the number of selected experts and apply each expert to the corresponding input
for i in range(len(expert_size)):
# Extract weight and bias from the Linear module
weight = self[i].weight.to(device=inputs.device, dtype=inputs.dtype)
bias = (
self[i].bias.to(device=inputs.device, dtype=inputs.dtype)
if self[i].bias is not None
else None
)
expert_output = F.linear(input_list[i], weight, bias)
output_list.append(expert_output)

# Concatenate the outputs along the first dimension
results = torch.cat(output_list, dim=0)
return results


__all__ = ["BaseGPTQModel"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .base import BaseGPTQModel


class GraniteMoeGPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]
convert_3d_modulelist = [
"block_sparse_moe.input_linear",
"block_sparse_moe.output_linear",
]

layers_node = "model.layers"
layer_type = "GraniteMoeDecoderLayer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you add some simple key to inform the format of input_linear and output_linear, that these are 3D tensors.

Also in the granitemoe case, another compilation is that input_linear fuses w1 and w3. it might be ok for a first cut just to leave them as fused.

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically the simple key needs to know what do look for to convert it to 3D tensor, and then when you write layer_modules you write it as though they have been converrted

class GraniteMoeGPTQ(BaseGPTQModel):
    
    convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]

    layer_modules = [

        [
             "block_sparse_moe.input_linear.0.weight",
              "block_sparse_moe.input_linear.1.weight",
              ...
        ], [
             "block_sparse_moe.output_linear.0.weight",
              "block_sparse_moe.output_linear.1.weight",
              ...
        ]
    ]

layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
[f"block_sparse_moe.input_linear.{i}" for i in range(40)],
[f"block_sparse_moe.output_linear.{i}" for i in range(40)],
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
###############################################################################
# Standard
from logging import getLogger
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import functools
import hashlib
import json
Expand Down Expand Up @@ -128,6 +128,25 @@ def get_module_by_name_suffix(model, module_name: str):
return module


def get_all_modules_by_name_suffix(
model: nn.Module, target_suffix: str
) -> List[Tuple[Optional[nn.Module], nn.Module, str]]:
"""Find all modules in the model whose names end with the given suffix, along with their parent modules."""
name_to_module = dict(model.named_modules())
results = []
for full_name, mod in name_to_module.items():
if full_name.endswith(target_suffix):
split_name = full_name.split(".")
if len(split_name) > 1:
parent_name = ".".join(split_name[:-1])
else:
parent_name = ""

parent_module = name_to_module.get(parent_name, None)
results.append((parent_module, mod, full_name))
return results


def make_quant(
module,
names,
Expand Down
Loading