Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support for falcon-mamba architecture #9074

Merged
merged 18 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down
17 changes: 11 additions & 6 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,7 +2711,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

Expand All @@ -2731,7 +2731,7 @@ def set_vocab(self):
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
Expand All @@ -2742,20 +2742,25 @@ def set_gguf_parameters(self):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

num_hidden_layers = self.find_hparam(["n_layer", "num_hidden_layers"])
use_b_dt_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"]) in ["falcon_mamba"]:
compilade marked this conversation as resolved.
Show resolved Hide resolved
use_b_dt_norm = True
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(num_hidden_layers)
compilade marked this conversation as resolved.
Show resolved Hide resolved
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_mamba_b_dt_rms(use_b_dt_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype)

_tok_embd = None
Expand Down Expand Up @@ -3792,7 +3797,7 @@ class ExaoneModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")
assert (hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
Expand Down Expand Up @@ -3854,10 +3859,10 @@ def prepare_tensors(self):
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))

super().prepare_tensors()


###### CONVERSION LOGIC ######


# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class SSM:
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
B_DT_RMS = "{arch}.ssm.b_dt_rms"

class Tokenizer:
MODEL = "tokenizer.ggml.model"
Expand Down Expand Up @@ -1372,6 +1373,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_B_DT_RMS = Keys.SSM.B_DT_RMS

# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:

def add_rope_scaling_finetuned(self, value: bool) -> None:
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)

def add_mamba_b_dt_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.B_DT_RMS.format(arch=self.arch), value)

def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
Expand Down
15 changes: 15 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_B_DT_RMS,

LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
Expand Down Expand Up @@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_B_DT_RMS, "%s.ssm.b_dt_rms" },
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
Expand Down Expand Up @@ -2237,6 +2239,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_b_dt_rms = false;

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
Expand Down Expand Up @@ -5052,6 +5055,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_B_DT_RMS, hparams.ssm_b_dt_rms);
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand Down Expand Up @@ -12161,6 +12165,10 @@ struct llm_build_context {
GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_b_dt_rms = hparams.ssm_b_dt_rms;
// Use the same RMS norm as the final layer norm
const float norm_rms_eps = hparams.f_norm_rms_eps;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;
Expand Down Expand Up @@ -12241,6 +12249,13 @@ struct llm_build_context {
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_b_dt_rms) {
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will eventually be rewritten to use llm_build_norm, because some Mamba-based architectures like Jamba use RMS norms with learnable parameters here.

But for now I think this is fine.


// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
Expand Down
Loading