Skip to content

Commit

Permalink
fix(cpp): migrate chat_template api
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Jan 25, 2025
1 parent 87ee079 commit cc3feb4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 3 deletions.
10 changes: 9 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,15 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
}

const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);
common_chat_templates templates = common_chat_templates_from_model(llama->model, tmpl_chars);
std::string formatted_chat = common_chat_apply_template(
*templates.template_default,
chat,
true,
/* use_jinja= */ false
);

env->ReleaseStringUTFChars(chat_template, tmpl_chars);

return env->NewStringUTF(formatted_chat.c_str());
}
Expand Down
6 changes: 5 additions & 1 deletion cpp/rn-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ bool llama_rn_context::loadModel(common_params &params_)
}

bool llama_rn_context::validateModelChatTemplate() const {
const char * tmpl = llama_model_chat_template(model);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
if (tmpl == nullptr) {
return false;
}

llama_chat_message chat[] = {{"user", "test"}};
int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
Expand Down
1 change: 1 addition & 0 deletions cpp/rn-llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <sstream>
#include <iostream>
#include "chat-template.hpp"
#include "common.h"
#include "ggml.h"
#include "gguf.h"
Expand Down
10 changes: 9 additions & 1 deletion ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ - (NSDictionary *)modelInfo {
@"nEmbd": @(llama_model_n_embd(llama->model)),
@"nParams": @(llama_model_n_params(llama->model)),
@"isChatTemplateSupported": @(llama->validateModelChatTemplate()),
@"isChatTemplateToolUseSupported": @(llama->validateModelChatTemplateToolUse()),
@"metadata": meta
};
}
Expand All @@ -246,7 +247,14 @@ - (NSString *)getFormattedChat:(NSArray *)messages withTemplate:(NSString *)chat
}

auto tmpl = chatTemplate == nil ? "" : [chatTemplate UTF8String];
auto formatted_chat = common_chat_apply_template(llama->model, tmpl, chat, true);
common_chat_templates templates = common_chat_templates_from_model(llama->model, tmpl);
auto formatted_chat = common_chat_apply_template(
*templates.template_default,
chat,
true,
/* use_jinja= */ true
);

return [NSString stringWithUTF8String:formatted_chat.c_str()];
}

Expand Down

0 comments on commit cc3feb4

Please sign in to comment.