diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 02063ebfa5..2c72b4c724 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -204,8 +204,8 @@ Error Runner::generate( // print prompts wrapped_callback(prompt); - - auto prefill_res = text_prefiller_->prefill(prompt_tokens, 0); + int64_t pos = 0; + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); stats_.first_token_ms = util::time_in_ms(); stats_.prompt_eval_end_ms = util::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index 0fc06da0c5..04c77a1064 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -72,6 +72,54 @@ Error LlavaRunner::load() { return Error::Ok; } +Error LlavaRunner::prefill_images( + std::vector& images, + int64_t& start_pos) { + for (auto& image : images) { + // pos is updated inside image prefill. + ET_UNWRAP(image_prefiller_->prefill(image, start_pos)); + } + return Error::Ok; +} + +Result LlavaRunner::prefill_prompt( + const std::string& prompt, + int64_t& start_pos, + int8_t bos, + int8_t eos) { + std::vector prompt_tokens = + ET_UNWRAP(tokenizer_->encode(prompt, bos, eos)); + + return text_prefiller_->prefill(prompt_tokens, start_pos); +} + +Error LlavaRunner::generate_from_pos( + const std::string& prompt, + int32_t seq_len, + int64_t start_pos, + std::function token_callback, + std::function + stats_callback) { + // prefill user prompt. No BOS because preset prompt already has it. + token_callback(prompt); + + uint64_t prefill_next_token = + ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0)); + stats_.num_prompt_tokens = start_pos; + + // Generate tokens + int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate( + {prefill_next_token}, start_pos, seq_len, token_callback)); + + // Bookkeeping + stats_.num_generated_tokens = num_generated_tokens; + ::executorch::llm::print_report(stats_); + if (stats_callback) { + stats_callback(stats_); + } + return Error::Ok; +} + Error LlavaRunner::generate( std::vector images, const std::string& prompt, @@ -96,43 +144,14 @@ Error LlavaRunner::generate( int64_t pos = 0; // prefill preset prompt - std::vector preset_prompt_tokens = - ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0)); - size_t num_preset_tokens = preset_prompt_tokens.size(); - - ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos)); - pos += num_preset_tokens; + prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0); // prefill images - for (auto& image : images) { - // pos is updated inside image prefill. - ET_UNWRAP(image_prefiller_->prefill(image, pos)); - } - - // prefill user prompt. No BOS because preset prompt already has it. - wrapped_callback(prompt); - - std::vector user_prompt_tokens = - ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0)); - size_t num_user_tokens = user_prompt_tokens.size(); - - uint64_t prefill_next_token = - ET_UNWRAP(text_prefiller_->prefill(user_prompt_tokens, pos)); - pos += num_user_tokens; + prefill_images(images, pos); // Generate tokens - int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate( - {prefill_next_token}, pos, seq_len, wrapped_callback)); - - // Bookkeeping - stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens; - stats_.num_generated_tokens = num_generated_tokens; - ::executorch::llm::print_report(stats_); - if (stats_callback) { - stats_callback(stats_); - } - - return Error::Ok; + return generate_from_pos( + prompt, seq_len, pos, wrapped_callback, stats_callback); } } // namespace torch::executor diff --git a/examples/models/llava/runner/llava_runner.h b/examples/models/llava/runner/llava_runner.h index 9b14bc9283..923f8180a8 100644 --- a/examples/models/llava/runner/llava_runner.h +++ b/examples/models/llava/runner/llava_runner.h @@ -38,6 +38,48 @@ class LlavaRunner : public MultimodalRunner { std::function stats_callback = {}); + /** + * Prefill an LLaVA Module with the given images input. + * @param images The image input to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @return The error status of prefilling images. + */ + Error prefill_images(std::vector& images, int64_t& start_pos); + + /** + * Prefill an LLaVA Module with the given text input. + * @param prompt The text prompt to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The generated token of the LLaVA Module after prefill prompt. + */ + Result prefill_prompt( + const std::string& prompt, + int64_t& start_pos, + int8_t bos = 0, + int8_t eos = 0); + + /** + * Generate tokens from the given prompt, starting from the given position. + * @param prompt The text prompt to LLaVA. + * @param seq_len The total sequence length, including the prompt tokens and + * new tokens. + * @param start_pos The starting position in KV cache of the input in the LLM. + * @param token_callback What to do after a token is generated. + * @param stats_callback What to do with Stats. + * @return The error code. + */ + Error generate_from_pos( + const std::string& prompt, + int32_t seq_len = 1024, + int64_t start_pos = 0, + std::function token_callback = {}, + std::function + stats_callback = {}); + private: inline static const std::string kPresetPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "; diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index e6229e0b80..705583d638 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -25,7 +25,7 @@ TextPrefiller::TextPrefiller( ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, - int64_t start_pos_index) { + int64_t& start_pos) { ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null"); if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); @@ -43,45 +43,46 @@ ::executorch::runtime::Result TextPrefiller::prefill( {1, num_prompt_tokens}, exec_aten::ScalarType::Long); - auto start_pos = - from_blob(&start_pos_index, {1}, exec_aten::ScalarType::Long); + auto start_pos_tensor = + from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); - auto outputs_res = text_decoder_runner_->step(tokens, start_pos); + auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); + start_pos += num_prompt_tokens; cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); } else { // sequential prefill int64_t pos = 0; // position in the sequence - // token & pos - int64_t pos_data = 0; // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) cur_token = prompt_tokens[0]; // initialize tensor wrappers auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long); - auto start_pos = from_blob(&pos_data, {1}, exec_aten::ScalarType::Long); + auto start_pos_tensor = + from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); // run the first token and get back logits tensor. Assuming the first token // is bos so don't callback. auto logits_tensor = - ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos)); + ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor)); - pos = 1; // start from index 1 + pos += 1; // start the loop from index 1 + start_pos += 1; while (pos < num_prompt_tokens) { // Run the model - pos_data = start_pos_index + pos; - // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) cur_token = prompt_tokens[pos]; - logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos)); + logits_tensor = + ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor)); pos++; + start_pos++; } cur_token = text_decoder_runner_->logits_to_token(logits_tensor); diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index a8ba77b860..0ea126f32d 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -36,7 +36,7 @@ class TextPrefiller { */ ::executorch::runtime::Result prefill( std::vector& prompt_tokens, - int64_t start_pos = 0); + int64_t& start_pos); private: TextDecoderRunner* text_decoder_runner_;