-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IBM Granite Architecture #9412
IBM Granite Architecture #9412
Conversation
234387d
to
2ed6a17
Compare
NOTE: I've adjusted the self-reported complexity on this one to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice PR - good job!
In convert_hf_to_gguf.py, I derived GraniteModel from LlamaModel to avoid copy-pasting the defaults there. I see that other models derive directly from Model and avoid inheritance, so I would be open to following that pattern if it's preferred.
I believe this is OK. Ping @compilade to confirm.
There was no clear existing parameter for embeddings_multiplier/attention_multiplier/residual_multiplier, so I opted to use the same names that are used in transformers. This does, however, leave an ugly asymmetry with logits_scale, so I thought about *_scale as well.
Yes, it's a tough call. Just from llama.cpp
PoV, it's best to have all factors with the same suffix/prefix for consistency. But on the other hand, matching the names with existing frameworks like transformers
is also valuable. We haven't been consistent either way, so I guess we can leave it as proposed.
I noticed that in llama_hparams, some variables have a prefix character, while others don't. I opted to follow the f_ prefix notation, but am happy to remove that if it's not to convention.
Again, we have not been very consistent in this regard. Adding the f_
prefix for now is OK and in the future we will try to normalize the names.
build_llama: Similar to using LlamaModel as a base class, I opted to make small tweaks to the build_llama function rather than copy-paste into a clean-room build_granite function with the small deltas.
Good call.
Unit Testing: I see that there are not unit tests for each architecture and other model addition PRs don't seem to include them, so I did not either
We don't have unit tests yet for the build code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In convert_hf_to_gguf.py, I derived GraniteModel from LlamaModel to avoid copy-pasting the defaults there. I see that other models derive directly from Model and avoid inheritance, so I would be open to following that pattern if it's preferred.
This is okay; inheritance where it makes sense is useful and is preferred in some cases (e.g. Bert-like models also make use of it).
if head_dim := self.hparams.pop("head_dim", None): | ||
logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why head_dim
should be actively removed (with pop
) when present instead of simply not requiring it.
Is there a particular reason why this is done here? Does head_dim
cause problems for Granite?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I'm not intimately familiar with why it was done this way, but in the transformers
implementation, there is a ValueError
raised if head_dim
is not directly computed from num_heads
and hidden_size
and it is not supported as a parameter to the model's config, so I put this here as a guard against an invalid config.json
that had it present.
Thanks for the reviews @ggerganov and @compilade! All of your comments make sense. I'll plan to address them first thing on Monday |
Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
…lama Something is still not working right since the results are mostly terrible, but on occasion it's producing relevant results at this point, so _something_ is working. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
The defaults in LlamaModel are needed for Granite as well Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
2ed6a17
to
8086380
Compare
Other scalar multipliers are called *_scale, so this provides a more consistent naming convention. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
The transformers names with _multiplier will now be converted to the _scale equivalent during conversion. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
…arams Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]>
Thanks again for the quick reviews! I've addressed them with the following updates:
|
* feat(gguf-py): Add Granite model and params to gguf-py Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(convert_hf_to_gguf): Add registration and param setup for Granite Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): Add config parsing for Granite multiplier params Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): First pass at full port of granite deviations from llama Something is still not working right since the results are mostly terrible, but on occasion it's producing relevant results at this point, so _something_ is working. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Determine granite language 3b instruct by vocab size Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel The defaults in LlamaModel are needed for Granite as well Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Switch Granite param names to use _scale for consistency Other scalar multipliers are called *_scale, so this provides a more consistent naming convention. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale The transformers names with _multiplier will now be converted to the _scale equivalent during conversion. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
This is a port of the work done in llama.cpp directly ggerganov/llama.cpp#9412 Branch: GraniteThreeSupport Signed-off-by: Gabe Goodhart <[email protected]>
This is a port of the work done in llama.cpp directly ggerganov/llama.cpp#9412 Branch: GraniteThreeSupport Signed-off-by: Gabe Goodhart <[email protected]>
* feat(gguf-py): Add Granite model and params to gguf-py Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(convert_hf_to_gguf): Add registration and param setup for Granite Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): Add config parsing for Granite multiplier params Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): First pass at full port of granite deviations from llama Something is still not working right since the results are mostly terrible, but on occasion it's producing relevant results at this point, so _something_ is working. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Determine granite language 3b instruct by vocab size Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel The defaults in LlamaModel are needed for Granite as well Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Switch Granite param names to use _scale for consistency Other scalar multipliers are called *_scale, so this provides a more consistent naming convention. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale The transformers names with _multiplier will now be converted to the _scale equivalent during conversion. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
* feat(gguf-py): Add Granite model and params to gguf-py Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(convert_hf_to_gguf): Add registration and param setup for Granite Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): Add config parsing for Granite multiplier params Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * feat(llama.cpp): First pass at full port of granite deviations from llama Something is still not working right since the results are mostly terrible, but on occasion it's producing relevant results at this point, so _something_ is working. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Determine granite language 3b instruct by vocab size Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel The defaults in LlamaModel are needed for Granite as well Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Switch Granite param names to use _scale for consistency Other scalar multipliers are called *_scale, so this provides a more consistent naming convention. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale The transformers names with _multiplier will now be converted to the _scale equivalent during conversion. Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> * fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams Branch: GraniteLM Signed-off-by: Gabe Goodhart <[email protected]> --------- Signed-off-by: Gabe Goodhart <[email protected]>
This is a port of the work done in llama.cpp directly ggerganov/llama.cpp#9412 Branch: GraniteThreeSupport Signed-off-by: Gabe Goodhart <[email protected]>
Description
This PR introduces the
granite
model architecture from IBM. It emulates thetransformers
changes in this PR.The
granite
architecture is closely related tollama
and therefore reuses many of the existingllama
code paths. The primary differences are:embeddings_multiplier
: A scaling factor applied to the input embeddingsattention_multiplier
: A configurable scaling factor applied to the output of each attention block that replaces1 / sqrt(n_embd_head)
residual_multiplier
: A scaling factor applied when recombining the attention output with the residual and again when combining the FFN output with the residuallogits_scale
: A scaling factor applied to the final logits before decodingDesign Decisions/Questions
Since this is my first model addition, there were a couple of design choices that I took my best guess on, but would like to point out for further discussion:
convert_hf_to_gguf.py
, I derivedGraniteModel
fromLlamaModel
to avoid copy-pasting the defaults there. I see that other models derive directly fromModel
and avoid inheritance, so I would be open to following that pattern if it's preferred.llama.cpp
):embeddings_multiplier
/attention_multiplier
/residual_multiplier
, so I opted to use the same names that are used intransformers
. This does, however, leave an ugly asymmetry withlogits_scale
, so I thought about*_scale
as well.llama_hparams
, some variables have a prefix character, while others don't. I opted to follow thef_
prefix notation, but am happy to remove that if it's not to convention.build_llama
: Similar to usingLlamaModel
as a base class, I opted to make small tweaks to thebuild_llama
function rather than copy-paste into a clean-roombuild_granite
function with the small deltas.Testing
To test the model architecture, I am working with IBM's preview model ibm/PowerLM-3b.
huggingface-cli download ibm/PowerLM-3b --local-dir $HOME/models/powerlm-3b
Conversion
python convert_hf_to_gguf.py --verbose $HOME/models/powerlm-3b/
Simple Execution
Quantization