Skip to content

Commit

Permalink
merge master to resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Septa2112 committed Aug 22, 2024
2 parents b0c6ad7 + a1631e5 commit c8ef556
Show file tree
Hide file tree
Showing 46 changed files with 3,023 additions and 1,263 deletions.
44 changes: 44 additions & 0 deletions .devops/llama-cli-cann.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8

FROM cosdt/cann:$ASCEND_VERSION AS build

WORKDIR /app

COPY . .

RUN yum install -y gcc g++ cmake make
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}

# find libascend_hal.so, because the drive hasn`t been mounted.
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH

RUN echo "Building with static libs" && \
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \
cmake --build build --config Release --target llama-cli

# TODO: use image with NNRT
FROM cosdt/cann:$ASCEND_VERSION AS runtime
COPY --from=build /app/build/bin/llama-cli /llama-cli

ENV LC_ALL=C.utf8

ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}

ENTRYPOINT ["/llama-cli" ]
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,6 @@ poetry.toml

# Scripts
!/scripts/install-oneapi.bat

# Test models for lora adapters
/lora-tests
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down Expand Up @@ -425,6 +426,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
| [Vulkan](./docs/build.md#vulkan) | GPU |
| [CANN](./docs/build.md#cann) | Ascend NPU |

## Tools

Expand Down
64 changes: 57 additions & 7 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,41 @@

using json = nlohmann::ordered_json;

//
// Environment variable utils
//

template<typename T>
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::string(value) : target;
}

template<typename T>
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stoi(value) : target;
}

template<typename T>
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
target = value ? std::stof(value) : target;
}

template<typename T>
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
get_env(std::string name, T & target) {
char * value = std::getenv(name.c_str());
if (value) {
std::string val(value);
target = val == "1" || val == "true";
}
}

//
// CPU utils
//
Expand Down Expand Up @@ -220,12 +255,6 @@ int32_t cpu_get_num_math() {
// CLI argument parsing
//

void gpt_params_handle_hf_token(gpt_params & params) {
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
params.hf_token = std::getenv("HF_TOKEN");
}
}

void gpt_params_handle_model_default(gpt_params & params) {
if (!params.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
Expand Down Expand Up @@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {

gpt_params_handle_model_default(params);

gpt_params_handle_hf_token(params);
if (params.hf_token.empty()) {
get_env("HF_TOKEN", params.hf_token);
}

if (params.escape) {
string_process_escapes(params.prompt);
Expand All @@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
return true;
}

void gpt_params_parse_from_env(gpt_params & params) {
// we only care about server-related params for now
get_env("LLAMA_ARG_MODEL", params.model);
get_env("LLAMA_ARG_THREADS", params.n_threads);
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
get_env("LLAMA_ARG_BATCH", params.n_batch);
get_env("LLAMA_ARG_UBATCH", params.n_ubatch);
get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers);
get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http);
get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template);
get_env("LLAMA_ARG_N_PREDICT", params.n_predict);
get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots);
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
}

bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
const auto params_org = params; // the example can modify the default params

Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ struct gpt_params {
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
};

void gpt_params_handle_hf_token(gpt_params & params);
void gpt_params_parse_from_env(gpt_params & params);
void gpt_params_handle_model_default(gpt_params & params);

bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
Expand Down
32 changes: 10 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
Expand Down Expand Up @@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

Expand Down Expand Up @@ -2742,20 +2743,24 @@ def set_gguf_parameters(self):
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

use_dt_b_c_norm = False
# For falconmamba we do apply RMS norm on B / DT and C layers
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
use_dt_b_c_norm = True
# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
self.gguf_writer.add_file_type(self.ftype)

_tok_embd = None
Expand All @@ -2782,23 +2787,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(new_name, data_torch)]

def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
if bid is not None and new_name in (
self.format_tensor_name(
n, bid, ".weight" if name.endswith(".weight") else ""
)
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
):
return gguf.GGMLQuantizationType.F32

return super().tensor_force_quant(name, new_name, bid, n_dims)


@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
Expand Down Expand Up @@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")
assert (hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
Expand Down Expand Up @@ -3855,8 +3843,8 @@ def prepare_tensors(self):

super().prepare_tensors()

###### CONVERSION LOGIC ######

###### CONVERSION LOGIC ######

# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
Expand Down
2 changes: 1 addition & 1 deletion convert_llama_ggml_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def load(self, data, offset):
assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant
offset += 12
self.dtype= dtype
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len])
Expand Down
Loading

0 comments on commit c8ef556

Please sign in to comment.