forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement runner for phi-3-mini (pytorch#4500)
Summary: This PR mainly does the following things: - implement runner for phi-3-mini - replace phi-3-mini to use bpe tokenizer, which is shared for all LLMs now - fix a small bug in bpe tokenizer Pull Request resolved: pytorch#4500 Test Plan: ``` ./build/phi_3_mini_runner --model_path phi-3-mini-kv-128.pte --tokenizer_path tokenizer.bin --prompt "Tell me a story" --temperature 0 Prefilling tokens ... 24948 592 263 5828 Generating tokens ... about a time when you had to overcome a challenge. I remember when I was in high school, I had to prepare for a big exam that would determine my future. I had to study hard, but I also had to balance my schoolwork, my hobbies, and my social life. It was not easy, but I managed to do it. I made a study schedule, set goals, and rewarded myself for my achievements. I also asked for help from my teachers, friends, and family. I faced many difficulties, but I never gave up. I passed the exam with flying colors and``` Reviewed By: larryliu0820 Differential Revision: D60609165 Pulled By: helunwencser fbshipit-source-id: 9abab0ba8ea8e50559272c6001fa868e49f40a96
- Loading branch information
1 parent
14c2473
commit 864e0b0
Showing
6 changed files
with
215 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/examples/models/phi-3-mini/runner.h> | ||
|
||
#include <ctime> | ||
#include <iostream> | ||
|
||
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h> | ||
#include <executorch/extension/runner_util/managed_tensor.h> | ||
#include <executorch/runtime/platform/log.h> | ||
|
||
namespace torch::executor { | ||
|
||
#define SAMPLER_TOP 0.9f | ||
#define ENDOFTEXT_TOKEN 32000 | ||
#define VOCABULARY_SIZE 32064 | ||
|
||
Runner::Runner( | ||
const std::string& model_path, | ||
const std::string& tokenizer_path, | ||
const float temperature) | ||
: module_(std::make_unique<Module>(model_path, Module::LoadMode::File)), | ||
tokenizer_(std::make_unique<BPETokenizer>()), | ||
sampler_(std::make_unique<Sampler>( | ||
VOCABULARY_SIZE, | ||
temperature, | ||
SAMPLER_TOP, | ||
static_cast<unsigned long long>(std::time(nullptr)))) { | ||
ET_CHECK_MSG( | ||
tokenizer_->load(tokenizer_path) == Error::Ok, | ||
"Failed to load tokenizer at %s", | ||
tokenizer_path.c_str()); | ||
ET_LOG( | ||
Info, | ||
"Created Phi-3-mini runner: model_path=%s, tokenizer_path=%s", | ||
model_path.c_str(), | ||
tokenizer_path.c_str()); | ||
} | ||
|
||
void Runner::generate(const std::string& prompt, std::size_t max_seq_len) { | ||
auto encode_res = tokenizer_->encode(prompt, 0, 0); | ||
ET_CHECK_MSG( | ||
encode_res.error() == Error::Ok, "Failed to encode %", prompt.c_str()); | ||
auto input_tokens = encode_res.get(); | ||
|
||
std::cout << "Prefilling tokens ..." << std::endl; | ||
for (auto token : input_tokens) { | ||
std::cout << token << " "; | ||
} | ||
std::cout << std::endl; | ||
std::cout.flush(); | ||
auto prev_token = input_tokens.back(); | ||
auto current_token = prefill(input_tokens); | ||
|
||
std::cout << "Generating tokens ..." << std::endl; | ||
std::cout << tokenizer_->decode(prev_token, current_token).get(); | ||
std::cout.flush(); | ||
|
||
std::size_t seq_len = input_tokens.size() + 1; | ||
|
||
while (current_token != ENDOFTEXT_TOKEN && seq_len < max_seq_len) { | ||
prev_token = current_token; | ||
current_token = run_model_step(current_token); | ||
std::cout << tokenizer_->decode(prev_token, current_token).get(); | ||
std::cout.flush(); | ||
|
||
++seq_len; | ||
} | ||
|
||
std::cout << std::endl; | ||
} | ||
|
||
uint64_t Runner::logits_to_token(const exec_aten::Tensor& logits_tensor) { | ||
return sampler_->sample(logits_tensor.data_ptr<float>()); | ||
} | ||
|
||
uint64_t Runner::prefill(std::vector<uint64_t>& tokens) { | ||
ManagedTensor input_tokens( | ||
tokens.data(), | ||
{1, static_cast<exec_aten::SizesType>(tokens.size())}, | ||
ScalarType::Long); | ||
std::vector<EValue> inputs = {input_tokens.get_aliasing_tensor()}; | ||
|
||
auto result = module_->forward(inputs); | ||
ET_CHECK_MSG(result.error() == Error::Ok, "Failed to prefill tokens"); | ||
|
||
return logits_to_token(result.get()[0].toTensor()); | ||
} | ||
|
||
uint64_t Runner::run_model_step(uint64_t token) { | ||
ManagedTensor input_token(&token, {1, 1}, ScalarType::Long); | ||
std::vector<EValue> inputs = {input_token.get_aliasing_tensor()}; | ||
|
||
auto result = module_->forward(inputs); | ||
ET_CHECK_MSG( | ||
result.error() == Error::Ok, | ||
"Failed to run forward() for token %" PRIu64, | ||
token); | ||
|
||
return logits_to_token(result.get()[0].toTensor()); | ||
} | ||
|
||
} // namespace torch::executor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
// A simple phi-3-mini runner that includes preprocessing and post processing | ||
// logic. The module takes in a string as input and emits a string as output. | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <string> | ||
|
||
#include <executorch/extension/llm/sampler/sampler.h> | ||
#include <executorch/extension/llm/tokenizer/tokenizer.h> | ||
#include <executorch/extension/module/module.h> | ||
#include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
|
||
namespace torch::executor { | ||
|
||
class Runner { | ||
public: | ||
explicit Runner( | ||
const std::string& model_path, | ||
const std::string& tokenizer_path, | ||
const float temperature = 0.8f); | ||
|
||
/** | ||
* Generates response for a given prompt. | ||
* | ||
* @param[in] prompt The prompt to generate a response for. | ||
* @param[in] max_seq_len The maximum length of the sequence to generate, | ||
* including prompt. | ||
*/ | ||
void generate(const std::string& prompt, std::size_t max_seq_len); | ||
|
||
private: | ||
uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor); | ||
uint64_t prefill(std::vector<uint64_t>& tokens); | ||
uint64_t run_model_step(uint64_t token); | ||
|
||
std::unique_ptr<Module> module_; | ||
std::unique_ptr<Tokenizer> tokenizer_; | ||
std::unique_ptr<Sampler> sampler_; | ||
}; | ||
|
||
} // namespace torch::executor |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters