diff --git a/llama-burn/Cargo.toml b/llama-burn/Cargo.toml index 899236d..b417cf1 100644 --- a/llama-burn/Cargo.toml +++ b/llama-burn/Cargo.toml @@ -16,12 +16,15 @@ tiny = ["dep:tokenizers"] # Example feature flags (backend selection) tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] +cuda = ["burn/cuda-jit"] wgpu = ["burn/wgpu"] +# To import pytorch weights +import = ["burn-import"] + [dependencies] -# Note: default-features = false is needed to disable std -burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false } -burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" } +burn = { version = "0.14.0", default-features = false, features = ["std"] } +burn-import = { version = "0.14.0", optional = true } itertools = { version = "0.12.1", default-features = false, features = [ "use_alloc", ] } @@ -46,5 +49,5 @@ rand = { version = "0.8.5", default-features = false, features = [ ] } # std_rng is for no_std [dev-dependencies] -burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" } +burn = { version = "0.14.0", default-features = false } clap = { version = "4.5.4", features = ["derive"] } diff --git a/llama-burn/NOTICES.md b/llama-burn/NOTICES.md index 2758fe5..08caad9 100644 --- a/llama-burn/NOTICES.md +++ b/llama-burn/NOTICES.md @@ -8,6 +8,8 @@ derived from. The use of the following resources complies with the licenses prov The model implementation was adapted from the original [Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the [Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE). +The Llama 3.1 model is distributed under the +[Llama 3.1 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE). The TinyLlama implementation is derived from the same code, but its weights and tokenizers were adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under diff --git a/llama-burn/README.md b/llama-burn/README.md index cc757bc..1db56ab 100644 --- a/llama-burn/README.md +++ b/llama-burn/README.md @@ -4,7 +4,8 @@ The popular Llama LLM is here! -This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and +This repository contains the [Llama 3.1](https://github.com/meta-llama/llama-models/), +[Llama 3](https://github.com/meta-llama/llama3) and [TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama variants in [src/llama.rs](src/llama.rs). @@ -23,9 +24,7 @@ llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-bur If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are active), enable the corresponding feature flag. -> **Important:** these features require `std`. Note that the weights have been saved in the binary -> format, which is more compact and faster to save & load, but might not be compatible in future -> versions if the Burn data schema were to evolve. +> **Important:** these features require `std`. #### Llama 3 @@ -47,7 +46,7 @@ The [chat completion example](examples/chat.rs) initializes a Llama model from t file and generates a sequence of text based on the input prompt. The instruction-tuned model is loaded for dialogue applications, so the prompt is automatically formatted for chat completion. -The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`. +The example can be executed on the `tch` backend (CUDA or CPU), `cuda` or `wgpu`. | Argument | Description | | :-------------- | :------------------------------------------------------------------------------------------------------------- | @@ -83,9 +82,16 @@ Using the `wgpu` backend: cargo run --release --features llama3,wgpu --example chat ``` +Using the `cuda` backend: + +```sh +cargo run --release --features llama3,cuda --example chat +``` + **Built with Meta Llama 3.** This example uses the -[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) -instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is +[Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) (default) +and [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) +instruction-tuned models. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is also available if you wish to use it in your application. #### TinyLlama @@ -109,6 +115,18 @@ Using the `wgpu` backend: cargo run --release --features tiny,wgpu --example chat ``` +Using the `cuda` backend: + +```sh +cargo run --release --features tiny,cuda --example chat +``` + This example uses the [TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) instruction-tuned model based on the Llama2 architecture and tokenizer. + +## Known Issues + +Based on your hardware and the model selected, the `wgpu` backend might not be able to successfully +run the model due to the current memory management strategy. With `cuda` selected, the precision is +set to `f32` due to compilation errors with `f16`. diff --git a/llama-burn/examples/chat.rs b/llama-burn/examples/chat.rs index e22103d..26215e1 100644 --- a/llama-burn/examples/chat.rs +++ b/llama-burn/examples/chat.rs @@ -8,6 +8,9 @@ use llama_burn::{ tokenizer::Tokenizer, }; +#[cfg(feature = "llama3")] +use clap::ValueEnum; + const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?"; #[derive(Parser, Debug)] @@ -26,7 +29,7 @@ pub struct Config { max_seq_len: usize, /// The number of new tokens to generate (i.e., the number of generation steps to take). - #[arg(long, short = 'n', default_value_t = 50)] + #[arg(long, short = 'n', default_value_t = 65)] sample_len: usize, /// The seed to use when generating random samples. @@ -36,6 +39,23 @@ pub struct Config { /// The input prompt. #[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))] prompt: String, + + /// The Llama 3 model version. + #[cfg(feature = "llama3")] + #[arg(long, default_value = "llama-3.1-8b-instruct")] + version: Llama3, +} + +#[cfg(feature = "llama3")] +#[derive(Clone, Debug, ValueEnum)] +/// Llama-3 model variants to load. +enum Llama3 { + /// Llama-3-8B-Instruct. + #[value(name = "llama-3-8b-instruct")] + V3Instruct, + /// Llama-3.1-8B-Instruct. + #[value(name = "llama-3.1-8b-instruct")] + V31Instruct, } pub fn generate( @@ -76,7 +96,7 @@ pub fn chat(args: Config, device: Device) { #[cfg(feature = "tiny")] { // TinyLlama-1.1B Chat v1.0 - let mut llama = LlamaConfig::tiny_llama_pretrained::(&device).unwrap(); + let mut llama = LlamaConfig::tiny_llama_pretrained::(args.max_seq_len, &device).unwrap(); println!("Processing prompt: {}", prompt); // Prompt formatting for chat model @@ -95,8 +115,15 @@ pub fn chat(args: Config, device: Device) { #[cfg(feature = "llama3")] { - // Llama-3-8B-Instruct - let mut llama = LlamaConfig::llama3_8b_pretrained::(true, &device).unwrap(); + // Llama-3-8B-Instruct or Llama-3.1-8B-Instruct + let mut llama = match args.version { + Llama3::V3Instruct => { + LlamaConfig::llama3_8b_pretrained::(args.max_seq_len, &device).unwrap() + } + Llama3::V31Instruct => { + LlamaConfig::llama3_1_8b_pretrained::(args.max_seq_len, &device).unwrap() + } + }; println!("Processing prompt: {}", prompt); // Prompt formatting for chat model @@ -156,6 +183,19 @@ mod wgpu { } } +#[cfg(feature = "cuda")] +mod cuda { + use super::*; + use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + + pub fn run(args: Config) { + let device = CudaDevice::default(); + + // NOTE: compilation errors in f16 + chat::>(args, device); + } +} + pub fn main() { // Parse arguments let args = Config::parse(); @@ -166,4 +206,6 @@ pub fn main() { tch_cpu::run(args); #[cfg(feature = "wgpu")] wgpu::run(args); + #[cfg(feature = "cuda")] + cuda::run(args); } diff --git a/llama-burn/src/cache.rs b/llama-burn/src/cache.rs index 714f92f..9be0808 100644 --- a/llama-burn/src/cache.rs +++ b/llama-burn/src/cache.rs @@ -1,11 +1,5 @@ use burn::tensor::{backend::Backend, Tensor}; -/// All Llama-3 models support sequence length up to 8192 tokens. -pub(crate) const MAX_SEQ_LEN: usize = 8192; - -// /// All Llama-2 models support sequence length up to 4096 tokens. -// pub(crate) const MAX_SEQ_LEN_V2: usize = 4096; - // Adapted from `burn::nn::cache` enum CacheState { Value(T), @@ -39,11 +33,6 @@ pub(crate) struct AutoregressiveCache { impl AutoregressiveCache { /// Creates a new empty cache. pub fn new(max_seq_len: usize) -> Self { - assert!( - max_seq_len <= MAX_SEQ_LEN, - "Maximum sequence length must not exceed {MAX_SEQ_LEN}" - ); - Self { cache: TensorCache::empty(), max_seq_len, diff --git a/llama-burn/src/llama.rs b/llama-burn/src/llama.rs index cb657e7..055ec41 100644 --- a/llama-burn/src/llama.rs +++ b/llama-burn/src/llama.rs @@ -4,18 +4,24 @@ use burn::{ config::Config, module::Module, nn::{RotaryEncoding, RotaryEncodingConfig}, - record::{FileRecorder, HalfPrecisionSettings, Recorder, RecorderError}, + record::{FileRecorder, HalfPrecisionSettings, RecorderError}, tensor::{ activation::softmax, backend::Backend, Device, ElementConversion, Int, Shape, Tensor, TensorData, }, }; -use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; + +#[cfg(feature = "import")] +use { + crate::transformer::TransformerRecord, + burn::record::Recorder, + burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}, +}; use crate::{ sampling::Sampler, tokenizer::Tokenizer, - transformer::{KeyValueCache, Transformer, TransformerConfig, TransformerRecord}, + transformer::{KeyValueCache, Transformer, TransformerConfig}, }; #[cfg(feature = "pretrained")] @@ -48,6 +54,9 @@ pub struct LlamaConfig { /// Rotary positional encoding (RoPE) theta. #[config(default = "10000.0")] pub rope_theta: f32, + /// Use scaled RoPE. + #[config(default = "false")] + pub rope_scaled: bool, /// Maximum sequence length for input text. #[config(default = "128")] pub max_seq_len: usize, @@ -56,6 +65,15 @@ pub struct LlamaConfig { } impl LlamaConfig { + /// Llama-3.1-8B configuration. + pub fn llama3_1_8b(tokenizer_path: &str) -> Self { + // hidden_size = 14336; vocab_size = 128256 + Self::new(14336, 128256, tokenizer_path.to_string()) + .with_num_key_value_heads(Some(8)) + .with_rope_theta(500000.0) + .with_rope_scaled(true) + } + /// Llama-3-8B configuration. pub fn llama3_8b(tokenizer_path: &str) -> Self { // hidden_size = 14336; vocab_size = 128256 @@ -64,16 +82,83 @@ impl LlamaConfig { .with_rope_theta(500000.0) } + /// TinyLlama-1.1B Chat v1.0 configuration. + pub fn tiny_llama(tokenizer_path: &str) -> Self { + // hidden_size = 5632; vocab_size = 32000 + Self::new(5632, 32000, tokenizer_path.to_string()) + .with_d_model(2048) + .with_num_hidden_layers(22) + .with_num_key_value_heads(Some(4)) + .with_rope_theta(10000.0) + } + + /// Load pre-trained Llama-3.1-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. + #[cfg(feature = "llama3")] + pub fn load_llama3_1_8b( + checkpoint: &str, + tokenizer_path: &str, + max_seq_len: usize, + device: &Device, + ) -> Result, String> { + use burn::record::NamedMpkFileRecorder; + + let llama = Self::llama3_1_8b(tokenizer_path) + .with_max_seq_len(max_seq_len) + .init::(device)?; + + let recorder = NamedMpkFileRecorder::::new(); + let llama = llama + .load(checkpoint, &recorder) + .map_err(|err| format!("Failed to load pre-trained Llama model.\nError: {err}"))?; + + Ok(llama) + } + + /// Load pre-trained Llama-3.1-8B-Instruct model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. + /// + /// # Arguments + /// - `max_seq_len` - The maximum sequence length for input text. + /// - `device` - The device to load the model on. + #[cfg(all(feature = "llama3", feature = "pretrained"))] + pub fn llama3_1_8b_pretrained( + max_seq_len: usize, + device: &Device, + ) -> Result, String> { + // Llama-3.1 models support context length up to 128K tokens. + check_context_length(max_seq_len, 128 * 1024); + + // Download checkpoint and tokenizer + let model = pretrained::Llama::Llama31Instruct.pretrained(); + let checkpoint = model + .download_weights() + .map_err(|err| format!("Could not download weights.\nError: {err}"))?; + let tokenizer = model + .download_tokenizer() + .map_err(|err| format!("Could not download tokenizer.\nError: {err}"))?; + + Self::load_llama3_1_8b( + checkpoint.to_str().unwrap(), + tokenizer.to_str().unwrap(), + // "/home/laggui/workspace/llama-models/models/llama3_1/Meta-Llama-3.1-8B-Instruct/model", + // "/home/laggui/workspace/llama-models/models/llama3_1/Meta-Llama-3.1-8B-Instruct/tokenizer.model", + max_seq_len, + device, + ) + } + /// Load pre-trained Llama-3-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. #[cfg(feature = "llama3")] pub fn load_llama3_8b( checkpoint: &str, tokenizer_path: &str, + max_seq_len: usize, device: &Device, ) -> Result, String> { use burn::record::NamedMpkFileRecorder; - let llama = Self::llama3_8b(tokenizer_path).init::(device)?; + let llama = Self::llama3_8b(tokenizer_path) + .with_max_seq_len(max_seq_len) + .init::(device)?; let recorder = NamedMpkFileRecorder::::new(); let llama = llama @@ -83,22 +168,21 @@ impl LlamaConfig { Ok(llama) } - /// Load pre-trained Llama-3-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. + /// Load pre-trained Llama-3-8B-Instruct model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. /// /// # Arguments - /// - `instruct`: If true, load the instruction-tuned model for dialogue applications (e.g., chat). + /// - `max_seq_len` - The maximum sequence length for input text. /// - `device` - The device to load the model on. #[cfg(all(feature = "llama3", feature = "pretrained"))] pub fn llama3_8b_pretrained( - instruct: bool, + max_seq_len: usize, device: &Device, ) -> Result, String> { + // Llama-3 models support context length up to 8K tokens. + check_context_length(max_seq_len, 8 * 1024); + // Download checkpoint and tokenizer - let model = if instruct { - pretrained::Llama::Llama3Instruct.pretrained() - } else { - pretrained::Llama::Llama3.pretrained() - }; + let model = pretrained::Llama::Llama3Instruct.pretrained(); let checkpoint = model .download_weights() .map_err(|err| format!("Could not download weights.\nError: {err}"))?; @@ -109,30 +193,24 @@ impl LlamaConfig { Self::load_llama3_8b( checkpoint.to_str().unwrap(), tokenizer.to_str().unwrap(), + max_seq_len, device, ) } - /// TinyLlama-1.1B Chat v1.0 configuration. - pub fn tiny_llama(tokenizer_path: &str) -> Self { - // hidden_size = 5632; vocab_size = 32000 - Self::new(5632, 32000, tokenizer_path.to_string()) - .with_d_model(2048) - .with_num_hidden_layers(22) - .with_num_key_value_heads(Some(4)) - .with_rope_theta(10000.0) - } - /// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. #[cfg(feature = "tiny")] pub fn load_tiny_llama( checkpoint: &str, tokenizer_path: &str, + max_seq_len: usize, device: &Device, ) -> Result, String> { use burn::record::NamedMpkFileRecorder; - let llama = Self::tiny_llama(tokenizer_path).init::(device)?; + let llama = Self::tiny_llama(tokenizer_path) + .with_max_seq_len(max_seq_len) + .init::(device)?; let recorder = NamedMpkFileRecorder::::new(); let llama = llama @@ -145,8 +223,12 @@ impl LlamaConfig { /// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. #[cfg(all(feature = "tiny", feature = "pretrained"))] pub fn tiny_llama_pretrained( + max_seq_len: usize, device: &Device, ) -> Result, String> { + // TinyLlama models support context length up to 2K tokens. + check_context_length(max_seq_len, 2 * 1024); + // Download checkpoint and tokenizer let model = pretrained::Llama::TinyLlama.pretrained(); let checkpoint = model @@ -159,6 +241,7 @@ impl LlamaConfig { Self::load_tiny_llama( checkpoint.to_str().unwrap(), tokenizer.to_str().unwrap(), + max_seq_len, device, ) } @@ -190,8 +273,13 @@ impl LlamaConfig { self.max_seq_len * 2, self.d_model / self.num_attention_heads, ) - .with_theta(self.rope_theta) - .init(device); + .with_theta(self.rope_theta); + + let rope = if self.rope_scaled { + rope.init_with_frequency_scaling(freq_scaling_by_parts, device) + } else { + rope.init(device) + }; Ok(Llama { tokenizer, @@ -203,6 +291,7 @@ impl LlamaConfig { } /// Load pre-trained Llama checkpoint. + #[cfg(feature = "import")] pub fn load_pretrained( &self, checkpoint: &str, @@ -333,6 +422,13 @@ impl LlamaConfig { } } +fn check_context_length(max_seq_len: usize, max_context_len: usize) { + assert!( + max_seq_len <= max_context_len, + "Maximum sequence length must not exceed {max_context_len}" + ); +} + /// Generated text sample output. pub struct GenerationOutput { /// The generated text. @@ -377,7 +473,7 @@ impl Llama { ) -> GenerationOutput { let mut tokens = self.tokenize(prompt); let prompt_len = tokens.dims()[0]; - let eos_token = self.tokenizer.eos_id() as i64; + let stop_tokens = Tensor::from_ints(self.tokenizer.stop_ids().as_slice(), &self.device); let mut num_tokens: usize = 0; let mut input_pos = Tensor::::arange(0..tokens.dims()[0] as i64, &self.device); @@ -395,19 +491,25 @@ impl Llama { next_token_logits = softmax(next_token_logits / temperature, 1); }; - let next_token = sampler.sample(next_token_logits); + let next_token = sampler.sample(next_token_logits).squeeze(0); + + // Stop when any of the valid stop tokens is encountered + if stop_tokens + .clone() + .equal(next_token.clone()) + .any() + .into_scalar() + { + break; + } // Concatenate the new generated token - tokens = Tensor::cat(vec![tokens, next_token.clone().squeeze(0)], 0); + tokens = Tensor::cat(vec![tokens, next_token], 0); num_tokens += 1; // Advance let t = input_pos.dims()[0]; input_pos = input_pos.slice([t - 1..t]) + 1; - - if next_token.equal_elem(eos_token).all().into_scalar() { - break; - } } let tokens = tokens.into_data().as_slice::().unwrap()[prompt_len..] @@ -464,3 +566,45 @@ impl Llama { Ok(self) } } + +/// Applies frequency scaling by parts following Llama 3.1's scheme. +/// +/// Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45 +fn freq_scaling_by_parts(freqs: Tensor) -> Tensor { + let scale_factor = 8.; + let low_freq_factor = 1.; + let high_freq_factor = 4.; + let old_context_len = 8192.; + + let low_freq_wavelen = old_context_len / low_freq_factor; + let high_freq_wavelen = old_context_len / high_freq_factor; + + let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI); + + // if wavelen >= high_freq_wavelen + let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen); + let smooth = wavelen + .clone() + .recip() + .mul_scalar(old_context_len) + .sub_scalar(low_freq_factor) + .div_scalar(high_freq_factor - low_freq_factor); + // (1 - smooth) * freq / scale_factor + smooth * freq + let new_freqs = smooth + .clone() + .neg() + .add_scalar(1.) + .mul(freqs.clone().div_scalar(scale_factor)) + .add(smooth.clone().mul(freqs.clone())); + let new_freqs = freqs.clone().mask_where(cond, new_freqs); + + // if wavelen > low_freq_wavelen + let cond = wavelen.clone().greater_elem(low_freq_wavelen); + let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor)); + + // if wavelen < high_freq_wavelen + let cond = wavelen.lower_elem(high_freq_wavelen); + let new_freqs = new_freqs.mask_where(cond, freqs); + + new_freqs +} diff --git a/llama-burn/src/pretrained.rs b/llama-burn/src/pretrained.rs index 405ca4c..97ab7e2 100644 --- a/llama-burn/src/pretrained.rs +++ b/llama-burn/src/pretrained.rs @@ -67,6 +67,8 @@ pub enum Llama { Llama3, /// Llama-3-8B-Instruct. Llama3Instruct, + /// Llama-3.1-8B-Instruct. + Llama31Instruct, /// TinyLlama-1.1B Chat v1.0. TinyLlama, } @@ -84,6 +86,11 @@ impl ModelMeta for Llama { model: "https://huggingface.co/tracel-ai/llama-3-8b-instruct-burn/resolve/main/model.mpk?download=true", tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-instruct-burn/resolve/main/tokenizer.model?download=true", }, + Self::Llama31Instruct => Pretrained { + name: "Llama-3.1-8B-Instruct", + model: "https://huggingface.co/tracel-ai/llama-3.1-8b-instruct-burn/resolve/main/model.mpk?download=true", + tokenizer: "https://huggingface.co/tracel-ai/llama-3.1-8b-instruct-burn/resolve/main/tokenizer.model?download=true", + }, Self::TinyLlama => Pretrained { name: "TinyLlama-1.1B", model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true", diff --git a/llama-burn/src/tokenizer/base.rs b/llama-burn/src/tokenizer/base.rs index 046cf32..b1fa58a 100644 --- a/llama-burn/src/tokenizer/base.rs +++ b/llama-burn/src/tokenizer/base.rs @@ -25,4 +25,7 @@ pub trait Tokenizer { /// End of sentence token identifier. fn eos_id(&self) -> u32; + + /// Stop token identifiers. + fn stop_ids(&self) -> Vec; } diff --git a/llama-burn/src/tokenizer/sentence_piece.rs b/llama-burn/src/tokenizer/sentence_piece.rs index 801fe00..622cc00 100644 --- a/llama-burn/src/tokenizer/sentence_piece.rs +++ b/llama-burn/src/tokenizer/sentence_piece.rs @@ -51,4 +51,8 @@ impl Tokenizer for SentiencePieceTokenizer { fn eos_id(&self) -> u32 { self.eos_token_id } + + fn stop_ids(&self) -> Vec { + vec![self.eos_id()] + } } diff --git a/llama-burn/src/tokenizer/tiktoken.rs b/llama-burn/src/tokenizer/tiktoken.rs index 96321ce..e7af7e8 100644 --- a/llama-burn/src/tokenizer/tiktoken.rs +++ b/llama-burn/src/tokenizer/tiktoken.rs @@ -1,5 +1,4 @@ use std::{ - collections::HashSet, fs::File, io::{BufRead, BufReader}, }; @@ -12,19 +11,22 @@ use super::Tokenizer; const BOS_TOKEN: &str = "<|begin_of_text|>"; const EOS_TOKEN: &str = "<|end_of_text|>"; +const EOT_TOKEN: &str = "<|eot_id|>"; +const EOM_TOKEN: &str = "<|eom_id|>"; const NUM_RESERVED_SPECIAL_TOKENS: usize = 256; -const SPECIAL_TOKENS: [&str; 10] = [ +const SPECIAL_TOKENS: [&str; 11] = [ BOS_TOKEN, EOS_TOKEN, "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", - "<|reserved_special_token_2|>", - "<|reserved_special_token_3|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", "<|start_header_id|>", "<|end_header_id|>", - "<|reserved_special_token_4|>", - "<|eot_id|>", // end of turn + EOM_TOKEN, // end of message + EOT_TOKEN, // end of turn + "<|python_tag|>", ]; const PATTERN: &str = r#"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"#; @@ -33,6 +35,8 @@ pub struct Tiktoken { bpe: CoreBPE, bos_token_id: usize, eos_token_id: usize, + eot_token_id: usize, + eom_token_id: usize, } impl Tokenizer for Tiktoken { @@ -61,9 +65,9 @@ impl Tokenizer for Tiktoken { .iter() .map(|t| t.to_string()) .collect::>(), - (5..NUM_RESERVED_SPECIAL_TOKENS - 5) + (0..NUM_RESERVED_SPECIAL_TOKENS - SPECIAL_TOKENS.len()) .into_iter() - .map(|i| format!("<|reserved_special_token_{i}|>")) + .map(|i| format!("<|reserved_special_token_{}|>", i + 2)) .collect::>(), ] .concat(); @@ -75,6 +79,8 @@ impl Tokenizer for Tiktoken { let bos_token_id = special_tokens[BOS_TOKEN]; let eos_token_id = special_tokens[EOS_TOKEN]; + let eot_token_id = special_tokens[EOT_TOKEN]; + let eom_token_id = special_tokens[EOM_TOKEN]; let bpe = CoreBPE::new(mergeable_ranks, special_tokens, PATTERN).map_err(|e| e.to_string())?; @@ -82,6 +88,8 @@ impl Tokenizer for Tiktoken { bpe, bos_token_id, eos_token_id, + eot_token_id, + eom_token_id, }) } @@ -89,8 +97,7 @@ impl Tokenizer for Tiktoken { let bos_token = if bos { vec![self.bos_token_id] } else { vec![] }; let eos_token = if eos { vec![self.eos_token_id] } else { vec![] }; - // `allowed_special` is an empty set - let tokens = self.bpe.encode(text, HashSet::new()); + let tokens = self.bpe.encode_with_special_tokens(text); [bos_token, tokens, eos_token] .into_iter() @@ -112,4 +119,12 @@ impl Tokenizer for Tiktoken { fn eos_id(&self) -> u32 { self.eos_token_id as u32 } + + fn stop_ids(&self) -> Vec { + vec![ + self.eos_id(), + self.eom_token_id as u32, + self.eot_token_id as u32, + ] + } }