Skip to content

Commit

Permalink
feat: support manually release context
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Apr 29, 2024
1 parent 9b4f806 commit a342e38
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export interface LlamaContext {
stopCompletion(): void
saveSession(path: string): Promise<void>
loadSession(path: string): Promise<void>
release(): Promise<void>
}

export interface Module {
Expand Down
42 changes: 42 additions & 0 deletions src/addons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::LoadSession>(
"loadSession",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Release>(
"release",
static_cast<napi_property_attributes>(napi_enumerable))});
Napi::FunctionReference *constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);
Expand Down Expand Up @@ -139,12 +142,21 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {

std::mutex &getMutex() { return mutex; }

void Dispose() {
std::lock_guard<std::mutex> lock(mutex);
compl_worker = nullptr;
ctx.reset();
tokens.reset();
model.reset();
}

private:
Napi::Value GetSystemInfo(const Napi::CallbackInfo &info);
Napi::Value Completion(const Napi::CallbackInfo &info);
void StopCompletion(const Napi::CallbackInfo &info);
Napi::Value SaveSession(const Napi::CallbackInfo &info);
Napi::Value LoadSession(const Napi::CallbackInfo &info);
Napi::Value Release(const Napi::CallbackInfo &info);

gpt_params params;
LlamaCppModel model{nullptr, llama_free_model};
Expand Down Expand Up @@ -389,6 +401,26 @@ class LoadSessionWorker : public Napi::AsyncWorker,
void OnError(const Napi::Error &err) { Reject(err.Value()); }
};

class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred {
public:
DisposeWorker(Napi::Env env, LlamaContext *ctx)
: AsyncWorker(env), Deferred(env), ctx_(ctx) {
ctx_->Ref();
}

~DisposeWorker() { ctx_->Unref(); }

protected:
void Execute() override { ctx_->Dispose(); }

void OnOK() override { Resolve(AsyncWorker::Env().Undefined()); }

void OnError(const Napi::Error &err) override { Reject(err.Value()); }

private:
LlamaContext *ctx_;
};

// getSystemInfo(): string
Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
return Napi::String::New(info.Env(), get_system_info(params).c_str());
Expand Down Expand Up @@ -487,6 +519,16 @@ Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) {
return worker->Promise();
}

// release(): Promise<void>
Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) {
if (compl_worker != nullptr) {
compl_worker->Stop();
}
auto *worker = new DisposeWorker(info.Env(), this);
worker->Queue();
return worker->Promise();
}

Napi::Object Init(Napi::Env env, Napi::Object exports) {
LlamaContext::Export(env, exports);
return exports;
Expand Down
1 change: 1 addition & 0 deletions test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ it('work fine', async () => {
expect(data).toMatchObject({ token: expect.any(String) })
})
expect(result).toMatchSnapshot()
await model.release()
})

0 comments on commit a342e38

Please sign in to comment.