diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 0839d58428..5339817c1f 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(cpp/greedy_causal_lm) add_subdirectory(cpp/multinomial_causal_lm) add_subdirectory(cpp/prompt_lookup_decoding_lm) add_subdirectory(cpp/speculative_decoding_lm) +add_subdirectory(cpp/benchmark_genai) install(FILES requirements.txt DESTINATION samples COMPONENT cpp_samples_genai) diff --git a/samples/cpp/benchmark_genai/CMakeLists.txt b/samples/cpp/benchmark_genai/CMakeLists.txt new file mode 100644 index 0000000000..5443439de5 --- /dev/null +++ b/samples/cpp/benchmark_genai/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +find_package(OpenVINOGenAI REQUIRED PATHS + "${CMAKE_BINARY_DIR}" # Reuse the package from the build. + ${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO. +) + +FetchContent_Declare(cxxopts + URL https://github.com/jarro2783/cxxopts/archive/refs/tags/v3.1.1.tar.gz + URL_HASH SHA256=523175f792eb0ff04f9e653c90746c12655f10cb70f1d5e6d6d9491420298a08) +FetchContent_MakeAvailable(cxxopts) + +add_executable(benchmark_genai benchmark_genai.cpp) +target_link_libraries(benchmark_genai PRIVATE openvino::genai cxxopts::cxxopts) +set_target_properties(benchmark_genai PROPERTIES + COMPILE_PDB_NAME benchmark_genai + # Ensure out of box LC_RPATH on macOS with SIP + INSTALL_RPATH_USE_LINK_PATH ON) +install(TARGETS benchmark_genai + RUNTIME DESTINATION samples_bin/ + COMPONENT samples_bin + EXCLUDE_FROM_ALL) diff --git a/samples/cpp/benchmark_genai/README.md b/samples/cpp/benchmark_genai/README.md new file mode 100644 index 0000000000..616bb6a36d --- /dev/null +++ b/samples/cpp/benchmark_genai/README.md @@ -0,0 +1,47 @@ +# LLMs benchmarking sample + +This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. + +## Download and convert the model and tokenizers + +The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. + +It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported. + +```sh +pip install --upgrade-strategy eager -r ../../requirements.txt +optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 +``` + +## Usage + +```sh +benchmark_vanilla_genai [OPTIONS] +``` + +### Options + +- `-m, --model`: Path to the model and tokenizers base directory. +- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text. +- `-nw, --num_warmup` (default: `1`): Number of warmup iterations. +- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations. +- `-n, --num_iter` (default: `3`): Number of iterations. +- `-d, --device` (default: `"CPU"`): Device to run the model on. + +### Output: + +``` +benchmark_vanilla_genai -m TinyLlama-1.1B-Chat-v1.0 -n 10 +``` + +``` +Load time: 3405.69 ms +Generate time: 1430.77 ± 3.04 ms +Tokenization time: 0.51 ± 0.02 ms +Detokenization time: 0.37 ± 0.01 ms +TTFT: 81.60 ± 0.54 ms +TPOT: 71.52 ± 2.72 ms +Throughput tokens/s: 13.98 ± 0.53 +``` + +For more information how performance metrics are calculated please follow [performance-metrics tutorial](../../../src/README.md#performance-metrics). diff --git a/samples/cpp/benchmark_genai/benchmark_genai.cpp b/samples/cpp/benchmark_genai/benchmark_genai.cpp new file mode 100644 index 0000000000..287d6b379a --- /dev/null +++ b/samples/cpp/benchmark_genai/benchmark_genai.cpp @@ -0,0 +1,70 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/llm_pipeline.hpp" +#include + +int main(int argc, char* argv[]) try { + cxxopts::Options options("benchmark_vanilla_genai", "Help command"); + + options.add_options() + ("m,model", "Path to model and tokenizers base directory", cxxopts::value()->default_value(".")) + ("p,prompt", "Prompt", cxxopts::value()->default_value("The Sky is blue because")) + ("nw,num_warmup", "Number of warmup iterations", cxxopts::value()->default_value(std::to_string(1))) + ("n,num_iter", "Number of iterations", cxxopts::value()->default_value(std::to_string(3))) + ("mt,max_new_tokens", "Maximal number of new tokens", cxxopts::value()->default_value(std::to_string(20))) + ("d,device", "device", cxxopts::value()->default_value("CPU")) + ("h,help", "Print usage"); + + cxxopts::ParseResult result; + try { + result = options.parse(argc, argv); + } catch (const cxxopts::exceptions::exception& e) { + std::cout << e.what() << "\n\n"; + std::cout << options.help() << std::endl; + return EXIT_FAILURE; + } + + if (result.count("help")) { + std::cout << options.help() << std::endl; + return EXIT_SUCCESS; + } + + std::string prompt = result["prompt"].as(); + const std::string model_path = result["model"].as(); + std::string device = result["device"].as(); + size_t num_warmup = result["num_warmup"].as(); + size_t num_iter = result["num_iter"].as(); + + ov::genai::GenerationConfig config; + config.max_new_tokens = result["max_new_tokens"].as(); + + ov::genai::LLMPipeline pipe(model_path, device); + + for (size_t i = 0; i < num_warmup; i++) + pipe.generate(prompt, config); + + ov::genai::DecodedResults res = pipe.generate(prompt, config); + ov::genai::PerfMetrics metrics = res.perf_metrics; + for (size_t i = 0; i < num_iter - 1; i++) { + res = pipe.generate(prompt, config); + metrics = metrics + res.perf_metrics; + } + + std::cout << std::fixed << std::setprecision(2); + std::cout << "Load time: " << metrics.get_load_time() << " ms" << std::endl; + std::cout << "Generate time: " << metrics.get_generate_duration().mean << " ± " << metrics.get_generate_duration().std << " ms" << std::endl; + std::cout << "Tokenization time: " << metrics.get_tokenization_duration().mean << " ± " << metrics.get_tokenization_duration().std << " ms" << std::endl; + std::cout << "Detokenization time: " << metrics.get_detokenization_duration().mean << " ± " << metrics.get_detokenization_duration().std << " ms" << std::endl; + std::cout << "TTFT: " << metrics.get_ttft().mean << " ± " << metrics.get_ttft().std << " ms" << std::endl; + std::cout << "TPOT: " << metrics.get_tpot().mean << " ± " << metrics.get_tpot().std << " ms/token " << std::endl; + std::cout << "Throughput: " << metrics.get_throughput().mean << " ± " << metrics.get_throughput().std << " tokens/s" << std::endl; + + return 0; +} catch (const std::exception& error) { + std::cerr << error.what() << '\n'; + return EXIT_FAILURE; +} catch (...) { + std::cerr << "Non-exception object thrown\n"; + return EXIT_FAILURE; +} diff --git a/samples/python/benchmark_genai/README.md b/samples/python/benchmark_genai/README.md new file mode 100644 index 0000000000..1ff9ef4305 --- /dev/null +++ b/samples/python/benchmark_genai/README.md @@ -0,0 +1,47 @@ +# LLMs benchmarking sample + +This sample script demonstrates how to benchmark an LLMs in OpenVINO GenAI. The script includes functionality for warm-up iterations, generating text, and calculating various performance metrics. + +## Download and convert the model and tokenizers + +The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version. + +It's not required to install [../../requirements.txt](../../requirements.txt) for deployment if the model has already been exported. + +```sh +pip install --upgrade-strategy eager -r ../../requirements.txt +optimum-cli export openvino --trust-remote-code --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 TinyLlama-1.1B-Chat-v1.0 +``` + +## Usage + +```sh +python benchmark_vanilla_genai.py [OPTIONS] +``` + +### Options + +- `-m, --model`: Path to the model and tokenizers base directory. +- `-p, --prompt` (default: `"The Sky is blue because"`): The prompt to generate text. +- `-nw, --num_warmup` (default: `1`): Number of warmup iterations. +- `-n, --num_iter` (default: `3`): Number of iterations. +- `-mt, --max_new_tokens` (default: `20`): Number of warmup iterations. +- `-d, --device` (default: `"CPU"`): Device to run the model on. + +### Output: + +``` +python benchmark_vanilla_genai.py -m TinyLlama-1.1B-Chat-v1.0 -n 10 +``` + +``` +Load time: 3405.69 ms +Generate time: 1430.77 ± 3.04 ms +Tokenization time: 0.51 ± 0.02 ms +Detokenization time: 0.37 ± 0.01 ms +TTFT: 81.60 ± 0.54 ms +TPOT: 71.52 ± 2.72 ms +Throughput tokens/s: 13.98 ± 0.53 +``` + +For more information on how performance metrics are calculated, see [performance metrics readme](../../../src/README.md#performance-metrics). diff --git a/samples/python/benchmark_genai/benchmark_genai.py b/samples/python/benchmark_genai/benchmark_genai.py new file mode 100755 index 0000000000..9851483880 --- /dev/null +++ b/samples/python/benchmark_genai/benchmark_genai.py @@ -0,0 +1,49 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai as ov_genai + +def main(): + parser = argparse.ArgumentParser(description="Help command") + parser.add_argument("-m", "--model", type=str, help="Path to model and tokenizers base directory") + parser.add_argument("-p", "--prompt", type=str, default="The Sky is blue because", help="Prompt") + parser.add_argument("-nw", "--num_warmup", type=int, default=1, help="Number of warmup iterations") + parser.add_argument("-n", "--num_iter", type=int, default=2, help="Number of iterations") + parser.add_argument("-mt", "--max_new_tokens", type=int, default=20, help="Maximal number of new tokens") + parser.add_argument("-d", "--device", type=str, default="CPU", help="Device") + + args = parser.parse_args() + + # Perf metrics is stored in DecodedResults. + # In order to get DecodedResults instead of a string input should be a list. + prompt = [args.prompt] + model_path = args.model + device = args.device + num_warmup = args.num_warmup + num_iter = args.num_iter + + config = ov_genai.GenerationConfig() + config.max_new_tokens = args.max_new_tokens + + pipe = ov_genai.LLMPipeline(model_path, device) + + for _ in range(num_warmup): + pipe.generate(prompt, config) + + res = pipe.generate(prompt, config) + perf_metrics = res.perf_metrics + for _ in range(num_iter - 1): + res = pipe.generate(prompt, config) + perf_metrics += res.perf_metrics + + print(f"Load time: {perf_metrics.get_load_time():.2f} ms") + print(f"Generate time: {perf_metrics.get_generate_duration().mean:.2f} ± {perf_metrics.get_generate_duration().std:.2f} ms") + print(f"Tokenization time: {perf_metrics.get_tokenization_duration().mean:.2f} ± {perf_metrics.get_tokenization_duration().std:.2f} ms") + print(f"Detokenization time: {perf_metrics.get_detokenization_duration().mean:.2f} ± {perf_metrics.get_detokenization_duration().std:.2f} ms") + print(f"TTFT: {perf_metrics.get_ttft().mean:.2f} ± {perf_metrics.get_ttft().std:.2f} ms") + print(f"TPOT: {perf_metrics.get_tpot().mean:.2f} ± {perf_metrics.get_tpot().std:.2f} ms") + print(f"Throughput : {perf_metrics.get_throughput().mean:.2f} ± {perf_metrics.get_throughput().std:.2f} tokens/s") + +if __name__ == "__main__": + main() diff --git a/src/README.md b/src/README.md index b404794977..e88c2f784f 100644 --- a/src/README.md +++ b/src/README.md @@ -196,6 +196,97 @@ int main(int argc, char* argv[]) { } ``` +### Performance Metrics + +`openvino_genai.PerfMetrics` (referred as `PerfMetrics` for simplicity) is a structure that holds performance metrics for each generate call. `PerfMetrics` holds fields with mean and standard deviations for the following metrics: +- Time To the First Token (TTFT), ms +- Time per Output Token (TPOT), ms/token +- Generate total duration, ms +- Tokenization duration, ms +- Detokenization duration, ms +- Throughput, tokens/s + +and: +- Load time, ms +- Number of generated tokens +- Number of tokens in the input prompt + +Performance metrics are stored either in the `DecodedResults` or `EncodedResults` `perf_metric` field. Additionally to the fields mentioned above, `PerfMetrics` has a member `raw_metrics` of type `openvino_genai.RawPerfMetrics` (referred to as `RawPerfMetrics` for simplicity) that contains raw values for the durations of each batch of new token generation, tokenization durations, detokenization durations, and more. These raw metrics are accessible if you wish to calculate your own statistical values such as median or percentiles. However, since mean and standard deviation values are usually sufficient, we will focus on `PerfMetrics`. + +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") +result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +perf_metrics = result.perf_metrics + +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') +print(f'TTFT: {perf_metrics.get_ttft().mean:.2f} ms') +print(f'TPOT: {perf_metrics.get_tpot().mean:.2f} ms/token') +print(f'Throughput: {perf_metrics.get_throughput()get_.mean():.2f} tokens/s') +``` + +```cpp +#include "openvino/genai/llm_pipeline.hpp" +#include + +int main(int argc, char* argv[]) { + std::string model_path = argv[1]; + ov::genai::LLMPipeline pipe(model_path, "CPU"); + auto result = pipe.generate("The Sun is yellow because", ov::genai::max_new_tokens(20)); + auto perf_metrics = result.perf_metrics; + + std::cout << std::fixed << std::setprecision(2); + std::cout << "Generate duration: " << perf_metrics.get_generate_duration().mean << " ms" << std::endl; + std::cout << "TTFT: " << metrics.get_ttft().mean << " ms" << std::endl; + std::cout << "TPOT: " << metrics.get_tpot().mean << " ms/token " << std::endl; + std::cout << "Throughput: " << metrics.get_throughput().mean << " tokens/s" << std::endl; +} +``` +output: +```sh +mean_generate_duration: 76.28 +mean_ttft: 42.58 +mean_tpot 3.80 +``` + +>**Note**: If the input prompt is just a string, the generate function returns only a string without perf_metrics. To obtain perf_metrics, provide the prompt as a list with at least one element or call generate with encoded inputs. + +Several `perf_metrics` can be added to each other. In that case `raw_metrics` are concatenated and mean/std values are recalculated. This accumulates statistics from several `generate()` calls + +```cpp +#include "openvino/genai/llm_pipeline.hpp" +#include + +int main(int argc, char* argv[]) { + std::string model_path = argv[1]; + ov::genai::LLMPipeline pipe(model_path, "CPU"); + auto result_1 = pipe.generate("The Sun is yellow because", ov::genai::max_new_tokens(20)); + auto result_2 = pipe.generate("The Sun is yellow because", ov::genai::max_new_tokens(20)); + auto perf_metrics = result_1.perf_metrics + result_2.perf_metrics + + std::cout << std::fixed << std::setprecision(2); + std::cout << "Generate duration: " << perf_metrics.get_generate_duration().mean << " ms" << std::endl; + std::cout << "TTFT: " << metrics.get_ttft().mean << " ms" << std::endl; + std::cout << "TPOT: " << metrics.get_tpot().mean << " ms/token " << std::endl; + std::cout << "Throughput: " << metrics.get_throughput().mean << " tokens/s" << std::endl; +} +``` + +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") +res_1 = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +res_2 = pipe.generate(["Why Sky is blue because"], max_new_tokens=20) +perf_metrics = res_1.perf_metrics + res_2.perf_metrics + +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') +print(f'TTFT: {perf_metrics.get_ttft().mean:.2f} ms') +print(f'TPOT: {perf_metrics.get_tpot().mean:.2f} ms/token') +print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') +``` + +For more examples of how metrics are used, please refer to the Python [benchmark_genai.py](https://github.com/openvinotoolkit/openvino.genai/tree/releases/2024/3/samples/python/benchmark_genai/README.md) and C++ [benchmark_genai](https://github.com/openvinotoolkit/openvino.genai/tree/releases/2024/3/samples/cpp/benchmark_genai/README.md) samples. + ## How It Works For information on how OpenVINO™ GenAI works, refer to the [How It Works Section](https://github.com/openvinotoolkit/openvino.genai/tree/releases/2024/2/src/docs/HOW_IT_WORKS.md). diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index abd4ee5a44..e79a6e65f0 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -5,11 +5,13 @@ #include #include +#include #include "openvino/core/any.hpp" #include "openvino/genai/generation_config.hpp" #include "openvino/genai/tokenizer.hpp" #include "openvino/genai/streamer_base.hpp" +#include "openvino/genai/perf_metrics.hpp" namespace ov { namespace genai { @@ -29,11 +31,13 @@ using StringInputs = std::variant>; * * @param tokens sequence of resulting tokens * @param scores sum of logarithmic probabilities of all tokens in the sequence +* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics */ class EncodedResults { public: std::vector> tokens; std::vector scores; + PerfMetrics perf_metrics; }; /** @@ -42,11 +46,13 @@ class EncodedResults { * * @param texts vector of resulting sequences * @param scores scores for each sequence +* @param metrics performance metrics with tpot, ttft, etc. of type ov::genai::PerfMetrics */ class DecodedResults { public: std::vector texts; std::vector scores; + PerfMetrics perf_metrics; // @brief Convert DecodedResults to a string. operator std::string() const { diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp new file mode 100644 index 0000000000..ddb9ff581f --- /dev/null +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -0,0 +1,97 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include "openvino/genai/visibility.hpp" +#include +#include +#include + +namespace ov { +namespace genai { + +using TimePoint = std::chrono::steady_clock::time_point; +using MicroSeconds = std::chrono::duration>; + +/** +* @brief Structure with raw performance metrics for each generation before any statistics calculated. +*/ +struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { + std::vector generate_durations; + std::vector tokenization_durations; + std::vector detokenization_durations; + + std::vector m_times_to_first_token; + std::vector m_new_token_times; + std::vector m_batch_sizes; + std::vector m_durations; + + size_t num_generated_tokens; + size_t num_input_tokens; +}; + +/** +* @brief Structure to store mean and standart deviation values. +*/ +struct OPENVINO_GENAI_EXPORTS MeanStdPair { + float mean; + float std; +}; + +/** +* @brief Structure to store performance metric for each generation. +* +* @param +*/ +struct OPENVINO_GENAI_EXPORTS PerfMetrics { + float load_time; // Load time in ms. + MeanStdPair ttft; // Time to the first token (in ms) (TTTFT). + MeanStdPair tpot; // Time (in ms) per output token (TPOT). + MeanStdPair throughput; // Tokens per second. + + MeanStdPair generate_duration; + MeanStdPair tokenization_duration = {-1, -1}; + MeanStdPair detokenization_duration = {-1. -1}; + + size_t num_generated_tokens; + size_t num_input_tokens; + + float get_load_time(); // Load time in ms. + float get_num_generated_tokens(); + float get_num_input_tokens(); + MeanStdPair get_ttft(); // Time to the first token (in ms) (TTTFT). + MeanStdPair get_tpot(); // Time (in ms) per output token (TPOT). + MeanStdPair get_throughput(); // Tokens per second. + + MeanStdPair get_generate_duration(); + MeanStdPair get_tokenization_duration(); + MeanStdPair get_detokenization_duration(); + + // Flag indicating if raw metrics were evaluated. + // If false means current mean/std ttft, tpot, etc. are not actual + // and evaluate_statistics() should recalculate them. + bool m_evaluated = false; + + /** + * @brief calculates mean/std values from raw_metrics. + * + * @param start_time optional start_time in case if duration needs to be updated. + */ + void evaluate_statistics(std::optional start_time = std::nullopt); + + /** + * @brief convert duration to microseconds + * + * @param duration duration in + */ + static float get_microsec(std::chrono::steady_clock::duration duration); + PerfMetrics operator+(const PerfMetrics& metrics) const; + PerfMetrics& operator+=(const PerfMetrics& right); + + RawPerfMetrics raw_metrics; +}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/greedy_decoding.cpp b/src/cpp/src/greedy_decoding.cpp index 9170c7d2f9..8dc56b4ba8 100644 --- a/src/cpp/src/greedy_decoding.cpp +++ b/src/cpp/src/greedy_decoding.cpp @@ -1,7 +1,7 @@ // Copyright (C) 2023-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 -#include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/perf_metrics.hpp" #include "utils.hpp" namespace ov { @@ -19,12 +19,16 @@ EncodedResults greedy_decoding( const size_t batch_size = prompts_shape[0]; size_t running_batch_size = batch_size; size_t prompt_len = prompts_shape[1]; + size_t max_new_tokens = generation_config.get_max_new_tokens(prompt_len); + // Initialize results and performance metrics. EncodedResults results; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; + results.scores.resize(running_batch_size); results.tokens.resize(running_batch_size); std::fill(results.scores.begin(), results.scores.end(), 0); - + m_model_runner.set_tensor("input_ids", input_ids); m_model_runner.set_tensor("attention_mask", attention_mask); if (position_ids.has_value()) @@ -50,6 +54,9 @@ EncodedResults greedy_decoding( eos_met[batch] = (out_token == generation_config.eos_token_id); m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); + if (streamer && streamer->put(token_iter_results[0])) { return results; } @@ -58,8 +65,8 @@ EncodedResults greedy_decoding( if (!generation_config.ignore_eos && all_are_eos) return results; - size_t max_tokens = generation_config.get_max_new_tokens(prompt_len); - for (size_t i = 0; i < max_tokens - 1; ++i) { + + for (size_t i = 0; i < max_new_tokens - 1; ++i) { if (position_ids.has_value()) utils::update_position_ids(m_model_runner.get_tensor("position_ids"), m_model_runner.get_tensor("attention_mask")); m_model_runner.set_tensor("attention_mask", utils::extend_attention(m_model_runner.get_tensor("attention_mask"))); @@ -80,6 +87,8 @@ EncodedResults greedy_decoding( m_model_runner.get_tensor("input_ids").data()[batch] = out_token; } + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); if (streamer && streamer->put(token_iter_results[0])) return results; @@ -106,8 +115,9 @@ EncodedResults greedy_decoding( if (streamer) { streamer->end(); } + return results; } } //namespace genai -} //namespace ov \ No newline at end of file +} //namespace ov diff --git a/src/cpp/src/group_beam_searcher.cpp b/src/cpp/src/group_beam_searcher.cpp index 8695aeac02..1b9729b2f6 100644 --- a/src/cpp/src/group_beam_searcher.cpp +++ b/src/cpp/src/group_beam_searcher.cpp @@ -362,14 +362,15 @@ std::pair beam_search(ov::InferRequest& lm, std::optional selected_beam_idx) { OPENVINO_ASSERT(config.num_beams % config.num_beam_groups == 0, "number of beams should be divisible by number of groups"); - - // Initialize beam search + auto batch_size = input_ids.get_shape().at(0); + auto sequence_length = input_ids.get_shape().at(1); + + // Initialize beam search. const int64_t* prompt_data = input_ids.data(); std::vector> prompts; prompts.reserve(batch_size); for (size_t batch = 0; batch < batch_size; batch++) { - size_t sequence_length = input_ids.get_shape().at(1); size_t batch_offset = batch * sequence_length; const int64_t* prompt_start = prompt_data + batch_offset; prompts.push_back(std::vector{prompt_start, prompt_start + sequence_length}); @@ -389,7 +390,7 @@ std::pair beam_search(ov::InferRequest& lm, lm.set_tensor("beam_idx", beam_idx); Parameters parameters{std::move(prompts)}; - parameters.max_new_tokens = config.max_new_tokens; + parameters.max_new_tokens = config.get_max_new_tokens(sequence_length); parameters.eos_token_id = config.eos_token_id; parameters.n_groups = config.num_beam_groups; parameters.group_size = config.num_beams / config.num_beam_groups; @@ -401,11 +402,20 @@ std::pair beam_search(ov::InferRequest& lm, std::vector next_tokens; std::vector next_beams; - + + // Reserve for performance counters. + std::vector new_token_times; + std::vector batch_sizes; + new_token_times.reserve(parameters.max_new_tokens); + batch_sizes.reserve(parameters.max_new_tokens); + for (size_t length_count = 0; ; ++length_count) { lm.infer(); std::tie(next_tokens, next_beams) = group_beam_searcher.select_next_tokens(lm.get_tensor("logits")); + new_token_times.emplace_back(std::chrono::steady_clock::now()); + batch_sizes.emplace_back(batch_size); + if (next_tokens.empty() || length_count == parameters.max_new_tokens - 1) { // Break the cycle before masks are extended in update_attention_mask_with_beams. // If generation is continued, attention_mask length should be equal to KV cache size. @@ -434,6 +444,9 @@ std::pair beam_search(ov::InferRequest& lm, int32_t res_selected_beam_idx = 0; results.scores.reserve(config.num_return_sequences * result.size()); results.tokens.reserve(config.num_return_sequences * result.size()); + auto& raw_perf_counters = results.perf_metrics.raw_metrics; + raw_perf_counters.m_new_token_times = new_token_times; + raw_perf_counters.m_batch_sizes = batch_sizes; // align output with HF for (size_t prompt_id = 0; prompt_id < result.size(); prompt_id++) { @@ -462,7 +475,7 @@ std::pair beam_search(ov::InferRequest& lm, results.tokens.push_back(std::move(beam->get().tokens)); } } - + return {results, res_selected_beam_idx}; } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 1d68d4c746..8505daf3b2 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -10,6 +10,7 @@ #include "openvino/genai/continuous_batching_pipeline.hpp" #include "openvino/genai/generation_config.hpp" #include "openvino/genai/llm_pipeline.hpp" +#include "openvino/genai/perf_metrics.hpp" #include "llm_pipeline_base.hpp" #include "llm_pipeline_static.hpp" #include "utils.hpp" @@ -111,8 +112,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + auto start_time = std::chrono::steady_clock::now(); GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config; - EncodedInputs encoded_input; + TokenizedInputs encoded_input; if (auto input_vector = std::get_if>(&inputs)) { OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts"); @@ -145,9 +147,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = m_tokenizer.encode(prompt); } } + auto encode_stop_time = std::chrono::steady_clock::now(); + auto encoded_results = generate(encoded_input, config, streamer); - auto encoded_results = generate(encoded_input, config, streamer); + auto decode_start_time = std::chrono::steady_clock::now(); DecodedResults decoded_results = {m_tokenizer.decode(encoded_results.tokens), encoded_results.scores}; + auto decode_stop_time = std::chrono::steady_clock::now(); if (is_chat_conversation) { // Tail of chat template is missing in KV cache. @@ -157,6 +162,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { m_history.push_back({{"role", "assistant"}, {"content", answer}}); } + // generate_durations + decoded_results.perf_metrics = encoded_results.perf_metrics; + + auto& raw_counters = decoded_results.perf_metrics.raw_metrics; + auto stop_time = std::chrono::steady_clock::now(); + raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); + raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); + + decoded_results.perf_metrics.evaluate_statistics(start_time); return decoded_results; } @@ -165,9 +181,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { OptionalGenerationConfig generation_config, StreamerVariant streamer ) override { + auto start_time = std::chrono::steady_clock::now(); ov::Tensor input_ids; ov::Tensor attention_mask; - if (auto data = std::get_if(&inputs)) { input_ids = *data; attention_mask = ov::genai::utils::init_attention_mask(input_ids); @@ -255,7 +271,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { } else { m_is_cache_empty = false; } - + auto stop_time = std::chrono::steady_clock::now(); + + // If is called without tokenization then that stat will not be reported. + auto& metrics = result.perf_metrics; + metrics.num_input_tokens = batch_size * input_ids.get_shape().at(1); + metrics.load_time = this->m_load_time_ms; + metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); + metrics.evaluate_statistics(start_time); return result; } @@ -487,7 +510,10 @@ ov::genai::LLMPipeline::LLMPipeline( const ov::genai::Tokenizer& tokenizer, OptionalGenerationConfig generation_config ) { + auto start_time = std::chrono::steady_clock::now(); m_pimpl = std::make_unique(request, tokenizer, generation_config); + auto stop_time = std::chrono::steady_clock::now(); + m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); } ov::genai::LLMPipeline::LLMPipeline( @@ -495,27 +521,35 @@ ov::genai::LLMPipeline::LLMPipeline( const ov::genai::Tokenizer& tokenizer, const std::string& device, const ov::AnyMap& plugin_config -): m_pimpl{[&]() -> std::unique_ptr { +){ + auto start_time = std::chrono::steady_clock::now(); if ("CB" == device) { - return std::make_unique(model_path, tokenizer, "CPU", plugin_config); - } if ("NPU" == device) { - return std::make_unique(model_path, tokenizer, device, plugin_config); + m_pimpl = std::make_unique(model_path, tokenizer, "CPU", plugin_config); + } else if ("NPU" == device) { + m_pimpl = std::make_unique(model_path, tokenizer, device, plugin_config); + } else { + m_pimpl = std::make_unique(model_path, tokenizer, device, plugin_config); } - return std::make_unique(model_path, tokenizer, device, plugin_config); -}()} {} + auto stop_time = std::chrono::steady_clock::now(); + m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); +} ov::genai::LLMPipeline::LLMPipeline( const std::string& path, const std::string& device, const ov::AnyMap& config -): m_pimpl{[&]() -> std::unique_ptr { +){ + auto start_time = std::chrono::steady_clock::now(); if ("CB" == device) { - return std::make_unique(path, "CPU", config); - } if ("NPU" == device) { - return std::make_unique(path, device, config); + m_pimpl = std::make_unique(path, "CPU", config); + } else if ("NPU" == device) { + m_pimpl = std::make_unique(path, device, config); + } else { + m_pimpl = std::make_unique(path, device, config); } - return std::make_unique(path, device, config); -}()} {} + auto stop_time = std::chrono::steady_clock::now(); + m_pimpl->m_load_time_ms = std::chrono::duration_cast(stop_time - start_time).count(); +} ov::genai::GenerationConfig ov::genai::LLMPipeline::get_generation_config() const { return m_pimpl->m_generation_config; diff --git a/src/cpp/src/llm_pipeline_base.hpp b/src/cpp/src/llm_pipeline_base.hpp index 9df6442b35..7e58cd3b37 100644 --- a/src/cpp/src/llm_pipeline_base.hpp +++ b/src/cpp/src/llm_pipeline_base.hpp @@ -36,6 +36,8 @@ class LLMPipelineImplBase { Tokenizer m_tokenizer; GenerationConfig m_generation_config; + + float m_load_time_ms = 0; }; } // namespace genai diff --git a/src/cpp/src/multinomial_decoding.cpp b/src/cpp/src/multinomial_decoding.cpp index fd16e948c1..b00c62aed7 100644 --- a/src/cpp/src/multinomial_decoding.cpp +++ b/src/cpp/src/multinomial_decoding.cpp @@ -162,7 +162,9 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner size_t prompt_len = prompts_shape[1]; - ov::genai::EncodedResults results; + // Initialize results and performance metrics. + EncodedResults results; + auto& raw_perf_counters = results.perf_metrics.raw_metrics; results.scores.resize(batch_size, 0); results.tokens.resize(batch_size); @@ -179,6 +181,8 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner m_model_runner.get_tensor("beam_idx").data()[0] = 0; m_model_runner.infer(); + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); auto logits_tensor = m_model_runner.get_tensor("logits"); @@ -222,6 +226,8 @@ ov::genai::EncodedResults multinominal_decoding(ov::InferRequest& m_model_runner m_model_runner.get_tensor("input_ids").data()[0] = out_token.id; m_model_runner.infer(); + raw_perf_counters.m_new_token_times.emplace_back(std::chrono::steady_clock::now()); + raw_perf_counters.m_batch_sizes.emplace_back(batch_size); logits = m_model_runner.get_tensor("logits").data(); out_token = sampling.get_out_token(logits, vocab_size, tokens); diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp new file mode 100644 index 0000000000..2f378ab302 --- /dev/null +++ b/src/cpp/src/perf_metrics.cpp @@ -0,0 +1,164 @@ +// Copyright (C) 2023-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/perf_metrics.hpp" +#include "openvino/openvino.hpp" +#include +#include +#include + +namespace { + +ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { + // Accepts time durations in microseconds and returns standard deviation and mean in milliseconds. + float mean = std::accumulate(durations.begin(), durations.end(), 0.0f, + [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { + return acc + duration.count() / 1000.0f; + }); + mean /= durations.size(); + + float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, + [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { + auto d = duration.count() / 1000.0f; + return acc + d * d; + }); + float std = std::sqrt(sum_square_durations / durations.size() - mean * mean); + return {mean, std}; +} + + +} // namespace + +namespace ov { +namespace genai { + +float PerfMetrics::get_load_time() { + return load_time; +} + +float PerfMetrics::get_num_generated_tokens() { + evaluate_statistics(); + return num_generated_tokens; +} + +float PerfMetrics::get_num_input_tokens() { + evaluate_statistics(); + return num_generated_tokens; +} + +MeanStdPair PerfMetrics::get_ttft() { + evaluate_statistics(); + return ttft; +} + +MeanStdPair PerfMetrics::get_tpot() { + evaluate_statistics(); + return tpot; +} + +MeanStdPair PerfMetrics::get_throughput() { + evaluate_statistics(); + return throughput; +} + +MeanStdPair PerfMetrics::get_generate_duration() { + evaluate_statistics(); + return generate_duration; +} + +MeanStdPair PerfMetrics::get_tokenization_duration() { + evaluate_statistics(); + return tokenization_duration; +} + +MeanStdPair PerfMetrics::get_detokenization_duration() { + evaluate_statistics(); + return detokenization_duration; +} + +float PerfMetrics::get_microsec(std::chrono::steady_clock::duration duration) { + return std::chrono::duration_cast(duration).count(); +} + +void PerfMetrics::evaluate_statistics(std::optional start_time) { + if (m_evaluated){ + return; + } + // If start_tiem is specified then recalcualte durations according to start times and calculate statistics only after that. + if (start_time.has_value()) { + auto start_time_val = *start_time; + auto& tok_times = raw_metrics.m_new_token_times; + auto& batch_sizes = raw_metrics.m_batch_sizes; + raw_metrics.m_durations = std::vector(tok_times.size()); + + auto ttft = tok_times[0] - start_time_val; + raw_metrics.m_times_to_first_token = std::vector(); + raw_metrics.m_times_to_first_token.emplace_back(ttft); + num_generated_tokens = 0; + for (size_t i = 0; i < tok_times.size(); ++i) { + raw_metrics.m_durations[i] = tok_times[i] - start_time_val; + + // If in 10 ms a batch of 5 new tokens is generated then TPOT is 10 / 5 = 2 tok/ms. + raw_metrics.m_durations[i] /= batch_sizes[i]; + num_generated_tokens += batch_sizes[i]; + start_time_val = tok_times[i]; + } + } + + // calc_mean_and_std will convert microsecond to milliseconds. + tpot = calc_mean_and_std(raw_metrics.m_durations); + ttft = calc_mean_and_std(raw_metrics.m_times_to_first_token); + + generate_duration = calc_mean_and_std(raw_metrics.generate_durations); + tokenization_duration = calc_mean_and_std(raw_metrics.tokenization_durations); + detokenization_duration = calc_mean_and_std(raw_metrics.detokenization_durations); + + // tokens per second + throughput = {1000.0f / tpot.mean, (tpot.std * 1000.0f) / (tpot.mean * tpot.mean)}; + m_evaluated = true; +} + +PerfMetrics PerfMetrics::operator+(const PerfMetrics& right) const { + OPENVINO_ASSERT(right.load_time == load_time, "generation metrics can be accumulated only for the same pipeline"); + + // Copy left value to res. + PerfMetrics res = *this; + + // Concatenate durations, batch_sizes first token times. + auto& new_durations = res.raw_metrics.m_durations; + auto& new_batch_sizes = res.raw_metrics.m_batch_sizes; + auto& new_times_to_first_token = res.raw_metrics.m_times_to_first_token; + auto& right_durations = right.raw_metrics.m_durations; + auto& right_batch_sizes = right.raw_metrics.m_batch_sizes; + auto& right_times_to_first_token = right.raw_metrics.m_times_to_first_token; + + new_durations.insert(new_durations.end(), right_durations.begin(), right_durations.end()); + new_times_to_first_token.insert(new_times_to_first_token.end(), right_times_to_first_token.begin(), right_times_to_first_token.end()); + new_batch_sizes.insert(new_batch_sizes.end(), right_batch_sizes.begin(), right_batch_sizes.end()); + + // Concatenate tokenization/detokenization and total generation times. + auto& new_tok_durations = res.raw_metrics.tokenization_durations; + auto& new_detok_durations = res.raw_metrics.detokenization_durations; + auto& new_gen_durations = res.raw_metrics.generate_durations; + auto& right_tok_durations = right.raw_metrics.tokenization_durations; + auto& right_detok_durations = right.raw_metrics.detokenization_durations; + auto& right_gen_durations = right.raw_metrics.generate_durations; + + new_tok_durations.insert(new_tok_durations.end(), right_tok_durations.begin(), right_tok_durations.end()); + new_detok_durations.insert(new_detok_durations.end(), right_detok_durations.begin(), right_detok_durations.end()); + new_gen_durations.insert(new_gen_durations.end(), right_gen_durations.begin(), right_gen_durations.end()); + + res.num_generated_tokens = num_generated_tokens + right.num_generated_tokens; + res.num_input_tokens = num_generated_tokens + right.num_input_tokens; + res.load_time = load_time; + res.m_evaluated = false; + return res; +} + +PerfMetrics& PerfMetrics::operator+=(const PerfMetrics& right) { + *this = *this + right; + return *this; +} + +} // namespace genai +} // namespace ov diff --git a/src/python/py_generate_pipeline.cpp b/src/python/py_generate_pipeline.cpp index 6175001c29..031c8fb97b 100644 --- a/src/python/py_generate_pipeline.cpp +++ b/src/python/py_generate_pipeline.cpp @@ -20,7 +20,10 @@ using ov::genai::EncodedResults; using ov::genai::GenerationConfig; using ov::genai::GenerationResult; using ov::genai::LLMPipeline; +using ov::genai::MeanStdPair; using ov::genai::OptionalGenerationConfig; +using ov::genai::PerfMetrics; +using ov::genai::RawPerfMetrics; using ov::genai::SchedulerConfig; using ov::genai::StopCriteria; using ov::genai::StreamerBase; @@ -36,6 +39,17 @@ using PyBindStreamerVariant = std::variant, std::sh template struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; +template +std::vector get_ms(const T& instance, U T::*member) { + // Converts c++ duration to float so that it can be used in Python. + std::vector res; + const auto& durations = instance.*member; + res.reserve(durations.size()); + std::transform(durations.begin(), durations.end(), std::back_inserter(res), + [](const auto& duration) { return duration.count(); }); + return res; +} + namespace { auto generate_docstring = R"( @@ -563,7 +577,45 @@ PYBIND11_MODULE(py_generate_pipeline, m) { .def(py::init<>()) .def_property_readonly("texts", [](const DecodedResults &dr) { return handle_utf8_results(dr); }) .def_readonly("scores", &DecodedResults::scores) - .def("__str__", &DecodedResults::operator std::string);; + .def_readonly("perf_metrics", &DecodedResults::perf_metrics) + .def("__str__", &DecodedResults::operator std::string); + + py::class_(m, "RawPerfMetrics") + .def(py::init<>()) + .def_readonly("generate_durations", &RawPerfMetrics::generate_durations) + .def_property_readonly("tokenization_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::tokenization_durations); + }) + .def_property_readonly("detokenization_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::detokenization_durations); + }) + .def_property_readonly("m_times_to_first_token", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::m_times_to_first_token); + }) + .def_property_readonly("m_durations", [](const RawPerfMetrics &rw) { + return get_ms(rw, &RawPerfMetrics::m_durations); + }) + .def_readonly("m_batch_sizes", &RawPerfMetrics::m_batch_sizes) + .def_readonly("num_generated_tokens", &RawPerfMetrics::num_generated_tokens) + .def_readonly("num_input_tokens", &RawPerfMetrics::num_input_tokens); + + py::class_(m, "MeanStdPair") + .def(py::init<>()) + .def_readonly("mean", &MeanStdPair::mean) + .def_readonly("std", &MeanStdPair::std); + + py::class_(m, "PerfMetrics") + .def(py::init<>()) + .def("get_generate_duration", &PerfMetrics::get_generate_duration) + .def("get_tokenization_duration", &PerfMetrics::get_tokenization_duration) + .def("get_detokenization_duration", &PerfMetrics::get_detokenization_duration) + .def("get_throughput", &PerfMetrics::get_throughput) + .def("get_tpot", &PerfMetrics::get_tpot) + .def("get_ttft", &PerfMetrics::get_ttft) + .def("get_load_time", &PerfMetrics::get_load_time) + .def("__add__", &PerfMetrics::operator+) + .def("__iadd__", &PerfMetrics::operator+=) + .def_readonly("raw_metrics", &PerfMetrics::raw_metrics); py::class_(m, "TokenizedInputs") .def(py::init()) @@ -572,7 +624,8 @@ PYBIND11_MODULE(py_generate_pipeline, m) { py::class_(m, "EncodedResults") .def_readonly("tokens", &EncodedResults::tokens) - .def_readonly("scores", &EncodedResults::scores); + .def_readonly("scores", &EncodedResults::scores) + .def_readonly("perf_metrics", &EncodedResults::perf_metrics); py::class_>(m, "StreamerBase") // Change the holder form unique_ptr to shared_ptr .def(py::init<>())