Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IBM Granite MoE Architecture #9438

Merged
merged 10 commits into from
Sep 25, 2024
Merged

Conversation

gabe-l-hart
Copy link
Contributor

@gabe-l-hart gabe-l-hart commented Sep 11, 2024

Dependencies

Description

This PR introduces the granitemoe model architecture from IBM. It emulates the transformers changes in this PR.

The granitemoe architecture follows a very similar pattern to the granite architecture and its changes relative to llama. For the MoE variant, the base architecture is mixtral (MoE branch of llama here in llama.cpp). The same four additional multipliers are added (embeddings_multiplier, attention_multiplier, residual_multiplier, and logits_scale).

Testing

This PR can be tested using ibm/PowerMoE-3b from huggingface following the same testing steps used for granite (here).

@gabe-l-hart gabe-l-hart force-pushed the GraniteMoE branch 2 times, most recently from 5f37be3 to 3219f58 Compare September 11, 2024 16:29
@github-actions github-actions bot added the python python script changes label Sep 11, 2024
@gabe-l-hart gabe-l-hart force-pushed the GraniteMoE branch 3 times, most recently from 1b235d0 to 2615459 Compare September 17, 2024 12:46
@gabe-l-hart gabe-l-hart marked this pull request as ready for review September 17, 2024 12:46
@gabe-l-hart
Copy link
Contributor Author

Hi @compilade @ggerganov! This PR is now ready for full review.

We're eager to get the granitemoe architecture fully supported in llama.cpp (and then following up with support in ollama). I'm sure you are perpetually swamped, so I just want to get a quick check on if this is in the review queue for you at this point and if you have any targets for merging support.

(also, thanks for the great project and all the work you do!)

convert_hf_to_gguf.py Outdated Show resolved Hide resolved
@gabe-l-hart
Copy link
Contributor Author

It looks like the failing test is on the windows server's Erase Slot server logs scenario. This seems like it should be unrelated to this PR. Without knowing the tests well, is there any likelihood that this is a false negative? I can dig further if needed.

convert_hf_to_gguf.py Outdated Show resolved Hide resolved
@ggerganov
Copy link
Owner

is there any likelihood that this is a false negative?

Yes, this is unrelated to the PR, no need to investigate.

gguf-py/gguf/tensor_mapping.py Outdated Show resolved Hide resolved
gguf-py/gguf/tensor_mapping.py Outdated Show resolved Hide resolved
convert_hf_to_gguf.py Outdated Show resolved Hide resolved
gguf-py/gguf/constants.py Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
@gabe-l-hart
Copy link
Contributor Author

Thanks for the detailed review @compilade! I believe I have all of the comments addressed at this point.

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the first few chunks of wikitext-2-raw with llama-perplexity and https://huggingface.co/ibm/PowerMoE-3b at Q8_0, I get [1]4.4570,[2]5.1116,[3]5.3469,[4]5.9955, so this does appear to work correctly.

gabe-l-hart and others added 9 commits September 24, 2024 10:24
This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>
GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>
… and up

After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: [email protected]

Signed-off-by: Gabe Goodhart <[email protected]>
GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>
Co-Authored-By: [email protected]

Co-authored-by: Georgi Gerganov <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>
This is a fix for the previous `granite` architecture PR. Recent snapshots
have included this (`lm_head.weights`) as part of the architecture

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Contributor Author

@compilade After you pointed out that I was missing output in the recent granitemoe snapshots, I dug a little deeper and it seems that the model team has added this for granite (dense) as well. I've added another commit to this PR to fix that as well. I'm not sure the preferred PR hygiene, so I'm happy to move that to a separate fix PR if the preference is for more well-encapsulated changes.

@compilade compilade added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Sep 24, 2024
@ggerganov ggerganov merged commit 3d6bf69 into ggerganov:master Sep 25, 2024
55 checks passed
@gabe-l-hart gabe-l-hart deleted the GraniteMoE branch September 25, 2024 12:45
ericcurtin added a commit to containers/ramalama that referenced this pull request Oct 21, 2024
This was added recently to llama.cpp:

ggerganov/llama.cpp#9438

Signed-off-by: Eric Curtin <[email protected]>
dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
* feat(gguf-py): Add granitemoe architecture

This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add GraniteMoeModel

GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granitemoe convert): Split the double-sized input layer into gate and up

After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: [email protected]

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(granitemoe): Implement granitemoe

GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* Typo fix in docstring

Co-Authored-By: [email protected]

Co-authored-by: Georgi Gerganov <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(conversion): Simplify tensor name mapping in conversion

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Remove unused tensor name mappings

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Sanity check on merged FFN tensor sizes

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Allow "output" layer in granite moe architecture (convert and cpp)

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granite): Add missing 'output' tensor for Granite

This is a fix for the previous `granite` architecture PR. Recent snapshots
have included this (`lm_head.weights`) as part of the architecture

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Nov 4, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9438

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Nov 4, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9438

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* feat(gguf-py): Add granitemoe architecture

This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add GraniteMoeModel

GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granitemoe convert): Split the double-sized input layer into gate and up

After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: [email protected]

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(granitemoe): Implement granitemoe

GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* Typo fix in docstring

Co-Authored-By: [email protected]

Co-authored-by: Georgi Gerganov <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(conversion): Simplify tensor name mapping in conversion

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Remove unused tensor name mappings

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Sanity check on merged FFN tensor sizes

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Allow "output" layer in granite moe architecture (convert and cpp)

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granite): Add missing 'output' tensor for Granite

This is a fix for the previous `granite` architecture PR. Recent snapshots
have included this (`lm_head.weights`) as part of the architecture

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* feat(gguf-py): Add granitemoe architecture

This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add GraniteMoeModel

GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granitemoe convert): Split the double-sized input layer into gate and up

After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: [email protected]

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(granitemoe): Implement granitemoe

GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* Typo fix in docstring

Co-Authored-By: [email protected]

Co-authored-by: Georgi Gerganov <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(conversion): Simplify tensor name mapping in conversion

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Remove unused tensor name mappings

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Sanity check on merged FFN tensor sizes

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Allow "output" layer in granite moe architecture (convert and cpp)

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granite): Add missing 'output' tensor for Granite

This is a fix for the previous `granite` architecture PR. Recent snapshots
have included this (`lm_head.weights`) as part of the architecture

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
gabe-l-hart added a commit to gabe-l-hart/llamafile that referenced this pull request Dec 10, 2024
This is a port of the work done in llama.cpp directly
ggerganov/llama.cpp#9438

Branch: GraniteThreeSupport

Signed-off-by: Gabe Goodhart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge ready indicates that this may be ready to merge soon and is just holding out in case of objections python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants