Skip to content

Commit

Permalink
Showing 16 changed files with 465 additions and 167 deletions.
2 changes: 1 addition & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
@@ -328,7 +328,7 @@ Java_com_rnllama_LlamaContext_doCompletion(

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->ctx)] = -INFINITY;
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
}

const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
4 changes: 2 additions & 2 deletions cpp/build-info.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef BUILD_INFO_H
#define BUILD_INFO_H

#define BUILD_NUMBER 1414
#define BUILD_COMMIT "96981f3"
#define BUILD_NUMBER 1429
#define BUILD_COMMIT "34b2a5e"
#define BUILD_COMPILER ""
#define BUILD_TARGET "unknown"

8 changes: 4 additions & 4 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
@@ -880,13 +880,13 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}

if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}

{
LOG("warming up the model with an empty run\n");

std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
@@ -941,7 +941,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
}

std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_token bos_id = llama_token_bos(ctx);
const llama_token bos_id = llama_token_bos(llama_get_model(ctx));

std::string piece;
std::string result;
@@ -1186,7 +1186,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);

const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");

10 changes: 9 additions & 1 deletion cpp/ggml-metal-llama.metal
Original file line number Diff line number Diff line change
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
}

kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}

kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
constant float & scale,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
18 changes: 13 additions & 5 deletions cpp/ggml-metal.m
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@
LM_GGML_METAL_DECL_KERNEL(mul);
LM_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
LM_GGML_METAL_DECL_KERNEL(scale);
LM_GGML_METAL_DECL_KERNEL(scale_4);
LM_GGML_METAL_DECL_KERNEL(silu);
LM_GGML_METAL_DECL_KERNEL(relu);
LM_GGML_METAL_DECL_KERNEL(gelu);
@@ -249,6 +250,7 @@ static void lm_ggml_metal_log(enum lm_ggml_log_level level, const char* format,
LM_GGML_METAL_ADD_KERNEL(mul);
LM_GGML_METAL_ADD_KERNEL(mul_row);
LM_GGML_METAL_ADD_KERNEL(scale);
LM_GGML_METAL_ADD_KERNEL(scale_4);
LM_GGML_METAL_ADD_KERNEL(silu);
LM_GGML_METAL_ADD_KERNEL(relu);
LM_GGML_METAL_ADD_KERNEL(gelu);
@@ -347,6 +349,7 @@ void lm_ggml_metal_free(struct lm_ggml_metal_context * ctx) {
LM_GGML_METAL_DEL_KERNEL(mul);
LM_GGML_METAL_DEL_KERNEL(mul_row);
LM_GGML_METAL_DEL_KERNEL(scale);
LM_GGML_METAL_DEL_KERNEL(scale_4);
LM_GGML_METAL_DEL_KERNEL(silu);
LM_GGML_METAL_DEL_KERNEL(relu);
LM_GGML_METAL_DEL_KERNEL(gelu);
@@ -923,15 +926,20 @@ void lm_ggml_metal_graph_compute(

const float scale = *(const float *) src1->data;

[encoder setComputePipelineState:ctx->pipeline_scale];
int64_t n = lm_ggml_nelements(dst);

if (n % 4 == 0) {
n /= 4;
[encoder setComputePipelineState:ctx->pipeline_scale_4];
} else {
[encoder setComputePipelineState:ctx->pipeline_scale];
}

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];

const int64_t n = lm_ggml_nelements(dst);
LM_GGML_ASSERT(n % 4 == 0);

[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case LM_GGML_OP_UNARY:
switch (lm_ggml_get_unary_op(gf->nodes[i])) {
Loading

0 comments on commit 455d83f

Please sign in to comment.