Skip to content

Commit

Permalink
Llama 3.1 (#38)
Browse files Browse the repository at this point in the history
* Add Llama 3.1

* Use provided max_seq_len and remove llama3_1 feature flag

* Fix llama3 tokenizer w/ special tokens

* Add Llama3 enum for model selection

* Add end of turn token w/ stop tokens

* Add Llama-3.1-8B-Instruct link

* Add llama 3.1 tokenizer new tokens

* Update to burn 0.14 and add import feature flag

* Add cuda backend

* Add known issues

* Add llama 3.1 link

* Add llama 3.1 community license agreement

* Remove space

* Remove deprecated note regarding binary weights (we use mpk now)

* Add missing cuda backend mention
  • Loading branch information
laggui authored Sep 3, 2024
1 parent 7ebd9e3 commit 877996b
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 68 deletions.
11 changes: 7 additions & 4 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ tiny = ["dep:tokenizers"]
# Example feature flags (backend selection)
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
cuda = ["burn/cuda-jit"]
wgpu = ["burn/wgpu"]

# To import pytorch weights
import = ["burn-import"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
burn = { version = "0.14.0", default-features = false, features = ["std"] }
burn-import = { version = "0.14.0", optional = true }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
Expand All @@ -46,5 +49,5 @@ rand = { version = "0.8.5", default-features = false, features = [
] } # std_rng is for no_std

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
burn = { version = "0.14.0", default-features = false }
clap = { version = "4.5.4", features = ["derive"] }
2 changes: 2 additions & 0 deletions llama-burn/NOTICES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ derived from. The use of the following resources complies with the licenses prov
The model implementation was adapted from the original
[Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the
[Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE).
The Llama 3.1 model is distributed under the
[Llama 3.1 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE).

The TinyLlama implementation is derived from the same code, but its weights and tokenizers were
adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under
Expand Down
32 changes: 25 additions & 7 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

The popular Llama LLM is here!

This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and
This repository contains the [Llama 3.1](https://github.com/meta-llama/llama-models/),
[Llama 3](https://github.com/meta-llama/llama3) and
[TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding
tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama
variants in [src/llama.rs](src/llama.rs).
Expand All @@ -23,9 +24,7 @@ llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-bur
If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are
active), enable the corresponding feature flag.

> **Important:** these features require `std`. Note that the weights have been saved in the binary
> format, which is more compact and faster to save & load, but might not be compatible in future
> versions if the Burn data schema were to evolve.
> **Important:** these features require `std`.
#### Llama 3

Expand All @@ -47,7 +46,7 @@ The [chat completion example](examples/chat.rs) initializes a Llama model from t
file and generates a sequence of text based on the input prompt. The instruction-tuned model is
loaded for dialogue applications, so the prompt is automatically formatted for chat completion.

The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`.
The example can be executed on the `tch` backend (CUDA or CPU), `cuda` or `wgpu`.

| Argument | Description |
| :-------------- | :------------------------------------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -83,9 +82,16 @@ Using the `wgpu` backend:
cargo run --release --features llama3,wgpu --example chat
```

Using the `cuda` backend:

```sh
cargo run --release --features llama3,cuda --example chat
```

**Built with Meta Llama 3.** This example uses the
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
[Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) (default)
and [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned models. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
also available if you wish to use it in your application.

#### TinyLlama
Expand All @@ -109,6 +115,18 @@ Using the `wgpu` backend:
cargo run --release --features tiny,wgpu --example chat
```

Using the `cuda` backend:

```sh
cargo run --release --features tiny,cuda --example chat
```

This example uses the
[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
instruction-tuned model based on the Llama2 architecture and tokenizer.

## Known Issues

Based on your hardware and the model selected, the `wgpu` backend might not be able to successfully
run the model due to the current memory management strategy. With `cuda` selected, the precision is
set to `f32` due to compilation errors with `f16`.
50 changes: 46 additions & 4 deletions llama-burn/examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use llama_burn::{
tokenizer::Tokenizer,
};

#[cfg(feature = "llama3")]
use clap::ValueEnum;

const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?";

#[derive(Parser, Debug)]
Expand All @@ -26,7 +29,7 @@ pub struct Config {
max_seq_len: usize,

/// The number of new tokens to generate (i.e., the number of generation steps to take).
#[arg(long, short = 'n', default_value_t = 50)]
#[arg(long, short = 'n', default_value_t = 65)]
sample_len: usize,

/// The seed to use when generating random samples.
Expand All @@ -36,6 +39,23 @@ pub struct Config {
/// The input prompt.
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))]
prompt: String,

/// The Llama 3 model version.
#[cfg(feature = "llama3")]
#[arg(long, default_value = "llama-3.1-8b-instruct")]
version: Llama3,
}

#[cfg(feature = "llama3")]
#[derive(Clone, Debug, ValueEnum)]
/// Llama-3 model variants to load.
enum Llama3 {
/// Llama-3-8B-Instruct.
#[value(name = "llama-3-8b-instruct")]
V3Instruct,
/// Llama-3.1-8B-Instruct.
#[value(name = "llama-3.1-8b-instruct")]
V31Instruct,
}

pub fn generate<B: Backend, T: Tokenizer>(
Expand Down Expand Up @@ -76,7 +96,7 @@ pub fn chat<B: Backend>(args: Config, device: Device<B>) {
#[cfg(feature = "tiny")]
{
// TinyLlama-1.1B Chat v1.0
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(args.max_seq_len, &device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
Expand All @@ -95,8 +115,15 @@ pub fn chat<B: Backend>(args: Config, device: Device<B>) {

#[cfg(feature = "llama3")]
{
// Llama-3-8B-Instruct
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(true, &device).unwrap();
// Llama-3-8B-Instruct or Llama-3.1-8B-Instruct
let mut llama = match args.version {
Llama3::V3Instruct => {
LlamaConfig::llama3_8b_pretrained::<B>(args.max_seq_len, &device).unwrap()
}
Llama3::V31Instruct => {
LlamaConfig::llama3_1_8b_pretrained::<B>(args.max_seq_len, &device).unwrap()
}
};
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
Expand Down Expand Up @@ -156,6 +183,19 @@ mod wgpu {
}
}

#[cfg(feature = "cuda")]
mod cuda {
use super::*;
use burn::backend::{cuda_jit::CudaDevice, CudaJit};

pub fn run(args: Config) {
let device = CudaDevice::default();

// NOTE: compilation errors in f16
chat::<CudaJit<f32, i32>>(args, device);
}
}

pub fn main() {
// Parse arguments
let args = Config::parse();
Expand All @@ -166,4 +206,6 @@ pub fn main() {
tch_cpu::run(args);
#[cfg(feature = "wgpu")]
wgpu::run(args);
#[cfg(feature = "cuda")]
cuda::run(args);
}
11 changes: 0 additions & 11 deletions llama-burn/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use burn::tensor::{backend::Backend, Tensor};

/// All Llama-3 models support sequence length up to 8192 tokens.
pub(crate) const MAX_SEQ_LEN: usize = 8192;

// /// All Llama-2 models support sequence length up to 4096 tokens.
// pub(crate) const MAX_SEQ_LEN_V2: usize = 4096;

// Adapted from `burn::nn::cache`
enum CacheState<T> {
Value(T),
Expand Down Expand Up @@ -39,11 +33,6 @@ pub(crate) struct AutoregressiveCache<B: Backend> {
impl<B: Backend> AutoregressiveCache<B> {
/// Creates a new empty cache.
pub fn new(max_seq_len: usize) -> Self {
assert!(
max_seq_len <= MAX_SEQ_LEN,
"Maximum sequence length must not exceed {MAX_SEQ_LEN}"
);

Self {
cache: TensorCache::empty(),
max_seq_len,
Expand Down
Loading

0 comments on commit 877996b

Please sign in to comment.