diff --git a/llama-burn/Cargo.toml b/llama-burn/Cargo.toml index 6d65632..d61a26d 100644 --- a/llama-burn/Cargo.toml +++ b/llama-burn/Cargo.toml @@ -12,7 +12,7 @@ std = [] pretrained = ["burn/network", "std", "dep:dirs"] llama3 = ["pretrained", "dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"] -tiny = ["pretrained", "dep:rust_tokenizers"] +tiny = ["pretrained", "dep:tokenizers"] [dependencies] # Note: default-features = false is needed to disable std @@ -35,7 +35,7 @@ base64 = { version = "0.22.1", optional = true } rustc-hash = {version = "1.1.0", optional = true } # SentencePiece tokenizer (tiny llama / llama 2) -rust_tokenizers = { version = "8.1.1", optional = true } +tokenizers = { version = "0.19.1", default-features = false, features = ["onig"], optional = true } rand = { version = "0.8.5", default-features = false, features = [ "std_rng", diff --git a/llama-burn/examples/generate.rs b/llama-burn/examples/generate.rs index 1c166c5..e2d0c1d 100644 --- a/llama-burn/examples/generate.rs +++ b/llama-burn/examples/generate.rs @@ -11,20 +11,11 @@ use llama_burn::{ tokenizer::Tokenizer, }; -const DEFAULT_PROMPT: &str = "I believe the meaning of life is"; +const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?"; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] pub struct Config { - // TODO: download checkpoint from HF hub. - /// Model checkpoint path (automatically downloaded from the web if not present). - #[arg(short, long)] - model: String, - - /// Tokenizer path. - #[arg(short, long)] - tokenizer: String, - /// Top-p probability threshold. #[arg(long, default_value_t = 0.9)] top_p: f64, @@ -48,6 +39,10 @@ pub struct Config { /// The input prompt. #[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))] prompt: String, + + /// Chat assistant mode. + #[arg(short, long, default_value_t = cfg!(feature = "tiny"))] + chat: bool, } pub fn generate( @@ -57,7 +52,6 @@ pub fn generate( temperature: f64, sampler: &mut Sampler, ) { - println!("Processing prompt: {}", prompt); let now = Instant::now(); let generated = llama.generate(prompt, sample_len, temperature, sampler); let elapsed = now.elapsed().as_secs(); @@ -83,6 +77,7 @@ pub fn main() { let args = Config::parse(); let device = LibTorchDevice::Cuda(0); + let prompt = args.prompt; // Sampling strategy let mut sampler = if args.temperature > 0.0 { @@ -94,10 +89,21 @@ pub fn main() { #[cfg(feature = "tiny")] { let mut llama = LlamaConfig::tiny_llama_pretrained::(&device).unwrap(); + println!("Processing prompt: {}", prompt); + + let prompt = if args.chat { + // Prompt formatting for chat model + format!( + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\n{prompt}\n<|assistant|>\n" + ) + } else { + // Prompt with BOS token + format!("{}{prompt}", llama.tokenizer.bos()) + }; generate( &mut llama, - &args.prompt, + &prompt, args.sample_len, args.temperature, &mut sampler, @@ -107,10 +113,18 @@ pub fn main() { #[cfg(feature = "llama3")] { let mut llama = LlamaConfig::llama3_8b_pretrained::(&device).unwrap(); + println!("Processing prompt: {}", prompt); + + let prompt = if args.chat { + panic!("Llama-8B-Instruct is not available yet."); + } else { + // Prompt with BOS token + format!("{}{prompt}", llama.tokenizer.bos()) + }; generate( &mut llama, - &args.prompt, + &prompt, args.sample_len, args.temperature, &mut sampler, diff --git a/llama-burn/src/llama.rs b/llama-burn/src/llama.rs index 5548d78..3087938 100644 --- a/llama-burn/src/llama.rs +++ b/llama-burn/src/llama.rs @@ -14,7 +14,7 @@ use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}; use crate::{ sampling::Sampler, tokenizer::Tokenizer, - transformer::{KeyValueCache, Transformer, TransformerConfig}, + transformer::{KeyValueCache, Transformer, TransformerConfig, TransformerRecord}, }; #[cfg(feature = "pretrained")] @@ -275,12 +275,47 @@ impl LlamaConfig { } println!("Loading record..."); let now = Instant::now(); - let record = PyTorchFileRecorder::::new() + let mut record: TransformerRecord = PyTorchFileRecorder::::new() .load(load_args, device) .map_err(|e| e.to_string())?; let elapsed = now.elapsed().as_secs(); println!("Loaded in {}s", elapsed); + if cfg!(feature = "tiny") { + // TinyLlama weights from HuggingFace use a different rotary positional encoding + // which requires weight permutation: + // https://github.com/huggingface/transformers/issues/25199#issuecomment-1687720247 + // https://github.com/jzhang38/TinyLlama/issues/24 + let n_heads = self.num_attention_heads; + let n_kv_heads = self.num_key_value_heads.unwrap_or(n_heads); + let wk_dim = self.d_model * n_kv_heads / n_heads; + let permute = |w: Tensor, n_heads: usize, dim1: usize, dim2: usize| { + let w = w // [2048, 256] + .reshape([dim1, n_heads, 2, dim2 / n_heads / 2]) // [2048, 4, 2, 32] + .swap_dims(2, 3) // [2048, 4, 32, 2] + .reshape([dim1, dim2]); + w + }; + + record.layers = record + .layers + .into_iter() + .map(|mut layer| { + layer.attention.wq.weight = layer + .attention + .wq + .weight + .map(|w| permute(w, n_heads, self.d_model, self.d_model)); + layer.attention.wk.weight = layer + .attention + .wk + .weight + .map(|w| permute(w, n_kv_heads, self.d_model, wk_dim)); + layer + }) + .collect::>(); + } + llama.model = llama.model.load_record(record); println!("Llama record loaded"); @@ -301,14 +336,14 @@ pub struct GenerationOutput { /// Meta Llama large language model and tokenizer. pub struct Llama { /// The tokenizer. - tokenizer: T, + pub tokenizer: T, /// Llama decoder-only transformer. - model: Transformer, + pub model: Transformer, /// Key-value cache for each transformer block. - cache: Vec>, + pub cache: Vec>, /// Rotary positional encoding (RoPE). - rope: RotaryEncoding, - device: Device, + pub rope: RotaryEncoding, + pub device: Device, } impl Llama { @@ -346,7 +381,6 @@ impl Llama { .slice([0..batch_size, seq_len - 1..seq_len]) .squeeze(1); // [batch_size=1, vocab_size] - // TODO: naive sampling w/o cumsum tensor op to first test llama implementation correctness if temperature > 0.0 { next_token_logits = softmax(next_token_logits / temperature, 1); }; @@ -383,7 +417,7 @@ impl Llama { /// Encode a string into a tensor of tokens. fn tokenize(&self, text: &str) -> Tensor { - let tokens = self.tokenizer.encode(text, true, false); + let tokens = self.tokenizer.encode(text, false, false); let shape = Shape::new([tokens.len()]); Tensor::::from_data(Data::new(tokens, shape).convert(), &self.device) diff --git a/llama-burn/src/pretrained.rs b/llama-burn/src/pretrained.rs index 75cc21e..aa1b217 100644 --- a/llama-burn/src/pretrained.rs +++ b/llama-burn/src/pretrained.rs @@ -15,7 +15,7 @@ mod downloader { impl Pretrained { /// Download the file to the local cache directory. - fn download(&self, url: &str, file: &str) -> Result { + fn download(&self, url: &str) -> Result { // Model cache directory let model_dir = dirs::home_dir() .expect("Should be able to get home directory") @@ -27,10 +27,15 @@ mod downloader { create_dir_all(&model_dir)?; } - let file_name = model_dir.join(file); + let file_base_name = url + .rsplit_once('/') + .unwrap() + .1 + .replace("?download=true", ""); + let file_name = model_dir.join(&file_base_name); if !file_name.exists() { // Download file content - let bytes = downloader::download_file_as_bytes(url, file); + let bytes = downloader::download_file_as_bytes(url, &file_base_name); // Write content to file let mut output_file = File::create(&file_name)?; @@ -42,12 +47,12 @@ mod downloader { /// Download the pre-trained model weights to the local cache directory. pub fn download_weights(&self) -> Result { - self.download(self.model, "model.mpk") + self.download(self.model) } /// Download the tokenizer to the local cache directory. pub fn download_tokenizer(&self) -> Result { - self.download(self.tokenizer, "tokenizer.model") + self.download(self.tokenizer) } } } @@ -75,7 +80,7 @@ impl ModelMeta for Llama { Self::TinyLlama => Pretrained { name: "TinyLlama-1.1B", model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true", - tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.model?download=true", + tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.json?download=true", }, } } diff --git a/llama-burn/src/tokenizer/base.rs b/llama-burn/src/tokenizer/base.rs index de185e7..046cf32 100644 --- a/llama-burn/src/tokenizer/base.rs +++ b/llama-burn/src/tokenizer/base.rs @@ -10,9 +10,19 @@ pub trait Tokenizer { /// Decode a list of token identifiers into a string. fn decode(&self, tokens: Vec) -> String; + /// Beginning of sentence token. + fn bos(&self) -> String { + self.decode(vec![self.bos_id()]) + } + /// Beginning of sentence token identifier. fn bos_id(&self) -> u32; + /// End of sentence token. + fn eos(&self) -> String { + self.decode(vec![self.eos_id()]) + } + /// End of sentence token identifier. fn eos_id(&self) -> u32; }