Skip to content

Commit

Permalink
llama : support for falcon-mamba architecture (ggerganov#9074)
Browse files Browse the repository at this point in the history
* feat: initial support for llama.cpp

* fix: lint

* refactor: better refactor

* Update src/llama.cpp

Co-authored-by: compilade <[email protected]>

* Update src/llama.cpp

Co-authored-by: compilade <[email protected]>

* fix: address comments

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <[email protected]>

* fix: add more cleanup and harmonization

* fix: lint

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <[email protected]>

* fix: change name

* Apply suggestions from code review

Co-authored-by: compilade <[email protected]>

* add in operator

* fix: add `dt_b_c_rms` in `llm_load_print_meta`

* fix: correct printf format for bool

* fix: correct print format

* Update src/llama.cpp

Co-authored-by: compilade <[email protected]>

* llama : quantize more Mamba tensors

* llama : use f16 as the fallback of fallback quant types

---------

Co-authored-by: compilade <[email protected]>
  • Loading branch information
2 people authored and arthw committed Nov 15, 2024
1 parent 93b36e4 commit 8e94e91
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down
32 changes: 10 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
Expand Down Expand Up @@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

Expand Down Expand Up @@ -2742,20 +2743,24 @@ def set_gguf_parameters(self):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

use_dt_b_c_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
use_dt_b_c_norm = True
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype)

_tok_embd = None
Expand All @@ -2782,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
if bid is not None and new_name in (
self.format_tensor_name(
n, bid, ".weight" if name.endswith(".weight") else ""
)
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
):
return gguf.GGMLQuantizationType.F32

return super().tensor_force_quant(name, new_name, bid, n_dims)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
Expand Down Expand Up @@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")
assert (hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
Expand Down Expand Up @@ -3855,8 +3843,8 @@ def prepare_tensors(self):

super().prepare_tensors()

###### CONVERSION LOGIC ######

###### CONVERSION LOGIC ######

# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class SSM:
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"

class Tokenizer:
MODEL = "tokenizer.ggml.model"
Expand Down Expand Up @@ -1372,6 +1373,7 @@ def get_type(val: Any) -> GGUFValueType:
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS

# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None:
def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)

def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
22 changes: 20 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_DT_B_C_RMS,

LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_PRE,
Expand Down Expand Up @@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },

{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
Expand Down Expand Up @@ -2237,6 +2239,7 @@ struct llama_hparams {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
bool ssm_dt_b_c_rms = false;

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
Expand Down Expand Up @@ -2286,6 +2289,7 @@ struct llama_hparams {
if (this->ssm_d_inner != other.ssm_d_inner) return true;
if (this->ssm_d_state != other.ssm_d_state) return true;
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;

if (this->dec_start_token_id != other.dec_start_token_id) return true;

Expand Down Expand Up @@ -5052,6 +5056,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);

ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand Down Expand Up @@ -5907,6 +5912,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
}

LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
Expand Down Expand Up @@ -12165,6 +12171,10 @@ struct llm_build_context {
GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
// Use the same RMS norm as the final layer norm
const float norm_rms_eps = hparams.f_norm_rms_eps;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;
Expand Down Expand Up @@ -12245,6 +12255,13 @@ struct llm_build_context {
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));

// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
}

// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
Expand Down Expand Up @@ -16109,6 +16126,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
}
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
new_type = GGML_TYPE_F16;
}
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
++qs.n_fallback;
}
Expand Down Expand Up @@ -16437,8 +16457,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
quantize &= name.find("ssm_x.weight") == std::string::npos;
quantize &= name.find("ssm_dt.weight") == std::string::npos;

// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
Expand Down

0 comments on commit 8e94e91

Please sign in to comment.