Skip to content

Commit

Permalink
Merge branch 'master' into compilade/batch-splits
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Aug 21, 2024
2 parents 1be5ea7 + b40eb84 commit 80d9d2a
Show file tree
Hide file tree
Showing 58 changed files with 2,500 additions and 658 deletions.
44 changes: 44 additions & 0 deletions .devops/llama-cli-cann.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8

FROM cosdt/cann:$ASCEND_VERSION AS build

WORKDIR /app

COPY . .

RUN yum install -y gcc g++ cmake make
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}

# find libascend_hal.so, because the drive hasn`t been mounted.
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH

RUN echo "Building with static libs" && \
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \
cmake --build build --config Release --target llama-cli

# TODO: use image with NNRT
FROM cosdt/cann:$ASCEND_VERSION AS runtime
COPY --from=build /app/build/bin/llama-cli /llama-cli

ENV LC_ALL=C.utf8

ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}

ENTRYPOINT ["/llama-cli" ]
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO: there have been some issues with the workflow, so disabling for now
# https://github.com/ggerganov/llama.cpp/issues/7893
#
# Benchmark
name: Benchmark

Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ poetry.toml

# Scripts
!/scripts/install-oneapi.bat

# Test models for lora adapters
/lora-tests
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ Typically finetunes of the base models below are supported as well.
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down Expand Up @@ -424,6 +426,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
| [Vulkan](./docs/build.md#vulkan) | GPU |
| [CANN](./docs/build.md#cann) | Ascend NPU |

## Tools

Expand Down
42 changes: 34 additions & 8 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,34 @@ int32_t cpu_get_num_physical_cores() {
if (result == 0) {
return num_physical_cores;
}
#elif defined(_WIN32)
//TODO: Implement
#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
// TODO: windows + arm64 + mingw64
unsigned int n_threads_win = std::thread::hardware_concurrency();
unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;

DWORD buffer_size = 0;
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
return default_threads;
}
}

std::vector<char> buffer(buffer_size);
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
return default_threads;
}

int32_t num_physical_cores = 0;
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
while (buffer_size > 0) {
if (info->Relationship == RelationProcessorCore) {
num_physical_cores += info->Processor.GroupCount;
}
buffer_size -= info->Size;
info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
}

return num_physical_cores > 0 ? num_physical_cores : default_threads;
#endif
unsigned int n_threads = std::thread::hardware_concurrency();
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
Expand Down Expand Up @@ -1727,7 +1753,13 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
if (params.n_threads_batch != -1) {
os << " (n_threads_batch = " << params.n_threads_batch << ")";
}
#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
// TODO: windows + arm64 + mingw64
DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
#else
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
#endif

return os.str();
}
Expand Down Expand Up @@ -2702,12 +2734,6 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
return text;
}

bool llama_should_add_bos_token(const llama_model * model) {
const int add_bos = llama_add_bos_token(model);

return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}

//
// Chat template utils
//
Expand Down
4 changes: 0 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,6 @@ std::string llama_detokenize(
const std::vector<llama_token> & tokens,
bool special = true);

// Uses the value from the model metadata if possible, otherwise
// defaults to true when model type is SPM, otherwise false.
bool llama_should_add_bos_token(const llama_model * model);

//
// Chat template utils
//
Expand Down
153 changes: 131 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
Expand Down Expand Up @@ -590,6 +591,15 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249":
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
res = "smollm"
if chkhsh == "3c30d3ad1d6b64202cd222813e7736c2db6e1bd6d67197090fc1211fbc612ae7":
# ref: https://huggingface.co/bigscience/bloom
res = "bloom"
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
res = "gpt3-finnish"
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
res = "exaone"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -893,7 +903,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return tensors


@Model.register("BloomForCausalLM")
@Model.register("BloomForCausalLM", "BloomModel")
class BloomModel(Model):
model_arch = gguf.MODEL_ARCH.BLOOM

Expand Down Expand Up @@ -2702,7 +2712,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

Expand Down Expand Up @@ -2733,20 +2743,24 @@ def set_gguf_parameters(self):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

use_dt_b_c_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
use_dt_b_c_norm = True
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype)

_tok_embd = None
Expand All @@ -2773,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
if bid is not None and new_name in (
self.format_tensor_name(
n, bid, ".weight" if name.endswith(".weight") else ""
)
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
):
return gguf.GGMLQuantizationType.F32

return super().tensor_force_quant(name, new_name, bid, n_dims)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
Expand Down Expand Up @@ -3734,8 +3731,120 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]

###### CONVERSION LOGIC ######

@Model.register("NemotronForCausalLM")
class NemotronModel(Model):
model_arch = gguf.MODEL_ARCH.NEMOTRON

def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_pad_token_id(0)
self.gguf_writer.add_unk_token_id(1)

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])

f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"])
self.gguf_writer.add_layer_norm_eps(f_norm_eps)

# * Partial RoPE
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)

# * RopeScaling for Nemotron
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
# model.layers.{l}.input_layernorm.weight
# model.layers.{l}.post_attention_layernorm.weight
# model.norm.weight
if name.endswith("norm.weight"):
data_torch = data_torch + 1

return [(self.map_tensor_name(name), data_torch)]


@Model.register("ExaoneForCausalLM")
class ExaoneModel(Model):
model_arch = gguf.MODEL_ARCH.EXAONE

def set_gguf_parameters(self):
hparams = self.hparams

assert (hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
num_heads = hparams["num_attention_heads"]
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
# embed_dropout_rate = hparams["embed_dropout"]
self.gguf_writer.add_embedding_length(embed_dim)
self.gguf_writer.add_head_count(num_heads)
self.gguf_writer.add_head_count_kv(num_kv_heads)
self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers)
self.gguf_writer.add_file_type(self.ftype)

if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
if hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])

def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen != high_freq_wavelen

rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))

self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))

super().prepare_tensors()


###### CONVERSION LOGIC ######

# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
Expand Down
3 changes: 3 additions & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
]


Expand Down
2 changes: 1 addition & 1 deletion convert_llama_ggml_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def load(self, data, offset):
assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant
offset += 12
self.dtype= dtype
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len])
Expand Down
Loading

0 comments on commit 80d9d2a

Please sign in to comment.