Skip to content

Commit

Permalink
[NPU] Update C++ example with repetition_penalty & update Python code…
Browse files Browse the repository at this point in the history
… accordingly (#12528)

* Update c++ npu examples with repetition penalty

* Fit python with updated C++ API

* Style fix

* Small fix

* Small fix
  • Loading branch information
Oscilloscope98 authored Dec 12, 2024
1 parent 2cce896 commit dbaf4ab
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ std::string add_chat_history(npu_model_params model_params,
return prompt;
}


std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
npu_model_params model_params, tokenizer_params tok_params, npu_generation_params generation_params, bool do_print){
auto start = std::chrono::high_resolution_clock::now();
float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size);
float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size,
generation_params.repetition_penalty);
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
Expand All @@ -115,8 +115,9 @@ std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_i

int token_nums = 0;
start = std::chrono::high_resolution_clock::now();
for (int i = 1; i < max_new_token; i++){
auto logits = run_decode(void_model, embd[i-1]);
for (int i = 1; i < generation_params.max_new_token; i++){
auto logits = run_decode(void_model, embd[i-1],
generation_params.repetition_penalty);
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
embd.push_back(token);
Expand Down Expand Up @@ -207,6 +208,10 @@ int main(int argc, char ** argv) {
tokenizer_params tok_params;
load_tokenizer(tok_params, params.model);

npu_generation_params generation_params;
load_generation_config_from_file(generation_params, params.model);
generation_params.max_new_token = n_predict;

if (cnv_mode){
std::string prompt;
std::string history = "";
Expand All @@ -228,9 +233,11 @@ int main(int argc, char ** argv) {
full_prompt = add_chat_history(model_params, prompt, "", true);
embd_inp = llm_tokenize(full_prompt, false);
}

generation_params.max_new_token = model_params.kv_len - embd_inp.size();

response = run_generate(model, embd_inp.data(), embd_inp.size(),
model_params, tok_params, model_params.kv_len - embd_inp.size(), false);
model_params, tok_params, generation_params, false);

std::cout << "Assistant:";
std::cout << response << std::endl;
Expand All @@ -251,7 +258,7 @@ int main(int argc, char ** argv) {

// single text generation
std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
model_params, tok_params, params.n_predict, true);
model_params, tok_params, generation_params, true);

std::cout << "Output: " << std::endl;
std::cout << output << std::endl;
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def simple_generate(
if token in eos:
break
token = run_decode(self.model_ptr, token, self.vocab_size,
input_list, repetition_penalty)
repetition_penalty)
if streamer is not None:
# rest tokens
streamer.put(torch.tensor([token]))
Expand Down
26 changes: 6 additions & 20 deletions python/llm/src/ipex_llm/transformers/npu_models/npu_llm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,16 @@ def get_shared_lib_info(lib_base_name: str):
_lib.load_model_from_file.argtypes = [ctypes.c_char_p]
_lib.load_model_from_file.restype = ctypes.c_void_p

_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int]
_lib.run_prefill.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.run_prefill.restype = ctypes.POINTER(ctypes.c_float)

_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int]
_lib.run_decode.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_float]
_lib.run_decode.restype = ctypes.POINTER(ctypes.c_float)

_lib.llm_sample_token.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_bool, ctypes.c_int]
_lib.llm_sample_token.restype = ctypes.c_int

_lib.process_logits.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.c_int,
ctypes.POINTER(ctypes.c_int), ctypes.c_int,
ctypes.c_float]
_lib.process_logits.restype = ctypes.POINTER(ctypes.c_float)

_lib.reset.argtypes = [ctypes.c_void_p]
_lib.reset.restype = None

Expand All @@ -81,23 +77,13 @@ def load_model_from_file(model_dir: str):
def run_prefill(model_ptr, input_ids, vocab_size, repetition_penalty=1.0):
input_ptr = (ctypes.c_int32 * len(input_ids))(*input_ids)
input_len = len(input_ids)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len)
if repetition_penalty != 1:
plogits = _lib.process_logits(plogits, vocab_size,
input_ptr, input_len,
repetition_penalty)
plogits = _lib.run_prefill(model_ptr, input_ptr, input_len, repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token


def run_decode(model_ptr, input_id, vocab_size, updated_input_ids, repetition_penalty=1.0):
plogits = _lib.run_decode(model_ptr, input_id)
if repetition_penalty != 1:
updated_input_ptr = (ctypes.c_int32 * len(updated_input_ids))(*updated_input_ids)
updated_input_len = len(updated_input_ids)
plogits = _lib.process_logits(plogits, vocab_size,
updated_input_ptr, updated_input_len,
repetition_penalty)
def run_decode(model_ptr, input_id, vocab_size, repetition_penalty=1.0):
plogits = _lib.run_decode(model_ptr, input_id, repetition_penalty)
new_token = _lib.llm_sample_token(plogits, True, vocab_size)
return new_token

Expand Down

0 comments on commit dbaf4ab

Please sign in to comment.