Skip to content

Commit

Permalink
Implement runner for phi-3-mini (pytorch#4500)
Browse files Browse the repository at this point in the history
Summary:
This PR mainly does the following things:
- implement runner for phi-3-mini
- replace phi-3-mini to use bpe tokenizer, which is shared for all LLMs now
- fix a small bug in bpe tokenizer

Pull Request resolved: pytorch#4500

Test Plan:
```
./build/phi_3_mini_runner --model_path phi-3-mini-kv-128.pte --tokenizer_path tokenizer.bin --prompt "Tell me a story" --temperature 0
Prefilling tokens ...
24948 592 263 5828
Generating tokens ...
 about a time when you had to overcome a challenge.
I remember when I was in high school, I had to prepare for a big exam that would determine my future. I had to study hard, but I also had to balance my schoolwork, my hobbies, and my social life. It was not easy, but I managed to do it. I made a study schedule, set goals, and rewarded myself for my achievements. I also asked for help from my teachers, friends, and family. I faced many difficulties, but I never gave up. I passed the exam with flying colors and```

Reviewed By: larryliu0820

Differential Revision: D60609165

Pulled By: helunwencser

fbshipit-source-id: 9abab0ba8ea8e50559272c6001fa868e49f40a96
  • Loading branch information
helunwencser authored and facebook-github-bot committed Aug 5, 2024
1 parent 14c2473 commit 864e0b0
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 126 deletions.
45 changes: 29 additions & 16 deletions examples/models/phi-3-mini/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# ### Editing this file ###
#
# This file should be formatted with
# ~~~
# cmake-format -i CMakeLists.txt
# ~~~
# It should also be cmake-lint clean.
#

cmake_minimum_required(VERSION 3.19)
project(phi_3_mini_runner)

Expand All @@ -18,22 +27,26 @@ option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON)

add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/../../..
${CMAKE_BINARY_DIR}/../../..)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/third-party/sentencepiece
${CMAKE_BINARY_DIR}/sentencepiece)
${CMAKE_CURRENT_SOURCE_DIR}/../../.. ${CMAKE_BINARY_DIR}/../../..
)
if(NOT TARGET gflags)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/gflags
${CMAKE_BINARY_DIR}/gflags
)
endif()

add_executable(phi_3_mini_runner main.cpp)
add_executable(
phi_3_mini_runner
main.cpp runner.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/sampler/sampler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/tokenizer/bpe_tokenizer.cpp
)
target_include_directories(
phi_3_mini_runner
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/../../../extension/llm/third-party/sentencepiece/src)
phi_3_mini_runner
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/gflags/src
)
target_link_libraries(
phi_3_mini_runner
PRIVATE
executorch
extension_module_static
optimized_native_cpu_ops_lib
xnnpack_backend
sentencepiece)
phi_3_mini_runner PRIVATE executorch extension_module_static
optimized_native_cpu_ops_lib xnnpack_backend gflags
)
92 changes: 26 additions & 66 deletions examples/models/phi-3-mini/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,45 @@
* LICENSE file in the root directory of this source tree.
*/

// main.cpp
#include <gflags/gflags.h>

#include <iostream>
#include <executorch/examples/models/phi-3-mini/runner.h>

#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/managed_tensor.h>
DEFINE_string(
model_path,
"phi-3-mini.pte",
"File path for model serialized in flatbuffer format.");

#include "sentence_piece_tokenizer.h"
DEFINE_string(tokenizer_path, "tokenizer.bin", "File path for tokenizer.");

using namespace torch::executor;
DEFINE_string(prompt, "Tell me a story", "Prompt.");

// The value of the phi-3-mini `<|endoftext|>` token.
#define ENDOFTEXT_TOKEN 32000
#define VOCABULARY_SIZE 32064
DEFINE_double(
temperature,
0.8f,
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");

// TODO(lunwenh): refactor and share with llama
void generate(
Module& llm_model,
std::string& prompt,
SentencePieceTokenizer& tokenizer,
size_t max_output_length) {
// Convert the input text into a list of integers (tokens) that represents
// it, using the string-to-token mapping that the model was trained on.
// Each token is an integer that represents a word or part of a word.
std::vector<int64_t> input_tokens = tokenizer.encode(prompt);
DEFINE_int32(
seq_len,
128,
"Total number of tokens to generate (prompt + output).");

std::cout << "Generating tokens ..." << std::endl;
int main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

std::vector<int64_t> output_tokens;
const char* model_path = FLAGS_model_path.c_str();

for (size_t i = 0; i < max_output_length; i++) {
ManagedTensor tensor_tokens(
input_tokens.data(),
{1, static_cast<int>(input_tokens.size())},
ScalarType::Long);
std::vector<EValue> inputs = {tensor_tokens.get_aliasing_tensor()};
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();

Result<std::vector<EValue>> result_evalue = llm_model.forward(inputs);
const char* prompt = FLAGS_prompt.c_str();

const auto error = result_evalue.error();
Tensor logits_tensor = result_evalue.get()[0].toTensor();
const auto sentence_length = logits_tensor.size(1);
std::vector<float> logits(
logits_tensor.data_ptr<float>() +
(sentence_length - 1) * VOCABULARY_SIZE,
logits_tensor.data_ptr<float>() + sentence_length * VOCABULARY_SIZE);
double temperature = FLAGS_temperature;

// Sample the next token from the logits.
int64_t next_token =
std::max_element(logits.begin(), logits.end()) - logits.begin();
int32_t seq_len = FLAGS_seq_len;

std::cout << next_token << "\t";
std::cout.flush();
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);

// Break if we reached the end of the text.
if (next_token == ENDOFTEXT_TOKEN) {
break;
}
runner.generate(prompt, seq_len);

output_tokens.push_back(next_token);

// Update next input.
input_tokens.push_back(next_token);
}

std::cout << std::endl;
std::cout << tokenizer.decode(output_tokens) << std::endl;
}

int main() {
// Set up the prompt. This provides the seed text for the model to elaborate.
std::cout << "Enter model prompt: ";
std::string prompt;
std::getline(std::cin, prompt);

SentencePieceTokenizer tokenizer("tokenizer.model");

Module model("phi-3-mini.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);

const auto max_output_tokens = 128;
generate(model, prompt, tokenizer, max_output_tokens);
return 0;
}
109 changes: 109 additions & 0 deletions examples/models/phi-3-mini/runner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/examples/models/phi-3-mini/runner.h>

#include <ctime>
#include <iostream>

#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
#include <executorch/extension/runner_util/managed_tensor.h>
#include <executorch/runtime/platform/log.h>

namespace torch::executor {

#define SAMPLER_TOP 0.9f
#define ENDOFTEXT_TOKEN 32000
#define VOCABULARY_SIZE 32064

Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
: module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
tokenizer_(std::make_unique<BPETokenizer>()),
sampler_(std::make_unique<Sampler>(
VOCABULARY_SIZE,
temperature,
SAMPLER_TOP,
static_cast<unsigned long long>(std::time(nullptr)))) {
ET_CHECK_MSG(
tokenizer_->load(tokenizer_path) == Error::Ok,
"Failed to load tokenizer at %s",
tokenizer_path.c_str());
ET_LOG(
Info,
"Created Phi-3-mini runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}

void Runner::generate(const std::string& prompt, std::size_t max_seq_len) {
auto encode_res = tokenizer_->encode(prompt, 0, 0);
ET_CHECK_MSG(
encode_res.error() == Error::Ok, "Failed to encode %", prompt.c_str());
auto input_tokens = encode_res.get();

std::cout << "Prefilling tokens ..." << std::endl;
for (auto token : input_tokens) {
std::cout << token << " ";
}
std::cout << std::endl;
std::cout.flush();
auto prev_token = input_tokens.back();
auto current_token = prefill(input_tokens);

std::cout << "Generating tokens ..." << std::endl;
std::cout << tokenizer_->decode(prev_token, current_token).get();
std::cout.flush();

std::size_t seq_len = input_tokens.size() + 1;

while (current_token != ENDOFTEXT_TOKEN && seq_len < max_seq_len) {
prev_token = current_token;
current_token = run_model_step(current_token);
std::cout << tokenizer_->decode(prev_token, current_token).get();
std::cout.flush();

++seq_len;
}

std::cout << std::endl;
}

uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) {
return sampler_->sample(logits_tensor.data_ptr<float>());
}

uint64_t Runner::prefill(std::vector<uint64_t>& tokens) {
ManagedTensor input_tokens(
tokens.data(),
{1, static_cast<exec_aten::SizesType>(tokens.size())},
ScalarType::Long);
std::vector<EValue> inputs = {input_tokens.get_aliasing_tensor()};

auto result = module_->forward(inputs);
ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens");

return logits_to_token(result.get()[0].toTensor());
}

uint64_t Runner::run_model_step(uint64_t token) {
ManagedTensor input_token(&token, {1, 1}, ScalarType::Long);
std::vector<EValue> inputs = {input_token.get_aliasing_tensor()};

auto result = module_->forward(inputs);
ET_CHECK_MSG(
result.error() == Error::Ok,
"Failed to run forward() for token %" PRIu64,
token);

return logits_to_token(result.get()[0].toTensor());
}

} // namespace torch::executor
50 changes: 50 additions & 0 deletions examples/models/phi-3-mini/runner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// A simple phi-3-mini runner that includes preprocessing and post processing
// logic. The module takes in a string as input and emits a string as output.

#pragma once

#include <memory>
#include <string>

#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>

namespace torch::executor {

class Runner {
public:
explicit Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature = 0.8f);

/**
* Generates response for a given prompt.
*
* @param[in] prompt The prompt to generate a response for.
* @param[in] max_seq_len The maximum length of the sequence to generate,
* including prompt.
*/
void generate(const std::string& prompt, std::size_t max_seq_len);

private:
uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor);
uint64_t prefill(std::vector<uint64_t>& tokens);
uint64_t run_model_step(uint64_t token);

std::unique_ptr<Module> module_;
std::unique_ptr<Tokenizer> tokenizer_;
std::unique_ptr<Sampler> sampler_;
};

} // namespace torch::executor
43 changes: 0 additions & 43 deletions examples/models/phi-3-mini/sentence_piece_tokenizer.h

This file was deleted.

2 changes: 1 addition & 1 deletion extension/llm/tokenizer/bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const {
std::vector<uint64_t> tokens;

// add optional BOS token, if desired
if (bos > 0) {
if (bos >= 0) {
while (bos--) {
tokens.push_back(bos_tok_);
}
Expand Down

0 comments on commit 864e0b0

Please sign in to comment.