Skip to content

Commit

Permalink
Merge branch 'master' into compilade/refactor-kv-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Sep 2, 2024
2 parents fcb889c + 8f1d81a commit a03e32a
Show file tree
Hide file tree
Showing 9 changed files with 1,273 additions and 107 deletions.
84 changes: 83 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import ast
import logging
import argparse
import contextlib
Expand Down Expand Up @@ -298,9 +299,12 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
gguf.MODEL_TENSOR.TIME_MIX_W1,
gguf.MODEL_TENSOR.TIME_MIX_W2,
)
)
or not name.endswith(".weight")
or not new_name.endswith(".weight")
):
data_qtype = gguf.GGMLQuantizationType.F32

Expand Down Expand Up @@ -2716,6 +2720,84 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("Rwkv6ForCausalLM")
class Rwkv6Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6

def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)

tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]

with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)

self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_size = self.hparams["head_size"]
hidden_size = self.hparams["hidden_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
rescale_every_n_layers = self.hparams["rescale_every"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32)
time_mix_extra_dim = 64 if hidden_size == 4096 else 32
time_decay_extra_dim = 128 if hidden_size == 4096 else 64

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)

# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
new_name = self.map_tensor_name(name)

if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"

if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"):
data_torch = data_torch.transpose(0, 1)

if new_name.endswith("time_mix_w2.weight"):
data_torch = data_torch.permute(0, 2, 1)

rescale_every_n_layers = self.hparams["rescale_every"]
if rescale_every_n_layers > 0:
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))

yield (new_name, data_torch)


@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA
Expand Down
19 changes: 19 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ extern "C" {
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV,

GGML_OP_UNARY,

Expand Down Expand Up @@ -548,6 +549,7 @@ extern "C" {
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,

GGML_UNARY_OP_COUNT,
};
Expand Down Expand Up @@ -1165,6 +1167,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_exp_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1913,6 +1923,15 @@ extern "C" {
struct ggml_tensor * pw,
struct ggml_tensor * ph);

GGML_API struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state);

// custom operators

typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
Expand Down
Loading

0 comments on commit a03e32a

Please sign in to comment.