Skip to content

Commit

Permalink
Add JAIS model(s) (ggml-org#8118)
Browse files Browse the repository at this point in the history
* Add `JAIS` model(s)

* cleanup

* address review comments

* remove hack

* un-hardcode max-alibi-bias

* minor tweaks

---------

Co-authored-by: fmz <[email protected]>
  • Loading branch information
2 people authored and Nexesenex committed Jul 2, 2024
1 parent cef61ac commit 86e9f53
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 9 deletions.
1 change: 1 addition & 0 deletions convert-hf-to-gguf-update.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
]


Expand Down
93 changes: 93 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
# ref: https://huggingface.co/core42/jais-13b
res = "jais"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -2979,6 +2982,96 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("JAISLMHeadModel")
class JaisModel(Model):
model_arch = gguf.MODEL_ARCH.JAIS

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# SwigLU activation
assert self.hparams["activation_function"] == "swiglu"
# ALiBi position embedding
assert self.hparams["position_embedding_type"] == "alibi"

# Embeddings scale
self.embeddings_scale = 1.0
# note: For some JAIS flavors, output is tied to (same as) wte in original model
self.output_is_wte = False
if 'mup_embeddings_scale' in self.hparams:
self.output_is_wte = True # Hack (?)
self.embeddings_scale = self.hparams['mup_embeddings_scale']
elif 'embeddings_scale' in self.hparams:
self.embeddings_scale = self.hparams['embeddings_scale']
else:
assert False

self.width_scale = 1.0
if 'mup_output_alpha' in self.hparams:
assert 'mup_width_scale' in self.hparams
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
elif 'width_scale' in self.hparams:
self.width_scale = self.hparams['width_scale']
else:
assert False

self.max_alibi_bias = 8.0

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

tensors: list[tuple[str, Tensor]] = []

# we don't need these
if name.endswith((".attn.bias")):
return tensors

if name.endswith(("relative_pe.slopes")):
# Calculate max ALiBi bias (this is the inverse of the ALiBi calculation)
# Some other models has max_alibi_bias spelled out explicitly in the hyperparams,
# but Jais's PyTorch model simply precalculates the slope values and places them
# in relative_pes.slopes
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
first_val = float(data_torch._data[0])
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)

return tensors

if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
data_torch = data_torch.transpose(1, 0)

new_name = self.map_tensor_name(name)

if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
tensors.append((new_name, data_torch * self.embeddings_scale))
if self.output_is_wte:
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
assert not self.output_is_wte
tensors.append((new_name, data_torch * self.width_scale))
else:
tensors.append((new_name, data_torch))

return tensors

def write_tensors(self):
super().write_tensors()
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)


###### CONVERSION LOGIC ######


Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
BITNET = auto()
T5 = auto()
JAIS = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -288,6 +289,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -954,6 +956,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
19 changes: 10 additions & 9 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
Expand Down Expand Up @@ -49,7 +49,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -58,7 +58,7 @@ class TensorNameMap:
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon
"transformer.ln_f", # gpt2 gpt-j falcon jais
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
Expand All @@ -81,7 +81,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
Expand Down Expand Up @@ -109,7 +109,7 @@ class TensorNameMap:
# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
Expand Down Expand Up @@ -160,7 +160,7 @@ class TensorNameMap:
# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
Expand Down Expand Up @@ -202,7 +202,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
Expand Down Expand Up @@ -239,7 +239,7 @@ class TensorNameMap:
# Feed-forward up
MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
Expand Down Expand Up @@ -285,6 +285,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
Expand All @@ -308,7 +309,7 @@ class TensorNameMap:
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
Expand Down
Loading

0 comments on commit 86e9f53

Please sign in to comment.