Skip to content

Commit

Permalink
Merge branch 'master' into compilade/refactor-kv-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed May 24, 2024
2 parents cbc743e + d041d2c commit 0fd13e9
Show file tree
Hide file tree
Showing 37 changed files with 3,540 additions and 7,281 deletions.
5 changes: 5 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ server:
ggml:
- changed-files:
- any-glob-to-any-file:
- ggml.c
- ggml.h
- ggml-*.c
- ggml-*.h
- ggml-cuda/**
Expand All @@ -71,3 +73,6 @@ nix:
- "**/*.nix"
- .github/workflows/nix-*.yml
- .devops/nix/nixpkgs-instances.nix
embedding:
- changed-files:
- any-glob-to-any-file: examples/embedding/
5 changes: 3 additions & 2 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ jobs:
- { tag: "light-rocm", dockerfile: ".devops/main-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
- { tag: "server-rocm", dockerfile: ".devops/server-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
- { tag: "light-intel", dockerfile: ".devops/main-intel.Dockerfile", platforms: "linux/amd64" }
- { tag: "server-intel", dockerfile: ".devops/server-intel.Dockerfile", platforms: "linux/amd64" }
# TODO: Disabled due to build issues https://github.com/ggerganov/llama.cpp/issues/7507
#- { tag: "light-intel", dockerfile: ".devops/main-intel.Dockerfile", platforms: "linux/amd64" }
#- { tag: "server-intel", dockerfile: ".devops/server-intel.Dockerfile", platforms: "linux/amd64" }
steps:
- name: Check out the repo
uses: actions/checkout@v4
Expand Down
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ set(LLAMA_METAL_MACOSX_VERSION_MIN "" CACHE STRING
set(LLAMA_METAL_STD "" CACHE STRING "llama: metal standard version (-std flag)")
option(LLAMA_KOMPUTE "llama: use Kompute" OFF)
option(LLAMA_RPC "llama: use RPC" OFF)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_SYCL "llama: use SYCL" OFF)
option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl calculations" OFF)
set(LLAMA_SYCL_TARGET "INTEL" CACHE STRING "llama: sycl target device")
Expand Down Expand Up @@ -384,10 +383,6 @@ if (LLAMA_LLAMAFILE)
set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
endif()

if (LLAMA_QKK_64)
add_compile_definitions(GGML_QKK_64)
endif()

if (LLAMA_CUBLAS)
message(WARNING "LLAMA_CUBLAS is deprecated and will be removed in the future.\nUse LLAMA_CUDA instead")
set(LLAMA_CUDA ON)
Expand Down
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,6 @@ else
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
endif

ifdef LLAMA_QKK_64
MK_CPPFLAGS += -DGGML_QKK_64
endif

ifndef LLAMA_NO_ACCELERATE
# Mac OS - include Accelerate framework.
# `-framework Accelerate` works both with Apple Silicon and Mac Intel
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
- [x] [OLMo](https://allenai.org/olmo)
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/HOWTO-add-model.md))

Expand All @@ -140,6 +141,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [Yi-VL](https://huggingface.co/models?search=Yi-VL)
- [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM)
- [x] [Moondream](https://huggingface.co/vikhyatk/moondream2)
- [x] [Bunny](https://github.com/BAAI-DCAI/Bunny)

**HTTP server**

Expand Down
215 changes: 176 additions & 39 deletions ci/run.sh

Large diffs are not rendered by default.

189 changes: 189 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,44 @@ def set_gguf_parameters(self):
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))

tensors: list[tuple[str, Tensor]] = []

if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
# Map bloom-style qkv_linear to gpt-style qkv_linear
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
data_torch = torch.cat(
(
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
),
dim=0,
)
logger.info("re-format attention.linear_qkv.weight")
elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
data_torch = torch.cat(
(
qkv_bias[:, 0, :].reshape((n_embed,)),
qkv_bias[:, 1, :].reshape((n_embed,)),
qkv_bias[:, 2, :].reshape((n_embed,)),
),
dim=0,
)
logger.info("re-format attention.linear_qkv.bias")

tensors.append((self.map_tensor_name(name), data_torch))

return tensors


@Model.register("BloomForCausalLM")
class BloomModel(Model):
Expand Down Expand Up @@ -2529,6 +2567,157 @@ def set_vocab(self, *args, **kwargs):
self.gguf_writer.add_add_eos_token(True)


@Model.register("ArcticForCausalLM")
class ArcticModel(Model):
model_arch = gguf.MODEL_ARCH.ARCTIC

def set_vocab(self):
# The reason for using a custom implementation here is that the
# snowflake-arctic-instruct model redefined tokens 31998 and 31999 from
# tokenizer.model and used them as BOS and EOS instead of adding new tokens.
from sentencepiece import SentencePieceProcessor

tokenizer_path = self.dir_model / 'tokenizer.model'

if not tokenizer_path.is_file():
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)

# Read the whole vocabulary from the tokenizer.model file
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))

vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size

for token_id in range(tokenizer.vocab_size()):

piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.IsUnknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.IsControl(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.IsUnused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype

# Use the added_tokens_decoder field from tokeniser_config.json as the source
# of information about added/redefined tokens and modify them accordingly.
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
tokenizer_config_json = json.load(f)

if "added_tokens_decoder" in tokenizer_config_json:
added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"]
for token_id, token_json in added_tokens_decoder.items():
token_id = int(token_id)
if (token_id >= vocab_size):
logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue

token_content = token_json["content"]
token_type = SentencePieceTokenTypes.USER_DEFINED
token_score = -10000.0

# Map unk_token to UNKNOWN, other special tokens to CONTROL
# Set the score to 0.0 as in the original tokenizer.model
if ("special" in token_json) and token_json["special"]:
if token_content == tokenizer_config_json["unk_token"]:
token_type = SentencePieceTokenTypes.UNKNOWN
else:
token_type = SentencePieceTokenTypes.CONTROL
token_score = 0.0

logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})")
tokens[token_id] = token_content.encode("utf-8")
toktypes[token_id] = token_type
scores[token_id] = token_score

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_tokenizer_pre("default")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

if name.endswith("q_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith("k_proj.weight"):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams["num_local_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for wid in ["w1", "w2", "w3"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def write_tensors(self):
super().write_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


###### CONVERSION LOGIC ######


Expand Down
3 changes: 2 additions & 1 deletion examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
struct ggml_tensor * t16;
if (enable_flash_attn) {
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
//t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
} else {
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ int main(int argc, char ** argv) {
LOG_TEE("\n\n");

if (params.interactive) {
const char *control_message;
const char * control_message;
if (params.multiline_input) {
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
control_message = " - To return control to the AI, end your input with '\\'.\n"
" - To return control without starting a new line, end your input with '/'.\n";
} else {
control_message = " - Press Return to return control to LLaMa.\n"
control_message = " - Press Return to return control to the AI.\n"
" - To return control without starting a new line, end your input with '/'.\n"
" - If you want to submit another line, end your input with '\\'.\n";
}
Expand Down
4 changes: 2 additions & 2 deletions examples/sycl/win-build-sycl.bat
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ if %errorlevel% neq 0 goto ERROR

:: for FP16
:: faster for long-prompt inference
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON

:: for FP32
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release
if %errorlevel% neq 0 goto ERROR
:: build example/main only
:: make main
Expand Down
3 changes: 2 additions & 1 deletion examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ static struct ggml_tensor * llama_build_train_graphs(
struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
struct ggml_tensor * t16;
if (enable_flash_attn) {
t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
GGML_ASSERT(false && "TODO: ggml_flash_attn_ext() not yet supported");
//t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
} else {
struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
Expand Down
12 changes: 6 additions & 6 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0fd13e9

Please sign in to comment.