Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Update C++ example with repetition_penalty & update Python code accordingly #12528

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ std::string add_chat_history(npu_model_params model_params,
return prompt;
}


std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
npu_model_params model_params, tokenizer_params tok_params, npu_generation_params generation_params, bool do_print){
auto start = std::chrono::high_resolution_clock::now();
float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size);
float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size,
generation_params.repetition_penalty);
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
Expand All @@ -115,8 +115,9 @@ std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_i

int token_nums = 0;
start = std::chrono::high_resolution_clock::now();
for (int i = 1; i < max_new_token; i++){
auto logits = run_decode(void_model, embd[i-1]);
for (int i = 1; i < generation_params.max_new_token; i++){
auto logits = run_decode(void_model, embd[i-1],
generation_params.repetition_penalty);
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
embd.push_back(token);
Expand Down Expand Up @@ -207,6 +208,10 @@ int main(int argc, char ** argv) {
tokenizer_params tok_params;
load_tokenizer(tok_params, params.model);

npu_generation_params generation_params;
load_generation_config_from_file(generation_params, params.model);
generation_params.max_new_token = n_predict;

if (cnv_mode){
std::string prompt;
std::string history = "";
Expand All @@ -228,9 +233,11 @@ int main(int argc, char ** argv) {
full_prompt = add_chat_history(model_params, prompt, "", true);
embd_inp = llm_tokenize(full_prompt, false);
}

generation_params.max_new_token = model_params.kv_len - embd_inp.size();

response = run_generate(model, embd_inp.data(), embd_inp.size(),
model_params, tok_params, model_params.kv_len - embd_inp.size(), false);
model_params, tok_params, generation_params, false);

std::cout << "Assistant:";
std::cout << response << std::endl;
Expand All @@ -251,7 +258,7 @@ int main(int argc, char ** argv) {

// single text generation
std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
model_params, tok_params, params.n_predict, true);
model_params, tok_params, generation_params, true);

std::cout << "Output: " << std::endl;
std::cout << output << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def simple_generate(
if token in eos:
break
token = run_decode(self.model_ptr, token, self.vocab_size,
input_list, repetition_penalty)
repetition_penalty)
if streamer is not None:
# rest tokens
streamer.put(torch.tensor([token]))
Expand Down
26 changes: 6 additions & 20 deletions python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,16 @@ def get_shared_lib_info(lib_base_name: str):
_lib.load_model_from_file.argtypes = [ctypes.c_char_p]
_lib.load_model_from_file.restype = ctypes.c_void_p

_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int]
_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)

_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int]
_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_float]
_lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)

_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
_lib.llm_sample_token.restype = ctypes.c_int

_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)

_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None

Expand All @@ -81,23 +77,13 @@ def load_model_from_file(model_dir: str):
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
if repetition_penalty != 1:
plogits = _lib.process_logits(plogits, vocab_size,
input_ptr, input_len,
repetition_penalty)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token


def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
plogits = _lib.run_decode(model_ptr, input_id)
if repetition_penalty != 1:
updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
updated_input_len = len(updated_input_ids)
plogits = _lib.process_logits(plogits, vocab_size,
updated_input_ptr, updated_input_len,
repetition_penalty)
def run_decode(model_ptr, input_id, vocab_size, repetition_penalty=1.0):
plogits = _lib.run_decode(model_ptr, input_id, repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token

Expand Down
Loading