Skip to content

Commit

Permalink
add no pretokenizer bench
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Jan 3, 2025
1 parent dca4568 commit d334fb4
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions tokenizers/benches/llama3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate criterion;
use criterion::{Criterion, Throughput};
use itertools::Itertools;
use tokenizers::models::backtracking_bpe;
use tokenizers::PreTokenizerWrapper;
use tokenizers::Tokenizer;

pub fn llama3(c: &mut Criterion) {
Expand All @@ -12,7 +13,7 @@ pub fn llama3(c: &mut Criterion) {
group.throughput(Throughput::Bytes(data.bytes().len() as u64));

group.bench_function("llama3-backtracking", |b| {
let tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::<Vec<_>>(); // Convert HashMap into a Vec of (String, u32) tuples
//
vocab.sort_by(|a, b| a.1.cmp(&b.1));
Expand All @@ -23,7 +24,30 @@ pub fn llama3(c: &mut Criterion) {
.collect();
let model: backtracking_bpe::BacktrackingBpe =
backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None);
let tokenizer = Tokenizer::new(model);
tokenizer.with_model(model);
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
tokenizer
.encode_batch(criterion::black_box(data.clone()), add_special_tokens)
.unwrap()
})
});

group.bench_function("llama3-backtracking-no-pretok", |b| {
let mut tokenizer = Tokenizer::from_pretrained("gpt2", None).unwrap();
let mut vocab = &mut tokenizer.get_vocab(false).clone().into_iter().collect::<Vec<_>>(); // Convert HashMap into a Vec of (String, u32) tuples
//
vocab.sort_by(|a, b| a.1.cmp(&b.1));
vocab.truncate(vocab.len().saturating_sub(3));
let vocab: Vec<_> = vocab // Sort by u32 value
.into_iter() // IntoIterator to get the iterator of Vec<u8>
.map(|(tok, _)| Vec::from(tok.as_bytes()))
.collect();
let model: backtracking_bpe::BacktrackingBpe =
backtracking_bpe::BacktrackingBpe::from_dictionary(vocab, None, None);
tokenizer.with_model(model);
tokenizer.with_pre_tokenizer(None::<PreTokenizerWrapper>);
let data: Vec<_> = data.lines().collect();
let add_special_tokens = false;
b.iter(|| {
Expand Down

0 comments on commit d334fb4

Please sign in to comment.