Skip to content

Commit

Permalink
refactor: enhancement memory management and worker management
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed May 10, 2024
1 parent fa88550 commit 5408c9d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
5 changes: 4 additions & 1 deletion src/LlamaCompletionWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void LlamaCompletionWorker::Execute() {
const auto n_keep = _params.n_keep;
size_t n_cur = 0;
size_t n_input = 0;
const auto model = llama_get_model(_sess->context());
const auto model = _sess->model();
const bool add_bos = llama_should_add_bos_token(model);
auto ctx = _sess->context();

Expand Down Expand Up @@ -144,6 +144,9 @@ void LlamaCompletionWorker::Execute() {
}
const auto t_main_end = ggml_time_us();
_sess->get_mutex().unlock();
if (_onComplete) {
_onComplete();
}
}

void LlamaCompletionWorker::OnOK() {
Expand Down
4 changes: 4 additions & 0 deletions src/LlamaCompletionWorker.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.hpp"
#include <functional>

struct CompletionResult {
std::string text = "";
Expand All @@ -18,6 +19,8 @@ class LlamaCompletionWorker : public Napi::AsyncWorker,

inline void Stop() { _stop = true; }

inline void onComplete(std::function<void()> cb) { _onComplete = cb; }

protected:
void Execute();
void OnOK();
Expand All @@ -30,5 +33,6 @@ class LlamaCompletionWorker : public Napi::AsyncWorker,
Napi::ThreadSafeFunction _tsfn;
bool _has_callback = false;
bool _stop = false;
std::function<void()> _onComplete;
CompletionResult _result;
};
7 changes: 6 additions & 1 deletion src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
.ThrowAsJavaScriptException();
}

_sess = std::make_shared<LlamaSession>(ctx, params);
_sess = std::make_shared<LlamaSession>(model, ctx, params);
_info = get_system_info(params);
}

Expand All @@ -93,6 +93,10 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
Napi::TypeError::New(env, "Context is disposed")
.ThrowAsJavaScriptException();
}
if (_wip != nullptr) {
Napi::TypeError::New(env, "Another completion is in progress")
.ThrowAsJavaScriptException();
}
auto options = info[0].As<Napi::Object>();

gpt_params params = _sess->params();
Expand Down Expand Up @@ -143,6 +147,7 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
new LlamaCompletionWorker(info, _sess, callback, params, stop_words);
worker->Queue();
_wip = worker;
worker->onComplete([this]() { _wip = nullptr; });
return worker->Promise();
}

Expand Down
18 changes: 11 additions & 7 deletions src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,36 @@ constexpr T get_option(const Napi::Object &options, const std::string &name,

class LlamaSession {
public:
LlamaSession(llama_context *ctx, gpt_params params)
: ctx_(LlamaCppContext(ctx, llama_free)), params_(params) {
LlamaSession(llama_model *model, llama_context *ctx, gpt_params params)
: model_(LlamaCppModel(model, llama_free_model)), ctx_(LlamaCppContext(ctx, llama_free)), params_(params) {
tokens_.reserve(params.n_ctx);
}

~LlamaSession() { dispose(); }

llama_context *context() { return ctx_.get(); }
inline llama_context *context() { return ctx_.get(); }

std::vector<llama_token>* tokens_ptr() { return &tokens_; }
inline llama_model *model() { return model_.get(); }

void set_tokens(std::vector<llama_token> tokens) {
inline std::vector<llama_token>* tokens_ptr() { return &tokens_; }

inline void set_tokens(std::vector<llama_token> tokens) {
tokens_ = std::move(tokens);
}

const gpt_params &params() const { return params_; }
inline const gpt_params &params() const { return params_; }

std::mutex &get_mutex() { return mutex; }
inline std::mutex &get_mutex() { return mutex; }

void dispose() {
std::lock_guard<std::mutex> lock(mutex);
tokens_.clear();
ctx_.reset();
model_.reset();
}

private:
LlamaCppModel model_;
LlamaCppContext ctx_;
const gpt_params params_;
std::vector<llama_token> tokens_{};
Expand Down

0 comments on commit 5408c9d

Please sign in to comment.