Skip to content

Commit

Permalink
feat: support message template
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Jul 28, 2024
1 parent 35dcc68 commit 359b704
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 2 deletions.
9 changes: 8 additions & 1 deletion lib/binding.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import * as path from 'path'

export type ChatMessage = {
role: string
text: string
}

export type LlamaModelOptions = {
model: string
embedding?: boolean
Expand All @@ -12,7 +17,8 @@ export type LlamaModelOptions = {
}

export type LlamaCompletionOptions = {
prompt: string
messages?: ChatMessage[]
prompt?: string
n_samples?: number
temperature?: number
top_k?: number
Expand Down Expand Up @@ -48,6 +54,7 @@ export type EmbeddingResult = {
export interface LlamaContext {
new (options: LlamaModelOptions): LlamaContext
getSystemInfo(): string
getFormattedChat(messages: ChatMessage[]): string
completion(options: LlamaCompletionOptions, callback?: (token: LlamaCompletionToken) => void): Promise<LlamaCompletionResult>
stopCompletion(): void
tokenize(text: string): Promise<TokenizeResult>
Expand Down
34 changes: 33 additions & 1 deletion src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,27 @@
#include "SaveSessionWorker.h"
#include "TokenizeWorker.h"

std::vector<llama_chat_msg> get_messages(Napi::Array messages) {
std::vector<llama_chat_msg> chat;
for (size_t i = 0; i < messages.Length(); i++) {
auto message = messages.Get(i).As<Napi::Object>();
chat.push_back({
get_option<std::string>(message, "role", ""),
get_option<std::string>(message, "content", ""),
});
}
return std::move(chat);
}

void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
Napi::Function func = DefineClass(
env, "LlamaContext",
{InstanceMethod<&LlamaContext::GetSystemInfo>(
"getSystemInfo",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::GetFormattedChat>(
"getFormattedChat",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Completion>(
"completion",
static_cast<napi_property_attributes>(napi_enumerable)),
Expand Down Expand Up @@ -89,6 +104,17 @@ Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) {
return Napi::String::New(info.Env(), _info);
}

// getFormattedChat(messages: [{ role: string, content: string }]): string
Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() < 1 || !info[0].IsArray()) {
Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
}
auto messages = info[0].As<Napi::Array>();
auto formatted = llama_chat_apply_template(_sess->model(), "", get_messages(messages), true);
return Napi::String::New(env, formatted);
}

// completion(options: LlamaCompletionOptions, onToken?: (token: string) =>
// void): Promise<LlamaCompletionResult>
Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
Expand All @@ -110,7 +136,13 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
auto options = info[0].As<Napi::Object>();

gpt_params params = _sess->params();
params.prompt = get_option<std::string>(options, "prompt", "");
if (options.Has("messages") && options.Get("messages").IsArray()) {
auto messages = options.Get("messages").As<Napi::Array>();
auto formatted = llama_chat_apply_template(_sess->model(), "", get_messages(messages), true);
params.prompt = formatted;
} else {
params.prompt = get_option<std::string>(options, "prompt", "");
}
if (params.prompt.empty()) {
Napi::TypeError::New(env, "Prompt is required")
.ThrowAsJavaScriptException();
Expand Down
1 change: 1 addition & 0 deletions src/LlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class LlamaContext : public Napi::ObjectWrap<LlamaContext> {

private:
Napi::Value GetSystemInfo(const Napi::CallbackInfo &info);
Napi::Value GetFormattedChat(const Napi::CallbackInfo &info);
Napi::Value Completion(const Napi::CallbackInfo &info);
void StopCompletion(const Napi::CallbackInfo &info);
Napi::Value Tokenize(const Napi::CallbackInfo &info);
Expand Down
9 changes: 9 additions & 0 deletions test/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,15 @@ exports[`tokeneize 1`] = `

exports[`tokeneize 2`] = `"xxx"`;

exports[`tokeneize 3`] = `
"<|im_start|>user
<|im_end|>
<|im_start|>bot
<|im_end|>
<|im_start|>assistant
"
`;

exports[`work fine 1`] = `
{
"text": " swochadoorter scientific WindowsCa occupiedrå alta",
Expand Down
7 changes: 7 additions & 0 deletions test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ it('tokeneize', async () => {
const result = await model.detokenize([123, 123, 123])
expect(result).toMatchSnapshot()
}
{
const result = model.getFormattedChat([
{ role: 'user', text: 'Hello' },
{ role: 'bot', text: 'Hi' },
])
expect(result).toMatchSnapshot()
}
await model.release()
})

Expand Down

0 comments on commit 359b704

Please sign in to comment.