Skip to content

Commit

Permalink
[llava] Expose prefill image and prompt APIs
Browse files Browse the repository at this point in the history
Differential Revision: D62273041

Pull Request resolved: pytorch#5119
  • Loading branch information
larryliu0820 authored Sep 6, 2024
1 parent 030fc3f commit 9739609
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ Error Runner::generate(

// print prompts
wrapped_callback(prompt);

auto prefill_res = text_prefiller_->prefill(prompt_tokens, 0);
int64_t pos = 0;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
Expand Down
85 changes: 52 additions & 33 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,54 @@ Error LlavaRunner::load() {
return Error::Ok;
}

Error LlavaRunner::prefill_images(
std::vector<Image>& images,
int64_t& start_pos) {
for (auto& image : images) {
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
}
return Error::Ok;
}

Result<uint64_t> LlavaRunner::prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos,
int8_t eos) {
std::vector<uint64_t> prompt_tokens =
ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));

return text_prefiller_->prefill(prompt_tokens, start_pos);
}

Error LlavaRunner::generate_from_pos(
const std::string& prompt,
int32_t seq_len,
int64_t start_pos,
std::function<void(const std::string&)> token_callback,
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback) {
// prefill user prompt. No BOS because preset prompt already has it.
token_callback(prompt);

uint64_t prefill_next_token =
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
stats_.num_prompt_tokens = start_pos;

// Generate tokens
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
{prefill_next_token}, start_pos, seq_len, token_callback));

// Bookkeeping
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}
return Error::Ok;
}

Error LlavaRunner::generate(
std::vector<Image> images,
const std::string& prompt,
Expand All @@ -96,43 +144,14 @@ Error LlavaRunner::generate(
int64_t pos = 0;

// prefill preset prompt
std::vector<uint64_t> preset_prompt_tokens =
ET_UNWRAP(tokenizer_->encode(kPresetPrompt, /*bos=*/1, /*eos=*/0));
size_t num_preset_tokens = preset_prompt_tokens.size();

ET_UNWRAP(text_prefiller_->prefill(preset_prompt_tokens, pos));
pos += num_preset_tokens;
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);

// prefill images
for (auto& image : images) {
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
wrapped_callback(prompt);

std::vector<uint64_t> user_prompt_tokens =
ET_UNWRAP(tokenizer_->encode(prompt, /*bos=*/0, /*eos=*/0));
size_t num_user_tokens = user_prompt_tokens.size();

uint64_t prefill_next_token =
ET_UNWRAP(text_prefiller_->prefill(user_prompt_tokens, pos));
pos += num_user_tokens;
prefill_images(images, pos);

// Generate tokens
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
{prefill_next_token}, pos, seq_len, wrapped_callback));

// Bookkeeping
stats_.num_prompt_tokens = num_preset_tokens + num_user_tokens;
stats_.num_generated_tokens = num_generated_tokens;
::executorch::llm::print_report(stats_);
if (stats_callback) {
stats_callback(stats_);
}

return Error::Ok;
return generate_from_pos(
prompt, seq_len, pos, wrapped_callback, stats_callback);
}

} // namespace torch::executor
42 changes: 42 additions & 0 deletions examples/models/llava/runner/llava_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,48 @@ class LlavaRunner : public MultimodalRunner {
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});

/**
* Prefill an LLaVA Module with the given images input.
* @param images The image input to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The error status of prefilling images.
*/
Error prefill_images(std::vector<Image>& images, int64_t& start_pos);

/**
* Prefill an LLaVA Module with the given text input.
* @param prompt The text prompt to LLaVA.
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @param bos The number of BOS (begin of sequence) token.
* @param eos The number of EOS (end of sequence) token.
* @return The generated token of the LLaVA Module after prefill prompt.
*/
Result<uint64_t> prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos = 0,
int8_t eos = 0);

/**
* Generate tokens from the given prompt, starting from the given position.
* @param prompt The text prompt to LLaVA.
* @param seq_len The total sequence length, including the prompt tokens and
* new tokens.
* @param start_pos The starting position in KV cache of the input in the LLM.
* @param token_callback What to do after a token is generated.
* @param stats_callback What to do with Stats.
* @return The error code.
*/
Error generate_from_pos(
const std::string& prompt,
int32_t seq_len = 1024,
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {});

private:
inline static const std::string kPresetPrompt =
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
Expand Down
25 changes: 13 additions & 12 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TextPrefiller::TextPrefiller(

::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos_index) {
int64_t& start_pos) {
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
if (!text_decoder_runner_->is_method_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
Expand All @@ -43,45 +43,46 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
{1, num_prompt_tokens},
exec_aten::ScalarType::Long);

auto start_pos =
from_blob(&start_pos_index, {1}, exec_aten::ScalarType::Long);
auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);

auto outputs_res = text_decoder_runner_->step(tokens, start_pos);
auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());

start_pos += num_prompt_tokens;
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);

auto start_pos = from_blob(&pos_data, {1}, exec_aten::ScalarType::Long);
auto start_pos_tensor =
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);

// run the first token and get back logits tensor. Assuming the first token
// is bos so don't callback.
auto logits_tensor =
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));

pos = 1; // start from index 1
pos += 1; // start the loop from index 1
start_pos += 1;

while (pos < num_prompt_tokens) {
// Run the model
pos_data = start_pos_index + pos;

// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[pos];

logits_tensor = ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos));
logits_tensor =
ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));

pos++;
start_pos++;
}

cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TextPrefiller {
*/
::executorch::runtime::Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos = 0);
int64_t& start_pos);

private:
TextDecoderRunner* text_decoder_runner_;
Expand Down

0 comments on commit 9739609

Please sign in to comment.