diff --git a/common/common.cpp b/common/common.cpp index 0535508ba98df..ba1ecf0e59c8b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -901,6 +901,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.interactive = true; return true; } + if (arg == "--interactive-specials") { + params.interactive_specials = true; + return true; + } if (arg == "--embedding") { params.embedding = true; return true; @@ -1367,14 +1371,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { std::replace(arg.begin(), arg.end(), '_', '-'); } - if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { throw std::invalid_argument("error: unknown argument: " + arg); } - } - - if (invalid_param) { - throw std::invalid_argument("error: invalid parameter for argument: " + arg); + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); + } } if (params.prompt_cache_all && @@ -1422,6 +1424,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -h, --help show this help message and exit\n"); printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); + printf(" --interactive-specials allow special tokens in user text, in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); @@ -2652,6 +2655,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); + fprintf(stream, "interactive_specials: %s # default: false\n", params.interactive_specials ? "true" : "false"); fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); fprintf(stream, "keep: %d # default: 0\n", params.n_keep); fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); diff --git a/common/common.h b/common/common.h index 6f00a2cca8883..d80344f2a61bb 100644 --- a/common/common.h +++ b/common/common.h @@ -140,6 +140,7 @@ struct gpt_params { bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 2a1301569793a..fecb7cd713ea1 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -142,6 +142,9 @@ namespace grammar_parser { pos++; last_sym_start = out_elements.size(); while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } auto char_pair = parse_char(pos); pos = char_pair.second; out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); @@ -156,6 +159,9 @@ namespace grammar_parser { } last_sym_start = out_elements.size(); while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } auto char_pair = parse_char(pos); pos = char_pair.second; enum llama_gretype type = last_sym_start < out_elements.size() @@ -164,6 +170,9 @@ namespace grammar_parser { out_elements.push_back({type, char_pair.first}); if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } auto endchar_pair = parse_char(pos + 1); pos = endchar_pair.second; out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); diff --git a/common/sampling.cpp b/common/sampling.cpp index 3715a798531ae..f0f1b92d37f59 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); - result->n_considered = 0; + result->n_valid = 0; llama_sampling_set_rng_seed(result, params.seed); @@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); - ctx->n_considered = 0; + ctx->n_valid = 0; } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl( } } - ctx_sampling->n_considered = cur_p.size; + ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; return id; } diff --git a/common/sampling.h b/common/sampling.h index 5b73ecdcdb37b..655732ad17206 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -81,7 +81,7 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; - size_t n_considered; + size_t n_valid; // Number of correct top tokens with correct probabilities. std::mt19937 rng; }; diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index e6468772210f8..cd2674a0ea97d 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -74,6 +74,9 @@ class TOKENIZER_TYPE(IntEnum): {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! + {"name": "jina-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, + {"name": "jina-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, ] # make directory "models/tokenizers" if it doesn't exist @@ -142,8 +145,17 @@ def download_file_with_auth(url, token, save_path): if tokt == TOKENIZER_TYPE.SPM: continue + # Skip if the tokenizer folder does not exist or there are other download issues previously + if not os.path.exists(f"models/tokenizers/{name}"): + logger.warning(f"Directory for tokenizer {name} not found. Skipping...") + continue + # create the tokenizer - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + try: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + continue # Skip to the next model if the tokenizer can't be loaded chktok = tokenizer.encode(chktxt) chkhsh = sha256(str(chktok).encode()).hexdigest() @@ -161,6 +173,8 @@ def download_file_with_auth(url, token, save_path): logger.info("normalizer: " + json.dumps(normalizer, indent=4)) pre_tokenizer = cfg["pre_tokenizer"] logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + if "ignore_merges" in cfg["model"]: + logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) logger.info("") @@ -282,8 +296,17 @@ def get_vocab_base_pre(self, tokenizer) -> str: name = model["name"] tokt = model["tokt"] + # Skip if the tokenizer folder does not exist or there are other download issues previously + if not os.path.exists(f"models/tokenizers/{name}"): + logger.warning(f"Directory for tokenizer {name} not found. Skipping...") + continue + # create the tokenizer - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + try: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") + continue # Skip this model and continue with the next one in the loop with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f: for text in tests: diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1dc18b2a57721..fbaed64da1cac 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -404,8 +404,17 @@ def get_vocab_base_pre(self, tokenizer) -> str: # ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf res = "olmo" if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": - # ref: https://huggingface.co/databricks/dbrx-instruct + # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en + res = "jina-en" + if chkhsh == "171aeeedd6fb548d418a7461d053f11b6f1f1fc9b387bd66640d28a4b9f5c643": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-es + res = "jina-es" + if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6": + # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de + res = "jina-de" if res is None: logger.warning("\n") @@ -1013,6 +1022,18 @@ def set_gguf_parameters(self): class RefactModel(Model): model_arch = gguf.MODEL_ARCH.REFACT + def set_vocab(self): + super().set_vocab() + + # TODO: how to determine special FIM tokens automatically? + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_vocab._set_special_token("prefix", 1) + special_vocab._set_special_token("suffix", 3) + special_vocab._set_special_token("middle", 2) + special_vocab._set_special_token("fsep", 4) # is this correct? + special_vocab.add_to_gguf(self.gguf_writer) + def set_gguf_parameters(self): hidden_dim = self.hparams["n_embd"] inner_dim = 4 * hidden_dim @@ -2277,6 +2298,43 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] +@Model.register("JinaBertModel", "JinaBertForMaskedLM") +class JinaBertV2Model(BertModel): + model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.intermediate_size = self.hparams["intermediate_size"] + + def get_tensors(self): + for name, data in super().get_tensors(): + if 'gated_layers' in name: + d1 = data[:self.intermediate_size, :] + name1 = name.replace('gated_layers', 'gated_layers_w') + d2 = data[self.intermediate_size:, :] + name2 = name.replace('gated_layers', 'gated_layers_v') + yield name1, d1 + yield name2, d2 + continue + + yield name, data + + def set_vocab(self, *args, **kwargs): + tokenizer_class = 'BertTokenizer' + with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_class = json.load(f)['tokenizer_class'] + + if tokenizer_class == 'BertTokenizer': + super().set_vocab() + elif tokenizer_class == 'RobertaTokenizer': + self._set_vocab_gpt2() + self.gguf_writer.add_token_type_count(2) + else: + raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + ###### CONVERSION LOGIC ###### diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 6a93147d70e88..c85a2da53d129 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -49,6 +49,12 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + batch.seq_id[i][0] * n_embd; + //TODO: I would also add a parameter here to enable normalization or not. + /*fprintf(stdout, "unnormalized_embedding:"); + for (int hh = 0; hh < n_embd; hh++) { + fprintf(stdout, "%9.6f ", embd[hh]); + } + fprintf(stdout, "\n");*/ llama_embd_normalize(embd, out, n_embd); } } @@ -123,10 +129,12 @@ int main(int argc, char ** argv) { inputs.push_back(inp); } - // add SEP if not present + // check if the last token is SEP + // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' for (auto & inp : inputs) { if (inp.empty() || inp.back() != llama_token_sep(model)) { - inp.push_back(llama_token_sep(model)); + fprintf(stderr, "%s: warning: last token in the prompt is not SEP\n", __func__); + fprintf(stderr, "%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); } } diff --git a/examples/llama-bench/README.md b/examples/llama-bench/README.md index 10f37b4418897..8578405646af7 100644 --- a/examples/llama-bench/README.md +++ b/examples/llama-bench/README.md @@ -26,16 +26,21 @@ options: -m, --model (default: models/7B/ggml-model-q4_0.gguf) -p, --n-prompt (default: 512) -n, --n-gen (default: 128) - -b, --batch-size (default: 512) - -ctk , --cache-type-k (default: f16) - -ctv , --cache-type-v (default: f16) - -t, --threads (default: 112) + -pg (default: 512,128) + -b, --batch-size (default: 2048) + -ub, --ubatch-size (default: 512) + -ctk, --cache-type-k (default: f16) + -ctv, --cache-type-v (default: f16) + -t, --threads (default: 16) -ngl, --n-gpu-layers (default: 99) -sm, --split-mode (default: layer) -mg, --main-gpu (default: 0) -nkvo, --no-kv-offload <0|1> (default: 0) + -fa, --flash-attn <0|1> (default: 0) -mmp, --mmap <0|1> (default: 1) - -ts, --tensor_split (default: 0) + --numa (default: disabled) + -embd, --embeddings <0|1> (default: 0) + -ts, --tensor-split (default: 0) -r, --repetitions (default: 5) -o, --output (default: md) -v, --verbose (default: 0) @@ -43,10 +48,11 @@ options: Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times. ``` -llama-bench can perform two types of tests: +llama-bench can perform three types of tests: - Prompt processing (pp): processing a prompt in batches (`-p`) - Text generation (tg): generating a sequence of tokens (`-n`) +- Prompt processing + text generation (pg): processing a prompt followed by generating a sequence of tokens (`-pg`) With the exception of `-r`, `-o` and `-v`, all options can be specified multiple times to run multiple tests. Each pp and tg test is run with all combinations of the specified options. To specify multiple values for an option, the values can be separated by commas (e.g. `-n 16,32`), or the option can be specified multiple times (e.g. `-n 16 -n 32`). diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 40128ec444334..8b965e1990ba5 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -161,10 +161,17 @@ static const char * split_mode_str(llama_split_mode mode) { } } +static std::string pair_str(const std::pair & p) { + static char buf[32]; + snprintf(buf, sizeof(buf), "%d,%d", p.first, p.second); + return buf; +} + struct cmd_params { std::vector model; std::vector n_prompt; std::vector n_gen; + std::vector> n_pg; std::vector n_batch; std::vector n_ubatch; std::vector type_k; @@ -188,6 +195,7 @@ static const cmd_params cmd_params_defaults = { /* model */ {"models/7B/ggml-model-q4_0.gguf"}, /* n_prompt */ {512}, /* n_gen */ {128}, + /* n_pg */ {{512, 128}}, /* n_batch */ {2048}, /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, @@ -215,10 +223,11 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); + printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); - printf(" -ub N, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); - printf(" -ctk , --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); - printf(" -ctv , --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); + printf(" -ub, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); + printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); + printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); @@ -304,6 +313,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.n_gen.insert(params.n_gen.end(), p.begin(), p.end()); + } else if (arg == "-pg") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], ','); + if (p.size() != 2) { + invalid_param = true; + break; + } + params.n_pg.push_back({std::stoi(p[0]), std::stoi(p[1])}); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -493,6 +513,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.model.empty()) { params.model = cmd_params_defaults.model; } if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } + if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; } if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; } if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } @@ -632,6 +653,31 @@ static std::vector get_cmd_params_instances(const cmd_param }; instances.push_back(instance); } + + for (const auto & n_pg : params.n_pg) { + if (n_pg.first == 0 && n_pg.second == 0) { + continue; + } + cmd_params_instance instance = { + /* .model = */ m, + /* .n_prompt = */ n_pg.first, + /* .n_gen = */ n_pg.second, + /* .n_batch = */ nb, + /* .n_ubatch = */ nub, + /* .type_k = */ tk, + /* .type_v = */ tv, + /* .n_threads = */ nt, + /* .n_gpu_layers = */ nl, + /* .split_mode = */ sm, + /* .main_gpu = */ mg, + /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, + /* .tensor_split = */ ts, + /* .use_mmap = */ mmp, + /* .embeddings = */ embd, + }; + instances.push_back(instance); + } } return instances; @@ -965,6 +1011,9 @@ struct markdown_printer : public printer { if (field == "n_gpu_layers") { return 3; } + if (field == "test") { + return 13; + } int width = std::max((int)field.length(), 10); @@ -1091,12 +1140,11 @@ struct markdown_printer : public printer { value = test::get_backend(); } else if (field == "test") { if (t.n_prompt > 0 && t.n_gen == 0) { - snprintf(buf, sizeof(buf), "pp %d", t.n_prompt); + snprintf(buf, sizeof(buf), "pp%d", t.n_prompt); } else if (t.n_gen > 0 && t.n_prompt == 0) { - snprintf(buf, sizeof(buf), "tg %d", t.n_gen); + snprintf(buf, sizeof(buf), "tg%d", t.n_gen); } else { - assert(false); - exit(1); + snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen); } value = buf; } else if (field == "t/s") { @@ -1297,6 +1345,7 @@ int main(int argc, char ** argv) { llama_kv_cache_clear(ctx); uint64_t t_start = get_time_ns(); + if (t.n_prompt > 0) { test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 157a680b5ecdb..da60ddf2f057d 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -189,6 +189,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + if (!ctx_sampling) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 49acd6bab4074..9dee41001f12c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -523,6 +523,10 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + if (!ctx_sampling) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -879,7 +883,7 @@ int main(int argc, char ** argv) { } const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_inp = ::llama_tokenize(ctx, buffer, false, params.interactive_specials); const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 305f79492a055..55c1d41298cd6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -673,6 +673,8 @@ struct server_context { llama_free_model(model); model = nullptr; } + + llama_batch_free(batch); } bool load_model(const gpt_params & params_) { @@ -2270,10 +2272,10 @@ struct server_context { const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); if (n_probs > 0) { - const size_t n_considered = slot.ctx_sampling->n_considered; + const size_t n_valid = slot.ctx_sampling->n_valid; // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + if (slot.sparams.temp == 0.0f && n_probs > n_valid) { llama_sample_top_k(ctx, &cur_p, n_probs, 0); } @@ -2289,7 +2291,7 @@ struct server_context { for (size_t i = 0; i < n_probs; ++i) { result.probs.push_back({ cur_p.data[i].id, - i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. }); } } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6f89a7cc3e900..c5c7787969f5f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4,7 +4,6 @@ #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" -#include "ggml-cuda/alibi.cuh" #include "ggml-cuda/arange.cuh" #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" @@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; - case GGML_OP_ALIBI: - ggml_cuda_op_alibi(ctx, dst); - break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu deleted file mode 100644 index 6c7f1fd9562e4..0000000000000 --- a/ggml-cuda/alibi.cu +++ /dev/null @@ -1,63 +0,0 @@ -#include "alibi.cuh" - -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, - const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, cudaStream_t stream) { - const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); -} - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream); -} diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh deleted file mode 100644 index 630adfc7f6396..0000000000000 --- a/ggml-cuda/alibi.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#include "common.cuh" - -#define CUDA_ALIBI_BLOCK_SIZE 32 - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 7c486f4829bdd..ac5d6672b30d3 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); + half slopeh = __float2half(1.0f); + half2 slope2 = make_half2(1.0f, 1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + slope2 = make_half2(slopeh, slopeh); + } + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -710,8 +744,17 @@ template void launch_fattn_vec_ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_vec_ext_f16 <<>> ( @@ -720,7 +763,7 @@ template void launch_fattn_vec_ (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -761,8 +804,17 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_ext_f16 <<>> ( @@ -771,7 +823,7 @@ template data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const int32_t precision = KQV->op_params[1]; + const int32_t precision = KQV->op_params[2]; if (!fp16_mma_available(cc)) { GGML_ASSERT(precision == GGML_PREC_DEFAULT); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 6ed225999bddf..ca85285a3f458 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { const int h = rowx/nrows_y; // head index const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = powf(base, exp); + slope = powf(base, exph); } extern __shared__ float data_soft_max_f32[]; @@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p } template -static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; const float * src0_d = (const float *)src0->data; const void * src1_d = src1 ? (const void *)src1->data : nullptr; @@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // positions tensor - void * src2_d = nullptr; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - src2_d = (void *)src2->data; - } - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { const half * src1_dd = (const half *)src1_d; - const half * src2_dd = (const half *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { const float * src1_dd = (const float *)src1_d; - const float * src2_dd = (const float *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 9a469821d8042..3f033d58be481 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml case GGML_OP_SOFT_MAX: { float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; -#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") + memcpy(&scale, (float *)dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); - GGML_ASSERT(src2 == nullptr); + +#pragma message("TODO: add ALiBi support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + GGML_ASSERT(max_bias == 0.0f); ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; diff --git a/ggml-metal.m b/ggml-metal.m index c6817f01f2bdb..1bbb8fb4fefa1 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -169,7 +169,6 @@ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, - GGML_METAL_KERNEL_TYPE_ALIBI_F32, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -623,7 +622,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -633,14 +631,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_GROUP_NORM: return ctx->support_simdgroup_reduction; case GGML_OP_NORM: - case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; @@ -772,8 +769,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: - case GGML_OP_FLASH_ATTN_EXT: return true; + case GGML_OP_FLASH_ATTN_EXT: + return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && @@ -1357,13 +1355,12 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_SOFT_MAX: { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); int nth = 32; // SIMD width id pipeline = nil; - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { @@ -1394,8 +1391,8 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -1407,20 +1404,15 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:9]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:10]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -2225,49 +2217,6 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT((src0t == GGML_TYPE_F32)); - - const int nth = MIN(1024, ne00); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; case GGML_OP_ROPE: { GGML_ASSERT(ne10 == ne02); @@ -2565,7 +2514,7 @@ static enum ggml_status ggml_metal_graph_compute( "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - const int64_t ne31 = src3 ? src3->ne[1] : 0; + //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); @@ -2577,7 +2526,16 @@ static enum ggml_status ggml_metal_graph_compute( const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); id pipeline = nil; @@ -2614,34 +2572,37 @@ static enum ggml_status ggml_metal_graph_compute( } [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; - [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&scale length:sizeof( float) atIndex:26]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:27]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:28]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:29]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml-metal.metal b/ggml-metal.metal index 46c7d503930a4..ee9de57a3a6f1 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -356,7 +356,6 @@ template kernel void kernel_soft_max( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -378,10 +377,9 @@ kernel void kernel_soft_max( device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -397,7 +395,7 @@ kernel void kernel_soft_max( float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -422,7 +420,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -461,7 +459,6 @@ template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -483,10 +480,9 @@ kernel void kernel_soft_max_4( device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - float slope = 0.0f; + float slope = 1.0f; if (max_bias > 0.0f) { const int64_t h = i02; @@ -501,7 +497,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -527,7 +523,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4( } } -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -2116,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2154,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2257,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16( // prepare diagonal scale matrix simdgroup_float8x8 mscale(scale); + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2279,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope simdgroup_half8x8 mm; simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, TF, 0, false); @@ -2472,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2497,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16( const short T = D + 2*nsg*SH; // shared memory size per query in (half) + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix @@ -2603,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16( mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 1); - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope if (tiisg == 0) { float4 mm = (float4) mp4[ic/4 + cc]; - mqk = mqk*scale + mm; + mqk = mqk*scale + mm*slope; ss4[cc] = mqk; } @@ -2840,7 +2817,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 79aec4d9f02e1..e93d2af631cc0 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 #define SYCL_ROPE_BLOCK_SIZE 256 -#define SYCL_ALIBI_BLOCK_SIZE 32 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 #define SYCL_QUANTIZE_BLOCK_SIZE 256 #define SYCL_DEQUANTIZE_BLOCK_SIZE 256 @@ -9316,32 +9315,6 @@ static void rope_glm_f32( dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } -static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1, - const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (col >= ncols) { - return; - } - - const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = dpct::pow(m0, k + 1); - } else { - m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con template -static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, }); } -static void alibi_f32_sycl(const float *x, float *dst, const int ncols, - const int nrows, const int k_rows, - const int n_heads_log2_floor, const float m0, - const float m1, dpct::queue_ptr stream) { - const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE); - const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE); - const sycl::range<3> block_nums(1, nrows, num_blocks_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - alibi_f32(x, dst, ncols, k_rows, - n_heads_log2_floor, m0, m1, item_ct1); - }); -} - static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, } template -static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, dpct::queue_ptr stream) { @@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - soft_max_f32(x, mask, pos, dst, ncols_par, + soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, local_buf_acc.get_pointer()); @@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl }); } -static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos, +static void soft_max_f32_sycl(const float * x, const float * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { @@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); return; } switch (ncols_x) { case 32: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 64: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 128: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 256: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 512: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 1024: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 2048: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 4096: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; default: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; } } else { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, WARP_SIZE, stream); } @@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne); - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); - - (void) src1; - (void) src1_dd; -} - static void ggml_sycl_op_pool2d(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const ggml_tensor * src2 = dst->src[2]; - -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - // positions tensor - float * src2_dd = nullptr; - sycl_pool_alloc src2_f; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU; - - if (src2_on_device) { - ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; - src2_dd = (float *) src2_extra->data_device[g_main_device]; - } else { - src2_dd = src2_f.alloc(ggml_nelements(src2)); - SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream)); - } - } - - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, + soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } @@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope); } -static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi); -} - static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d); } @@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ROPE: func = ggml_sycl_rope; break; - case GGML_OP_ALIBI: - func = ggml_sycl_alibi; - break; case GGML_OP_IM2COL: func = ggml_sycl_im2col; break; @@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 95f7189740539..b9449be0357b5 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16); - if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { @@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, diff --git a/ggml.c b/ggml.c index 093d38d00ae35..4ee5d24af75d0 100644 --- a/ggml.c +++ b/ggml.c @@ -2185,7 +2185,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOFT_MAX_BACK", "ROPE", "ROPE_BACK", - "ALIBI", "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", @@ -2227,7 +2226,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2276,7 +2275,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "soft_max_back(x)", "rope(x)", "rope_back(x)", - "alibi(x)", "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", @@ -2318,7 +2316,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5646,7 +5644,6 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias, bool inplace) { @@ -5660,18 +5657,8 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->ne[1] >= a->ne[1]); } - if (pos) { - GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); - GGML_ASSERT(pos->ne[0] == a->ne[0]); - } - - if (pos && mask) { - GGML_ASSERT(pos->type == mask->type); - } - if (max_bias > 0.0f) { - GGML_ASSERT(pos); + GGML_ASSERT(mask); } bool is_node = false; @@ -5689,7 +5676,6 @@ static struct ggml_tensor * ggml_soft_max_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = mask; - result->src[2] = pos; return result; } @@ -5697,23 +5683,22 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); } struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias) { - return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false); + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } // ggml_soft_max_back @@ -5928,37 +5913,6 @@ struct ggml_tensor * ggml_rope_back( return result; } -// ggml_alibi - -struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max) { - GGML_ASSERT(n_past >= 0); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - int32_t op_params[3] = { n_past, n_head }; - memcpy(op_params + 2, &bias_max, sizeof(float)); - ggml_set_op_params(result, op_params, sizeof(op_params)); - - result->op = GGML_OP_ALIBI; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -6486,9 +6440,11 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale) { + float scale, + float max_bias) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); @@ -6498,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext( //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -6508,7 +6468,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale }; + float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -6528,7 +6488,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos + ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_ff @@ -13333,7 +13293,6 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -13359,8 +13318,8 @@ static void ggml_compute_forward_soft_max_f32( // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -13377,13 +13336,13 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; - float * pos_f32 = src2 ? (float *) src2->data : src0->data; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); @@ -13396,27 +13355,11 @@ static void ggml_compute_forward_soft_max_f32( if (mp_f32) { if (use_f16) { for (int i = 0; i < nc; ++i) { - wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); } } else { for (int i = 0; i < nc; ++i) { - wp[i] += mp_f32[i]; - } - } - } - - // ALiBi bias - if (max_bias > 0.0f) { - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*pos_f32[i]; + wp[i] += slope*mp_f32[i]; } } } @@ -13578,178 +13521,6 @@ static void ggml_compute_forward_soft_max_back( } } -// ggml_compute_forward_alibi - -static void ggml_compute_forward_alibi_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int64_t ne1 = src0->ne[1]; // seq_len_without_past - const int64_t ne2 = src0->ne[2]; // n_head -> this is k - //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - - const int64_t n = ggml_nrows(src0); - const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - - const size_t nb0 = src0->nb[0]; - const size_t nb1 = src0->nb[1]; - const size_t nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int64_t k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int64_t i = 0; i < ne0; i++) { - for (int64_t j = 0; j < ne1; j++) { - float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - pdst[0] = i * m_k + src[0]; - } - } - } -} - -static void ggml_compute_forward_alibi_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz - - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - - // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); - } - } - } -} - -static void ggml_compute_forward_alibi( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_alibi_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_alibi_f32(params, dst); - } break; - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q8_K: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_I64: - case GGML_TYPE_F64: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( @@ -15763,8 +15534,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -15773,6 +15553,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const uint32_t h = iq2; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float S = 0.0f; float M = -INFINITY; @@ -15796,7 +15579,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } @@ -15867,7 +15650,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[1]) { + switch (dst->op_params[2]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -17630,10 +17413,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope_back(params, tensor); } break; - case GGML_OP_ALIBI: - { - ggml_compute_forward_alibi(params, tensor); - } break; case GGML_OP_CLAMP: { ggml_compute_forward_clamp(params, tensor); @@ -18652,10 +18431,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_CLAMP: { GGML_ASSERT(false); // TODO: not implemented @@ -19428,10 +19203,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; - case GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; case GGML_OP_CLAMP: { n_tasks = 1; //TODO diff --git a/ggml.h b/ggml.h index fe60538220680..76c3328315964 100644 --- a/ggml.h +++ b/ggml.h @@ -468,7 +468,6 @@ extern "C" { GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, GGML_OP_ROPE_BACK, - GGML_OP_ALIBI, GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, @@ -1428,15 +1427,13 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope)) + // fused soft_max(a*scale + mask*(ALiBi slope)) // mask is optional - // pos is required when max_bias > 0.0f // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias); @@ -1538,16 +1535,6 @@ extern "C" { float xpos_base, bool xpos_down); - // alibi position embedding - // in-place, returns view(a) - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max), - "use ggml_soft_max_ext instead (will be removed in Mar 2024)"); - // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( @@ -1744,7 +1731,8 @@ extern "C" { struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale); + float scale, + float max_bias); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5951c0bb0fb5e..a4fbfc5e09d06 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -118,6 +118,7 @@ class MODEL_ARCH(IntEnum): REFACT = auto() BERT = auto() NOMIC_BERT = auto() + JINA_BERT_V2 = auto() BLOOM = auto() STABLELM = auto() QWEN = auto() @@ -195,6 +196,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", @@ -380,6 +382,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, ], + MODEL_ARCH.JINA_BERT_V2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5750d4191f6b..8e1cac9152f55 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -137,6 +137,7 @@ class TensorNameMap: "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j + "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok @@ -148,6 +149,7 @@ class TensorNameMap: "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j + "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok @@ -229,6 +231,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j + "transformer.h.{bid}.mlp.linear_3", # refact "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen @@ -240,6 +243,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 + "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -266,6 +270,8 @@ class TensorNameMap: "model.layers.layers.{bid}.mlp.gate_proj", # plamo "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert + "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 + "transformer.h.{bid}.mlp.linear_1", # refact ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -299,6 +305,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.w2", # internlm2 "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 + "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -317,6 +324,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_norm", # cohere "transformer.blocks.{bid}.attn.q_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_q" # jina-bert-v2 ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -324,6 +332,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_norm", # cohere "transformer.blocks.{bid}.attn.k_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_k" # jina-bert-v2 ), MODEL_TENSOR.ROPE_FREQS: ( @@ -334,6 +343,7 @@ class TensorNameMap: "encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok + "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 ), MODEL_TENSOR.SSM_IN: ( diff --git a/llama.cpp b/llama.cpp index e7b3fd8b433b4..e91ad7285da99 100644 --- a/llama.cpp +++ b/llama.cpp @@ -205,6 +205,7 @@ enum llm_arch { LLM_ARCH_REFACT, LLM_ARCH_BERT, LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -228,39 +229,40 @@ enum llm_arch { }; static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_PERSIMMON, "persimmon" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { @@ -691,6 +693,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_JINA_BERT_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -1845,7 +1866,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models + bool use_alibi = false; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -2317,7 +2338,6 @@ struct llama_context { struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_pos; // F32 [n_kv] struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] @@ -3779,6 +3799,12 @@ static void llm_load_hparams( // get hparams kv ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + + // everything past this point is not vocab-related + if (hparams.vocab_only) { + return; + } + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); @@ -3860,7 +3886,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_7B : e_model::MODEL_8B; break; // LLaMa 8B v3 uses GQA + case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break; case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; @@ -3962,6 +3988,19 @@ static void llm_load_hparams( model.type = e_model::MODEL_335M; break; // bge-large } } break; + case LLM_ARCH_JINA_BERT_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small + case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base + } + } break; case LLM_ARCH_NOMIC_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -4383,7 +4422,9 @@ static void llm_load_vocab( tokenizer_pre == "starcoder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; } else if ( - tokenizer_pre == "gpt-2") { + tokenizer_pre == "gpt-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -5242,6 +5283,50 @@ static bool llm_load_tensors( layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; + case LLM_ARCH_JINA_BERT_V2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; // JinaBertLayer + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens + layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm + layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + } + } break; case LLM_ARCH_BLOOM: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6318,7 +6403,7 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { - struct ggml_tensor * tmp = ggml_mul_mat(ctx, up, cur); + struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -6500,7 +6585,6 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, int32_t n_tokens, int32_t n_kv, float kq_scale, @@ -6530,10 +6614,6 @@ static struct ggml_tensor * llm_build_kqv( GGML_UNUSED(model); GGML_UNUSED(n_ctx); - // note: if this assert triggers, then some check has failed earlier - // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention - GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -6543,7 +6623,7 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -6574,28 +6654,8 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_scale(ctx, kq, 30); } -#if defined(GGML_USE_KOMPUTE) -#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") -#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.use_alibi) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else -#endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); GGML_ASSERT(kv.size == n_ctx); @@ -6645,7 +6705,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * v_cur, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6664,7 +6723,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6771,18 +6830,17 @@ struct llm_build_context { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; lctx.inp_out_ids = nullptr; lctx.inp_KQ_mask = nullptr; - lctx.inp_KQ_pos = nullptr; lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; } void free() { @@ -6932,19 +6990,6 @@ struct llm_build_context { return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { - if (causal) { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); - } else { - // TODO: this will be needed for ALiBi-based BERT models - // https://github.com/ggerganov/llama.cpp/pull/6826 - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); - } - cb(lctx.inp_KQ_pos, "KQ_pos", -1); - ggml_set_input(lctx.inp_KQ_pos); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; - } - struct ggml_tensor * build_inp_mean() { lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); cb(lctx.inp_mean, "inp_mean", -1); @@ -7050,7 +7095,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7143,9 +7188,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7190,7 +7232,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7260,9 +7302,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7297,7 +7336,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7417,7 +7456,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7542,7 +7581,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7694,7 +7733,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7806,7 +7845,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8010,7 +8049,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8076,9 +8115,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8106,7 +8142,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8168,8 +8204,11 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + struct ggml_tensor * inp_pos = nullptr; - struct ggml_tensor * inp_pos = build_inp_pos(); + if (model.arch != LLM_ARCH_JINA_BERT_V2) { + inp_pos = build_inp_pos(); + } struct ggml_tensor * inp_mean = build_inp_mean(); struct ggml_tensor * inp_cls = build_inp_cls(); @@ -8200,13 +8239,26 @@ struct llm_build_context { struct ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT) { + if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); + if (model.layers[il].attn_q_norm) { + Qcur = llm_build_norm(ctx0, Qcur, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + } + Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); cb(Kcur, "Kcur", il); + if (model.layers[il].attn_k_norm) { + Kcur = llm_build_norm(ctx0, Kcur, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + } Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -8246,7 +8298,7 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); @@ -8297,6 +8349,13 @@ struct llm_build_context { model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { cur = llm_build_ffn(ctx0, cur, model.layers[il].ffn_up, NULL, @@ -8363,9 +8422,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, @@ -8399,7 +8455,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8464,9 +8520,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - if (model.pos_embd) { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -8530,13 +8583,13 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8680,7 +8733,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8798,7 +8851,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8911,7 +8964,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9025,7 +9078,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9180,7 +9233,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9297,7 +9350,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9410,7 +9463,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9513,7 +9566,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9620,7 +9673,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9736,7 +9789,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9853,7 +9906,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9983,7 +10036,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10104,7 +10157,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -10223,7 +10276,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10513,7 +10566,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10644,7 +10697,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10825,6 +10878,7 @@ static struct ggml_cgraph * llama_build_graph( result = llm.build_refact(); } break; case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: { result = llm.build_bert(); @@ -11032,11 +11086,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } } else { // when using kv cache, the mask needs to match the kv cache size @@ -11055,7 +11119,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float f = -INFINITY; for (int s = 0; s < batch.n_seq_id[i]; ++s) { if (batch.seq_id[i][s] == seq_id) { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(batch.pos[i] - batch.pos[j]); + } else { + f = 0.0f; + } break; } } @@ -11071,21 +11139,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch - // this allows to process multiple sequences in parallel with ALiBi-based models - if (hparams.use_alibi) { - const int64_t n_kv = kv_self.n; - - GGML_ASSERT(lctx.inp_KQ_pos); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer)); - - float * data = (float *) lctx.inp_KQ_pos->data; - - for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); - } - } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; @@ -12200,13 +12253,14 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; + bool ignore_merges = false; std::vector word_collection; switch (vocab.type) { case LLAMA_VOCAB_TYPE_BPE: switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: - case LLAMA_VOCAB_PRE_TYPE_DBRX: + ignore_merges = true; word_collection = unicode_regex_split(text, { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -12215,6 +12269,12 @@ struct llm_tokenizer_bpe { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }); break; + case LLAMA_VOCAB_PRE_TYPE_DBRX: + word_collection = unicode_regex_split(text, { + // same as llama3 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }); + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: word_collection = unicode_regex_split(text, { "[\r\n]", @@ -12298,6 +12358,11 @@ struct llm_tokenizer_bpe { int index = 0; size_t offset = 0; + if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } + while (offset < word.size()) { llm_symbol sym; size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); @@ -12752,7 +12817,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } - GGML_ASSERT(vocab.special_add_eos != 1); + if (add_special && vocab.special_add_eos == 1) { + GGML_ASSERT(vocab.special_add_eos != -1); + output.push_back(vocab.special_eos_id); + } } break; case LLAMA_VOCAB_TYPE_WPM: { @@ -15509,11 +15577,6 @@ struct llama_context * llama_new_context_with_model( } } - if (cparams.flash_attn && hparams.use_alibi) { - LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); - cparams.flash_attn = false; - } - if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); cparams.flash_attn = false; @@ -15808,6 +15871,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JINA_BERT_V2: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp index 0a89107c60d7f..9380bf355202a 100644 --- a/models/ggml-vocab-llama-bpe.gguf.inp +++ b/models/ggml-vocab-llama-bpe.gguf.inp @@ -104,3 +104,5 @@ __ggml_vocab_test__ 🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL __ggml_vocab_test__ + Việt +__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out index 1f00e3812e227..1f3607fb6a378 100644 --- a/models/ggml-vocab-llama-bpe.gguf.out +++ b/models/ggml-vocab-llama-bpe.gguf.out @@ -41,3 +41,4 @@ 8765 8765 1644 8765 8765 8765 198 4815 15073 66597 8004 1602 2355 79772 11187 9468 248 222 320 8416 8 27623 114 102470 9468 234 104 31643 320 36773 100166 98634 8 26602 227 11410 99 247 9468 99 247 220 18 220 1644 220 8765 220 8765 18 220 8765 1644 220 8765 8765 220 8765 8765 18 220 8765 8765 1644 220 18 13 18 220 18 497 18 220 18 1131 18 220 21549 222 98629 241 45358 233 21549 237 45358 224 21549 244 21549 115 21549 253 45358 223 21549 253 21549 95 98629 227 76460 223 949 37046 101067 19000 23182 102301 9263 18136 16 36827 21909 56560 54337 19175 102118 13373 64571 34694 3114 112203 80112 3436 106451 14196 14196 74694 3089 3089 29249 17523 3001 27708 7801 358 3077 1027 364 83 820 568 596 1070 11 364 793 499 2771 30 364 44 539 2771 358 3358 1304 433 11 364 35 499 1093 1063 15600 30 1226 6 43712 264 64966 43 + 101798 diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index fed3c1ee3d298..0ede9e67c2f15 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -325,8 +325,12 @@ def get_rows(properties): for row in rows_show: n_prompt = int(row[-4]) n_gen = int(row[-3]) - assert n_prompt == 0 or n_gen == 0 - test_name = f"tg{n_gen}" if n_prompt == 0 else f"pp{n_prompt}" + if n_prompt != 0 and n_gen == 0: + test_name = f"pp{n_prompt}" + elif n_prompt == 0 and n_gen != 0: + test_name = f"tg{n_gen}" + else: + test_name = f"pp{n_prompt}+tg{n_gen}" # Regular columns test name avg t/s values Speedup # VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV table.append(list(row[:-4]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])]) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d409a1d6b42ec..766a017524237 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -92,7 +92,7 @@ target_link_libraries(test-tokenizer-1-bpe PRIVATE common) install(TARGETS test-tokenizer-1-bpe RUNTIME) # TODO: disabled due to slowness -#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf) +#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges) #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0d66de5d9d389..731788b958fca 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case { if (this->mask) { mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } - ggml_tensor * pos = nullptr; - if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); - } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); return out; } }; @@ -1490,23 +1486,25 @@ struct test_flash_attn_ext : public test_case { const int64_t kv; // kv size const int64_t nb; // batch size + const float max_bias; // ALiBi + std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); + return VARS_TO_STR5(hs, nh, kv, nb, max_bias); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f) + : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias); return out; } }; @@ -1611,7 +1609,7 @@ struct test_llm : public test_case { struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f); // split cached v into n_head heads struct ggml_tensor * v = @@ -2128,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif for (bool mask : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { + if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { @@ -2141,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -2180,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #else for (int hs : { 64, 80, 128, 256, }) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + for (float max_bias : {0.0f, 8.0f}) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias)); + } } } } diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index a0e2caf9427eb..209a04ad6f77a 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -13,15 +13,27 @@ #include int main(int argc, char **argv) { - if (argc < 2) { - fprintf(stderr, "Usage: %s \n", argv[0]); + if (argc < 2 || argc > 3) { + fprintf(stderr, "Usage: %s [--ignore-merges]\n", argv[0]); return 1; } const std::string fname = argv[1]; + bool ignore_merges = false; + if (argc == 3) { + if (std::strcmp(argv[2], "--ignore-merges") != 0) { + fprintf(stderr, "Usage: %s [--ignore-merges]\n", argv[0]); + return 1; + } + ignore_merges = true; + } fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + if (ignore_merges) { + fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__); + } + llama_model * model; llama_context * ctx; @@ -65,7 +77,19 @@ int main(int argc, char **argv) { std::string str = llama_detokenize_bpe(ctx, std::vector(1, i)); try { auto cps = unicode_cpts_from_utf8(str); - std::vector tokens = llama_tokenize(ctx, str, false); + std::vector tokens = llama_tokenize(ctx, str, false, true); + if (ignore_merges && tokens.size() > 1) { + fprintf(stderr, + "%s : error: token %d detokenizes to '%s'(%zu) but " + "tokenization of this to multiple tokens: [", + __func__, i, str.c_str(), str.length()); + fprintf(stderr, "%d", tokens[0]); + for (size_t i = 1; i < tokens.size(); i++) { + fprintf(stderr, ", %d", tokens[i]); + } + fprintf(stderr, "]\n"); + return 2; + } std::string check = llama_detokenize_bpe(ctx, tokens); if (check != str) { fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",