Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IBM Granite Architecture #9412

Merged
merged 9 commits into from
Sep 17, 2024
Merged

IBM Granite Architecture #9412

merged 9 commits into from
Sep 17, 2024

Conversation

gabe-l-hart
Copy link
Contributor

@gabe-l-hart gabe-l-hart commented Sep 10, 2024

Description

This PR introduces the granite model architecture from IBM. It emulates the transformers changes in this PR.

The granite architecture is closely related to llama and therefore reuses many of the existing llama code paths. The primary differences are:

  1. embeddings_multiplier: A scaling factor applied to the input embeddings
  2. attention_multiplier: A configurable scaling factor applied to the output of each attention block that replaces 1 / sqrt(n_embd_head)
  3. residual_multiplier: A scaling factor applied when recombining the attention output with the residual and again when combining the FFN output with the residual
  4. logits_scale: A scaling factor applied to the final logits before decoding

Design Decisions/Questions

Since this is my first model addition, there were a couple of design choices that I took my best guess on, but would like to point out for further discussion:

  • In convert_hf_to_gguf.py, I derived GraniteModel from LlamaModel to avoid copy-pasting the defaults there. I see that other models derive directly from Model and avoid inheritance, so I would be open to following that pattern if it's preferred.
  • Parameter naming (llama.cpp):
    • There was no clear existing parameter for embeddings_multiplier/attention_multiplier/residual_multiplier, so I opted to use the same names that are used in transformers. This does, however, leave an ugly asymmetry with logits_scale, so I thought about *_scale as well.
    • I noticed that in llama_hparams, some variables have a prefix character, while others don't. I opted to follow the f_ prefix notation, but am happy to remove that if it's not to convention.
  • build_llama: Similar to using LlamaModel as a base class, I opted to make small tweaks to the build_llama function rather than copy-paste into a clean-room build_granite function with the small deltas.
  • Unit Testing: I see that there are not unit tests for each architecture and other model addition PRs don't seem to include them, so I did not either

Testing

To test the model architecture, I am working with IBM's preview model ibm/PowerLM-3b.

huggingface-cli download ibm/PowerLM-3b --local-dir $HOME/models/powerlm-3b

Conversion

python convert_hf_to_gguf.py --verbose $HOME/models/powerlm-3b/

Simple Execution

./build/bin/llama-simple -m $HOME/models/powerlm-3b/powerlm-3B-F16.gguf -p "Write a code to find the maximum value in a list of numbers." -n 100

ASSISTANT: Here's a code to find the maximum value in a list of numbers:

def find_max(numbers):
 max_value = numbers[0]
 for num in numbers:
 if num > max_value:
 max_value = num
 return max_value

This code defines a function called find_max that takes a list of numbers as

Quantization

./build/bin/llama-quantize $HOME/models/powerlm-3b/powerlm-3B-F16.gguf Q4_K_M

# Simple with the quantized model
./build/bin/llama-simple -m /Users/ghart/models/powerlm-3b/ggml-model-Q4_K_M.gguf -p "Write a code to find the maximum value in a list of numbers." -n 100

ASSISTANT: Sure, here's a code to find the maximum value in a list of numbers:

def find_max(numbers):
 max_value = numbers[0]
 for num in numbers:
 if num > max_value:
 max_value = num
 return max_value

This code defines a function called find_max that takes a list of

@github-actions github-actions bot added the python python script changes label Sep 10, 2024
@gabe-l-hart
Copy link
Contributor Author

NOTE: I've adjusted the self-reported complexity on this one to Low. It was definitely Medium for me as I ramped on the codebase, but for those of you already deeply familiar, it should be relatively straightforward.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR - good job!

In convert_hf_to_gguf.py, I derived GraniteModel from LlamaModel to avoid copy-pasting the defaults there. I see that other models derive directly from Model and avoid inheritance, so I would be open to following that pattern if it's preferred.

I believe this is OK. Ping @compilade to confirm.

There was no clear existing parameter for embeddings_multiplier/attention_multiplier/residual_multiplier, so I opted to use the same names that are used in transformers. This does, however, leave an ugly asymmetry with logits_scale, so I thought about *_scale as well.

Yes, it's a tough call. Just from llama.cpp PoV, it's best to have all factors with the same suffix/prefix for consistency. But on the other hand, matching the names with existing frameworks like transformers is also valuable. We haven't been consistent either way, so I guess we can leave it as proposed.

I noticed that in llama_hparams, some variables have a prefix character, while others don't. I opted to follow the f_ prefix notation, but am happy to remove that if it's not to convention.

Again, we have not been very consistent in this regard. Adding the f_ prefix for now is OK and in the future we will try to normalize the names.

build_llama: Similar to using LlamaModel as a base class, I opted to make small tweaks to the build_llama function rather than copy-paste into a clean-room build_granite function with the small deltas.

Good call.

Unit Testing: I see that there are not unit tests for each architecture and other model addition PRs don't seem to include them, so I did not either

We don't have unit tests yet for the build code.

@ggerganov ggerganov mentioned this pull request Sep 12, 2024
4 tasks
Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In convert_hf_to_gguf.py, I derived GraniteModel from LlamaModel to avoid copy-pasting the defaults there. I see that other models derive directly from Model and avoid inheritance, so I would be open to following that pattern if it's preferred.

This is okay; inheritance where it makes sense is useful and is preferred in some cases (e.g. Bert-like models also make use of it).

Comment on lines +3949 to +4099
if head_dim := self.hparams.pop("head_dim", None):
logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim)
Copy link
Collaborator

@compilade compilade Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why head_dim should be actively removed (with pop) when present instead of simply not requiring it.

Is there a particular reason why this is done here? Does head_dim cause problems for Granite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I'm not intimately familiar with why it was done this way, but in the transformers implementation, there is a ValueError raised if head_dim is not directly computed from num_heads and hidden_size and it is not supported as a parameter to the model's config, so I put this here as a guard against an invalid config.json that had it present.

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
@gabe-l-hart
Copy link
Contributor Author

Thanks for the reviews @ggerganov and @compilade! All of your comments make sense. I'll plan to address them first thing on Monday

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
…lama

Something is still not working right since the results are mostly terrible,
but on occasion it's producing relevant results at this point, so
_something_ is working.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
The defaults in LlamaModel are needed for Granite as well

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
Other scalar multipliers are called *_scale, so this provides a more
consistent naming convention.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
The transformers names with _multiplier will now be converted to the _scale
equivalent during conversion.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Contributor Author

Thanks again for the quick reviews! I've addressed them with the following updates:

  • Switch from *_multiplier to *_scale and convert from the transformers during model conversion. I like keeping this consistent with the other hparams in llama.cpp.
  • Split out the LLM_ARCH_GRANITE clause in llm_load_hparams so avoid the added conditionals in the LLM_ARCH_LLAMA clause.

@ggerganov ggerganov merged commit 0d2ec43 into ggerganov:master Sep 17, 2024
54 checks passed
dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
* feat(gguf-py): Add Granite model and params to gguf-py

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add registration and param setup for Granite

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): Add config parsing for Granite multiplier params

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): First pass at full port of granite deviations from llama

Something is still not working right since the results are mostly terrible,
but on occasion it's producing relevant results at this point, so
_something_ is working.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Determine granite language 3b instruct by vocab size

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel

The defaults in LlamaModel are needed for Granite as well

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Switch Granite param names to use _scale for consistency

Other scalar multipliers are called *_scale, so this provides a more
consistent naming convention.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale

The transformers names with _multiplier will now be converted to the _scale
equivalent during conversion.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Nov 4, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9412

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Nov 4, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9412

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* feat(gguf-py): Add Granite model and params to gguf-py

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add registration and param setup for Granite

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): Add config parsing for Granite multiplier params

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): First pass at full port of granite deviations from llama

Something is still not working right since the results are mostly terrible,
but on occasion it's producing relevant results at this point, so
_something_ is working.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Determine granite language 3b instruct by vocab size

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel

The defaults in LlamaModel are needed for Granite as well

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Switch Granite param names to use _scale for consistency

Other scalar multipliers are called *_scale, so this provides a more
consistent naming convention.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale

The transformers names with _multiplier will now be converted to the _scale
equivalent during conversion.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* feat(gguf-py): Add Granite model and params to gguf-py

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add registration and param setup for Granite

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): Add config parsing for Granite multiplier params

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(llama.cpp): First pass at full port of granite deviations from llama

Something is still not working right since the results are mostly terrible,
but on occasion it's producing relevant results at this point, so
_something_ is working.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Determine granite language 3b instruct by vocab size

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf): Use LlamaModel as base for GraniteModel

The defaults in LlamaModel are needed for Granite as well

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Switch Granite param names to use _scale for consistency

Other scalar multipliers are called *_scale, so this provides a more
consistent naming convention.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert_hf_to_gguf/gguf-py): _multiplier -> _scale

The transformers names with _multiplier will now be converted to the _scale
equivalent during conversion.

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(llama.cpp): Use separate switch clause for granite in llm_load_hparams

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Dec 10, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9412

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants