Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable GPTQModel to handle GraniteMoeParallelExperts #122

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Abhishek-TAMU
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU commented Jan 28, 2025

ISSUE: #112

Major changes made to enable GraniteMoeParallelExperts layer in GPTQModel :

1- find_layers function in plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/model.py
2- class GPTQ : plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/quantization/gptq.py

  • init and add_batch function
  • TODO: fasterquant function: Need to understand first what is happening inside the function.

@@ -558,7 +558,7 @@ def save_quantized(
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
value=META_QUANTIZER_GPTQMODEL,
version=__version__,
version="1.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be changed?

self.H += inp.matmul(inp.t())
# Update entire H_list and nsamples_list
if not self.is_moe:
# print("INSIDE ADD_BATCH FOR 2D")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks abit overcomplicated. I think lets not have an else clause and assume that is for "inside an MoE". Im worried that this logic is rather fragile and could break.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lets not have an else clause and assume that is for "inside an MoE"

I didn't understand this. The current logic of add_batch for 2D deals with self.H but for 3D we have self.H_list, hence I assume when registering forward hook, all the H's in self.H_list would be updated. Though not sure on if H's are supposed to be updated equally for all experts, which I have done here.

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes lets not introduce this logic, because we know this function already works for mixtral, which is an MoE. So the design of this function does not really need introduction of H_list semantics to handle MoEs, it is a product of our earlier decision, to use ParameterList, which in hindsight may not be such a good idea now given that so many extensive logic changes are required in these internals

gidx_list = []
loss_list = []

for i in range(self.num_experts):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel there is quite a lot of complication but in actuality the MoE design is quite simple. Assuming you dont quantize the router, then the only thing that needs to be quantized are the experts

  • so right now the experts are put into ParameterList and so all this custom logic to loop over stuff needs to be written, which is not ideal
  • however, maybe you can consider a different design. For example, if you create a ModuleList instead and put each expert as a 2D Linear module inside that list, then I think perhaps alot of this custom logic does not need to be written

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked and this is what Mixtral is doing

MixtralDecoderLayer(
  (self_attn): MixtralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MixtralRotaryEmbedding()
  )
  (block_sparse_moe): MixtralSparseMoeBlock(
    (gate): Linear(in_features=4096, out_features=8, bias=False)
    (experts): ModuleList(
      (0-7): 8 x MixtralBlockSparseTop2MLP(
        (w1): Linear(in_features=4096, out_features=14336, bias=False)
        (w2): Linear(in_features=14336, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=14336, bias=False)
        (act_fn): SiLU()
      )
    )
  )
  (input_layernorm): MixtralRMSNorm((4096,), eps=1e-05)
  (post_attention_layernorm): MixtralRMSNorm((4096,), eps=1e-05)
)

base_modules = ["model.embed_tokens", "model.norm"]

layers_node = "model.layers"
layer_type = "GraniteMoeDecoderLayer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you add some simple key to inform the format of input_linear and output_linear, that these are 3D tensors.

Also in the granitemoe case, another compilation is that input_linear fuses w1 and w3. it might be ok for a first cut just to leave them as fused.

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically the simple key needs to know what do look for to convert it to 3D tensor, and then when you write layer_modules you write it as though they have been converrted

class GraniteMoeGPTQ(BaseGPTQModel):
    
    convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]

    layer_modules = [

        [
             "block_sparse_moe.input_linear.0.weight",
              "block_sparse_moe.input_linear.1.weight",
              ...
        ], [
             "block_sparse_moe.output_linear.0.weight",
              "block_sparse_moe.output_linear.1.weight",
              ...
        ]
    ]

layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"],
Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference MixtralGPTQ, you will see that they split up w1+ w3 and w2, which means we should split block_sparse_moe.input_linear and "block_sparse_moe.output_linear", see above

layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear]
# Can add MOE.GraniteMoeRMSNorm here if want to include Linear Norm layer ["input_layernorm", "post_attention_layernorm"]
# MOE.GraniteMoeParallelExperts is torch.nn.Module for layer ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]
layers = [transformers.pytorch_utils.Conv1D, nn.Conv2d, nn.Linear, MOE.GraniteMoeParallelExperts]
Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can transform module and break the 3D tensors into ModuleList , you dont need to modify this function

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that means input_linear -> with a ModuleList, and then we emulate the behaviors

means that input_linear(input) gives a 3D tensor, then even when its ModuleList we try to emuilate it and make it also give a 3D tensor

so maybe a direct replacement of input_linear is not possible, but replace it with a class inheritied from ModuleList

class 3DTensorModuleList(nn.ModuleList):

   def forward(...):
        # impl this

@@ -25,46 +26,113 @@ class GPTQ:
def __init__(self, layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if layer has been modified such that all 3D tensors are replaced with module lists, then no change in this function is needed

Losses[:, i1:i2] = Losses1 / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if not self.is_moe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same.. if module list replacement is done, then no logic change is needed

@Abhishek-TAMU
Copy link
Collaborator Author

@fabianlim Based on your comments I made a incremental change of using ModuleList with 2D nn.Linear layers to store 3D weight of expert layer. Just wanted to ask is this the direction we want to go forward with ?

@@ -52,6 +54,24 @@
logger.addHandler(handler)
logger.setLevel(logging.INFO)

class ThreeDTensorModuleList(nn.ModuleList):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Shape of input: (num_selected_experts * batch_size (expert_size), input_features_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this module is called ThreeDTensorModuleList but its written quite specifically for MoE. I think its fine, but then its design should then assume that it is an MoE, and not a generic module with three3 tensors

def find_layers(module, layers=None, name=""):
# print("1- INSIDE find_layers module", module)
module = check3DTensor(module, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is missing here is the logic for having the model recipe needs to pass in the convert3dToModuleList How do you plan to handle this?

One alternative is to handle the module swap completely outside of this function, then you dont need this logic

def check3DTensor(module, name, convert3dToModuleList=["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]):
if convert3dToModuleList and name in convert3dToModuleList:
# print("INSIDE check3DTensor module, name, convert3dToModuleList", module, name, convert3dToModuleList)
num_experts = module.num_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above, its written very specifically for an MoE, and also it assumes attributes like num_experts which may not exist generally.

If its too difficult to write a robust check function, then i suggest somehow to write a hook mechanism to take in convert3dToModuleList from GraniteMoeGPTQ and just do all the module swaps outside first in a prescibed manner


class GraniteMoeGPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]
convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert3dToModuleList has to be part of the BaseGPTQModel design, so it needs to exist in the parent class

Signed-off-by: Abhishek <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants