Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast encode #1560

Closed
wants to merge 17 commits into from
1,327 changes: 1,327 additions & 0 deletions bindings/python/Cargo.lock

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions bindings/python/benches/bench_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/opt/homebrew/bin/python3.12
import base64
import functools
import gzip
import json
import os
import random
import time
from typing import Any, cast

import blobfile

import tiktoken
from tokenizers import Tokenizer


def format_byte_size(num_bytes: int) -> str:
"""Convert bytes to a human-readable format (KB, MB, GB)."""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if num_bytes < 1024:
return f"{num_bytes:.2f} {unit}", unit
num_bytes /= 1024
return f"{num_bytes:.2f} PB", unit


def benchmark_batch(documents: list[str]) -> None:
num_threads = int(os.environ["RAYON_NUM_THREADS"])
num_bytes = sum(map(len, map(str.encode, documents)))
readable_size, unit = format_byte_size(num_bytes)
print(f"num_threads: {num_threads}, data size: {readable_size}")
enc = tiktoken.get_encoding("gpt2")
enc.encode("warmup")

start = time.perf_counter_ns()
enc.encode_ordinary_batch(documents, num_threads=num_threads)
end = time.perf_counter_ns()

readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"tiktoken \t{readable_size} / s")

hf_enc = Tokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
hf_enc.encode("warmup")

start = time.perf_counter_ns()
hf_enc.encode_batch(documents)
end = time.perf_counter_ns()
readable_size, unit = format_byte_size(num_bytes / (end - start) * 1e9)
print(f"huggingface \t{readable_size} / s")


import os
import time
import tqdm
from datasets import load_dataset
import tiktoken


def test_on_xnli():
dataset_xnli = load_dataset("facebook/xnli", "all_languages")

# Varying the number of threads and length of input
num_threads_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16, 32] # Example thread counts
input_lengths = [10_000] # Example input lengths

documents = ["".join(item["premise"].values()) for item in dataset_xnli["train"]]
for num_threads in num_threads_list:
os.environ["RAYON_NUM_THREADS"] = str(num_threads)
os.environ["TOKENIZER_PARALLELISM"] = str(num_threads)
os.environ["RAYON_RS_NUM_THREADS"] = str(num_threads)
for length in input_lengths:
if length == 100_000 and num_threads == 1:
break
benchmark_batch(documents[:length])


# Call the function to run the benchmark
test_on_xnli()
12 changes: 11 additions & 1 deletion tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ harness = false
name = "unigram_benchmark"
harness = false

[[bench]]
name = "bert_decode_benchmark"
harness = false

[[bench]]
name = "important_tokenizer_benchmark"
harness = false

[dependencies]
lazy_static = "1.4"
rand = "0.8"
Expand Down Expand Up @@ -63,6 +71,7 @@ fancy-regex = { version = "0.13", optional = true}
getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.12"
mimalloc = "0.1"

[features]
default = ["progressbar", "onig", "esaxx_fast"]
Expand All @@ -79,4 +88,5 @@ tracing = "0.1"
tracing-subscriber = "0.3.18"

[profile.release]
lto = "fat"
debug = true
strip = false
25 changes: 25 additions & 0 deletions tokenizers/benches/bert_decode_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use tokenizers::tokenizer::Tokenizer;

fn decode(tokenizer: &Tokenizer, ids_slice: Vec<u32>, skip_special_tokens: bool) -> String {
tokenizer
.decode(&ids_slice, skip_special_tokens)
.expect("failed to decode input")
}

fn criterion_benchmark(c: &mut Criterion) {
let tokenizer =
Tokenizer::from_file("data/bert-wiki.json").expect("failed to create tokenizer");
c.bench_function("decode", |b| {
b.iter(|| {
decode(
&tokenizer,
black_box([2829, 4419, 14523, 2058, 1996, 13971, 3899].to_vec()),
black_box(true),
)
})
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
59 changes: 59 additions & 0 deletions tokenizers/benches/important_tokenizer_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use tokenizers::tokenizer::Tokenizer;
extern crate criterion;

mod common;

use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;

use tokenizers::tokenizer::EncodeInput;

use common::{iter_bench_encode, iter_bench_encode_batch};
use std::ops::Deref;
use std::time::{Duration, Instant};

static BATCH_SIZE: usize = usize::MAX;

fn bench_inference(c: &mut Criterion) {
let tokenizer = Tokenizer::from_pretrained("mistralai/Mistral-7B-v0.1", None).unwrap();
let mut lines: Vec<EncodeInput> = vec![];
let mut batches: Vec<Vec<EncodeInput>> = vec![vec![]];
for line in BufReader::new(File::open(Path::new("data/big.txt")).unwrap()).lines() {
let line: EncodeInput = line.unwrap().into();
lines.push(line.clone());
if batches.last().unwrap().len() >= BATCH_SIZE {
batches.push(vec![]);
}
batches.last_mut().unwrap().push(line);
}

c.bench_function("mistral encode long input", |b| {
b.iter_custom(|iters| iter_bench_encode(iters, tokenizer.deref(), &lines))
});

c.bench_function("encode single batch of very long input", |b| {
b.iter_custom(|iters| iter_bench_encode_batch(iters, tokenizer.deref(), &batches))
});

c.bench_function("decode long input", |b| {
b.iter_custom(|iters| {
let mut duration = Duration::new(0, 0);
let mut line_index: usize = 0;
for _i in 0..iters {
if line_index >= lines.len() {
line_index = 0;
}
let input = batches[0].clone();
let start = Instant::now();
let _ = black_box(tokenizer.encode_batch(input, false));
duration = duration.checked_add(start.elapsed()).unwrap();
}
duration
})
});
}

criterion_group!(benches, bench_inference);
criterion_main!(benches);
33 changes: 31 additions & 2 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ lazy_static! {
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
)
.unwrap();
static ref RE_VEC: Vec<SysRegex> = {
let pattern = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
let mut vec = Vec::with_capacity(MAX_NUM_THREADS);
for _ in 0..MAX_NUM_THREADS {
vec.push(SysRegex::new(pattern).unwrap());
}
vec
};
static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
static ref CHAR_BYTES: HashMap<char, u8> =
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
Expand Down Expand Up @@ -111,12 +119,31 @@ impl ByteLevel {
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize - 1
}

const MAX_NUM_THREADS: usize = 128;

/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
/// their byte-level counterpart. It also splits the input according to the configured regex.
// TODO: Give the ability to modify this regex
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
let re_ref: &SysRegex = &RE;
let re_ref: &SysRegex = &RE_VEC[hash_current_thread() % MAX_NUM_THREADS]; // TODO use the thread thing here as well!
pretokenized.split(|_, mut normalized| {
if self.add_prefix_space && !normalized.get().starts_with(' ') {
normalized.prepend(" ");
Expand All @@ -142,7 +169,8 @@ impl PreTokenizer for ByteLevel {
.map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))),
);
}
normalized.transform(transformations, 0);
// normalized.transform(transformations, 0); // TODO here what whould happen if we ignore
// aligments?
Ok(())
})
}
Expand Down Expand Up @@ -172,6 +200,7 @@ impl Decoder for ByteLevel {
}
}

// TODO this is also somewhere we want to just skip if we are fast
/// As a `PostProcessor`, `ByteLevel` is in charge of trimming the offsets if necessary.
impl PostProcessor for ByteLevel {
fn added_tokens(&self, _is_pair: bool) -> usize {
Expand Down
46 changes: 41 additions & 5 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ impl std::hash::Hash for AddedToken {
}
}

use std::num::NonZeroU64;
use std::thread;

pub struct FakeThreadId(NonZeroU64);

fn hash_current_thread() -> usize {
// It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
// that works great for our use case of avoiding collisions in our array. Unfortunately,
// it's private. However, there are only so many ways you can layout a u64, so just transmute
// https://github.com/rust-lang/rust/issues/67939
const _: [u8; 8] = [0; std::mem::size_of::<thread::ThreadId>()];
const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
let x =
unsafe { std::mem::transmute::<thread::ThreadId, FakeThreadId>(thread::current().id()).0 };
u64::from(x) as usize - 1
}

const MAX_NUM_THREADS: usize = 128;

type MatchingSet = (AhoCorasick, Vec<u32>);

lazy_static! {
Expand Down Expand Up @@ -156,11 +175,16 @@ pub struct AddedVocabulary {
/// us remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,

/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
//// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,

// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_trie_vec: Vec<MatchingSet>,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie_vec: Vec<MatchingSet>,

/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}
Expand All @@ -181,8 +205,10 @@ impl AddedVocabulary {
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
split_trie: (trie.clone(), vec![]),
split_normalized_trie: (normalized_trie.clone(), vec![]),
split_trie_vec: vec![(trie, vec![]); MAX_NUM_THREADS],
split_normalized_trie_vec: vec![(normalized_trie, vec![]); MAX_NUM_THREADS],
encode_special_tokens: false,
}
}
Expand Down Expand Up @@ -345,6 +371,7 @@ impl AddedVocabulary {
.build(tokens.iter().map(|token| &token.content))
.expect("Failed to build tried when refreshing tokens");
self.split_trie = (trie, ids);
self.split_trie_vec = vec![self.split_trie.clone(); MAX_NUM_THREADS];

let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
let patterns: Vec<_> = ntokens
Expand All @@ -362,6 +389,7 @@ impl AddedVocabulary {
.build(patterns.iter().map(|content| content.get()))
.expect("Failed to build tried when refreshing tokens (normalized)");
self.split_normalized_trie = (normalized_trie, nids);
self.split_normalized_trie_vec = vec![self.split_normalized_trie.clone(); MAX_NUM_THREADS];
}

/// Find any AddedToken in the given sentence, using the provided MatchingSet.
Expand Down Expand Up @@ -465,7 +493,12 @@ impl AddedVocabulary {

// 1. We extract all the non-normalized tokens from the non-normalized string
pretokenized
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.split(|_, sequence| {
Ok(self.split_with_indices(
sequence,
&self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

// <s> normalized = False
Expand All @@ -484,7 +517,10 @@ impl AddedVocabulary {
pretokenized
.split(|_, mut sequence| {
normalizer.map(|n| n.normalize(&mut sequence));
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
Ok(self.split_with_indices(
sequence,
&self.split_normalized_trie_vec[hash_current_thread() % MAX_NUM_THREADS],
))
})
.expect("AddedVocabulary bad split");

Expand Down
Loading
Loading