-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
626 additions
and
538 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#include "DisposeWorker.h" | ||
|
||
DisposeWorker::DisposeWorker(const Napi::CallbackInfo &info, | ||
LlamaSessionPtr sess) | ||
: AsyncWorker(info.Env()), Deferred(info.Env()), sess_(std::move(sess)) {} | ||
|
||
void DisposeWorker::Execute() { sess_->dispose(); } | ||
|
||
void DisposeWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); } | ||
|
||
void DisposeWorker::OnError(const Napi::Error &err) { Reject(err.Value()); } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#include "common.hpp" | ||
|
||
class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred { | ||
public: | ||
DisposeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr sess); | ||
|
||
protected: | ||
void Execute(); | ||
void OnOK(); | ||
void OnError(const Napi::Error &err); | ||
|
||
private: | ||
LlamaSessionPtr sess_; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#include "LlamaCompletionWorker.h" | ||
#include "LlamaContext.h" | ||
|
||
size_t common_part(const std::vector<llama_token> &a, | ||
const std::vector<llama_token> &b) { | ||
size_t i = 0; | ||
while (i < a.size() && i < b.size() && a[i] == b[i]) { | ||
i++; | ||
} | ||
return i; | ||
} | ||
|
||
size_t findStoppingStrings(const std::string &text, | ||
const size_t last_token_size, | ||
const std::vector<std::string> &stop_words) { | ||
size_t stop_pos = std::string::npos; | ||
|
||
for (const std::string &word : stop_words) { | ||
size_t pos; | ||
|
||
const size_t tmp = word.size() + last_token_size; | ||
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; | ||
|
||
pos = text.find(word, from_pos); | ||
|
||
if (pos != std::string::npos && | ||
(stop_pos == std::string::npos || pos < stop_pos)) { | ||
stop_pos = pos; | ||
} | ||
} | ||
|
||
return stop_pos; | ||
} | ||
|
||
LlamaCompletionWorker::LlamaCompletionWorker( | ||
const Napi::CallbackInfo &info, LlamaSessionPtr &sess, | ||
Napi::Function callback, gpt_params params, | ||
std::vector<std::string> stop_words) | ||
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), | ||
_params(params), _stop_words(stop_words) { | ||
if (!callback.IsEmpty()) { | ||
_tsfn = Napi::ThreadSafeFunction::New(info.Env(), callback, | ||
"LlamaCompletionCallback", 0, 1); | ||
_has_callback = true; | ||
} | ||
} | ||
|
||
LlamaCompletionWorker::~LlamaCompletionWorker() { | ||
if (_has_callback) { | ||
_tsfn.Release(); | ||
} | ||
} | ||
|
||
void LlamaCompletionWorker::Execute() { | ||
_sess->get_mutex().lock(); | ||
const auto t_main_start = ggml_time_us(); | ||
const size_t n_ctx = _params.n_ctx; | ||
const auto n_keep = _params.n_keep; | ||
size_t n_cur = 0; | ||
size_t n_input = 0; | ||
const auto model = llama_get_model(_sess->context()); | ||
const bool add_bos = llama_should_add_bos_token(model); | ||
auto ctx = _sess->context(); | ||
|
||
llama_set_rng_seed(ctx, _params.seed); | ||
|
||
LlamaCppSampling sampling{llama_sampling_init(_params.sparams), | ||
llama_sampling_free}; | ||
|
||
std::vector<llama_token> prompt_tokens = | ||
::llama_tokenize(ctx, _params.prompt, add_bos); | ||
n_input = prompt_tokens.size(); | ||
if (_sess->tokens_ptr()->size() > 0) { | ||
n_cur = common_part(*(_sess->tokens_ptr()), prompt_tokens); | ||
if (n_cur == n_input) { | ||
--n_cur; | ||
} | ||
n_input -= n_cur; | ||
llama_kv_cache_seq_rm(ctx, 0, n_cur, -1); | ||
} | ||
_sess->set_tokens(std::move(prompt_tokens)); | ||
|
||
const int max_len = _params.n_predict < 0 ? 0 : _params.n_predict; | ||
_sess->tokens_ptr()->reserve(_sess->tokens_ptr()->size() + max_len); | ||
|
||
auto embd = _sess->tokens_ptr(); | ||
for (int i = 0; i < max_len || _stop; i++) { | ||
// check if we need to remove some tokens | ||
if (embd->size() >= _params.n_ctx) { | ||
const int n_left = n_cur - n_keep - 1; | ||
const int n_discard = n_left / 2; | ||
|
||
llama_kv_cache_seq_rm(ctx, 0, n_keep + 1, n_keep + n_discard + 1); | ||
llama_kv_cache_seq_add(ctx, 0, n_keep + 1 + n_discard, n_cur, -n_discard); | ||
|
||
// shift the tokens | ||
embd->insert(embd->begin() + n_keep + 1, | ||
embd->begin() + n_keep + 1 + n_discard, embd->end()); | ||
embd->resize(embd->size() - n_discard); | ||
|
||
n_cur -= n_discard; | ||
_result.truncated = true; | ||
} | ||
int ret = llama_decode( | ||
ctx, llama_batch_get_one(embd->data() + n_cur, n_input, n_cur, 0)); | ||
if (ret < 0) { | ||
SetError("Failed to decode token, code: " + std::to_string(ret)); | ||
break; | ||
} | ||
// sample the next token | ||
const llama_token new_token_id = | ||
llama_sampling_sample(sampling.get(), ctx, nullptr); | ||
// prepare the next batch | ||
embd->emplace_back(new_token_id); | ||
auto token = llama_token_to_piece(ctx, new_token_id); | ||
_result.text += token; | ||
n_cur += n_input; | ||
_result.tokens_evaluated += n_input; | ||
_result.tokens_predicted += 1; | ||
n_input = 1; | ||
if (_has_callback) { | ||
const char *c_token = strdup(token.c_str()); | ||
_tsfn.BlockingCall(c_token, [](Napi::Env env, Napi::Function jsCallback, | ||
const char *value) { | ||
auto obj = Napi::Object::New(env); | ||
obj.Set("token", Napi::String::New(env, value)); | ||
delete value; | ||
jsCallback.Call({obj}); | ||
}); | ||
} | ||
// is it an end of generation? | ||
if (llama_token_is_eog(model, new_token_id)) { | ||
break; | ||
} | ||
// check for stop words | ||
if (!_stop_words.empty()) { | ||
const size_t stop_pos = | ||
findStoppingStrings(_result.text, token.size(), _stop_words); | ||
if (stop_pos != std::string::npos) { | ||
break; | ||
} | ||
} | ||
} | ||
const auto t_main_end = ggml_time_us(); | ||
_sess->get_mutex().unlock(); | ||
} | ||
|
||
void LlamaCompletionWorker::OnOK() { | ||
auto result = Napi::Object::New(Napi::AsyncWorker::Env()); | ||
result.Set("tokens_evaluated", Napi::Number::New(Napi::AsyncWorker::Env(), | ||
_result.tokens_evaluated)); | ||
result.Set("tokens_predicted", Napi::Number::New(Napi::AsyncWorker::Env(), | ||
_result.tokens_predicted)); | ||
result.Set("truncated", | ||
Napi::Boolean::New(Napi::AsyncWorker::Env(), _result.truncated)); | ||
result.Set("text", | ||
Napi::String::New(Napi::AsyncWorker::Env(), _result.text.c_str())); | ||
Napi::Promise::Deferred::Resolve(result); | ||
} | ||
|
||
void LlamaCompletionWorker::OnError(const Napi::Error &err) { | ||
Napi::Promise::Deferred::Reject(err.Value()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#include "common.hpp" | ||
|
||
struct CompletionResult { | ||
std::string text = ""; | ||
bool truncated = false; | ||
size_t tokens_predicted = 0; | ||
size_t tokens_evaluated = 0; | ||
}; | ||
|
||
class LlamaCompletionWorker : public Napi::AsyncWorker, | ||
public Napi::Promise::Deferred { | ||
public: | ||
LlamaCompletionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess, | ||
Napi::Function callback, gpt_params params, | ||
std::vector<std::string> stop_words = {}); | ||
|
||
~LlamaCompletionWorker(); | ||
|
||
inline void Stop() { _stop = true; } | ||
|
||
protected: | ||
void Execute(); | ||
void OnOK(); | ||
void OnError(const Napi::Error &err); | ||
|
||
private: | ||
LlamaSessionPtr _sess; | ||
gpt_params _params; | ||
std::vector<std::string> _stop_words; | ||
Napi::ThreadSafeFunction _tsfn; | ||
bool _has_callback = false; | ||
bool _stop = false; | ||
CompletionResult _result; | ||
}; |
Oops, something went wrong.