Skip to content

Commit

Permalink
Add benchmarks (#32)
Browse files Browse the repository at this point in the history
* Add benchmarks, update Readme, add utility dict generating function
  • Loading branch information
f-forcher authored Aug 2, 2024
1 parent 4926e29 commit f1368de
Show file tree
Hide file tree
Showing 8 changed files with 648 additions and 39 deletions.
385 changes: 385 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ regex-automata = { git = 'https://github.com/f-forcher/regex', branch = 'expose-

[dev-dependencies]
insta = "1.39.0"
criterion = "0.5.1"

[[bench]]
name = "benchmarks"
harness = false
76 changes: 49 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Usage: structured-gen-rust [OPTIONS]
Options:
-m, --model <MODEL>
Name of language model to use. At the moment only "mock" options are available

[default: random-sample]

Possible values:
Expand All @@ -43,49 +43,47 @@ Options:
[default: indexed-fsm]

Possible values:

- no-masking: Do not perform structured generation,
mask will allow all tokens
- naive: Use naive `O(N)` pattern matching algorithm, i.e. check for each
token if the resulting completed output would still validate the pattern
- indexed-fsm: The algorithm from the paper
[Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702),
precomputing the token vocabulary with a hashmap from the pattern
FSM states to valid tokens.
The masking step is now O(1), indepentent of the current output sequence length

- no-masking: Do not perform structured generation, mask will allow all tokens
- naive: Use naive `O(N)` pattern matching algorithm, i.e. check for each token if the
resulting completed output would still validate the pattern
- indexed-fsm: The algorithm from the paper [Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702),
precomputing the token vocabulary with a hashmap from the pattern FSM states to valid tokens.
The masking step is now O(1), indepentent of the current output sequence length

-v, --vocab <VOCAB>...
The model vocabulary as a space separated list of words. Example:

-v A B 3 ...


If not present, the default vocabulary
["A", "3", ".", "42", "B", ".2", "1"] will be used.
If not present, the default vocabulary ["A", "3", ".", "42", "B", ".2", "1"] will be used.

If set, it overrides the --gen-vocab option.

-i, --input <INPUT>
The input prompt to the model. Keep in mind that the
whole text completion including the prompt,
must conform to the pattern. The default is an empty string
The input prompt to the model. Keep in mind that the whole text completion including the
prompt, must conform to the pattern. The default is an empty string

[default: ]

-p, --pattern <PATTERN>
The regex pattern to which the model output should conform. Usually you want to anchor
it at both ends, i.e. `^...$`. Default is the float regex `^([0-9]*)?\.?[0-9]*$`

The regex pattern according to which the model output should conform.
Usually you want to anchor it at both ends, i.e. `^...$`. Default is
the float regex `^([0-9]*)?\.?[0-9]*$`

[default: ^([0-9]*)?\.?[0-9]*$]

-n, --n-tokens <N_TOKENS>
The max amount of tokens to produce

[default: 15]

-g, --gen-vocab <GEN_VOCAB>
You can set this to generate a vocabulary with `usize` tokens inside.

The dictionary consists of the single chars `a-z A-Z 0-9 . : , ! ?` and every multiple
char cartesian product combination of these, generating up to `gen_vocab` tokens.

If neither this or `--vocab` is set, the default vocabulary will be used (see `--vocab` for more details).

-h, --help
Print help (see a summary with '-h')

Expand Down Expand Up @@ -116,4 +114,28 @@ about 0.04 seconds (on my machine):
```
time cargo run --release -- -m deterministic -a indexed-fsm -n 10000 -v A B 3 . 45
```
```
## Benchmarks
Benchmarks can be run by using the command `cargo bench`.
In `target/criterion/report/index.html` an HTML plot will be generated
containing several plots and statistics.
## Debug and logs
Debug logs can be enabled by running with a debug profile (i.e. `cargo run` without the `--release` flag) and setting the `RUST_LOG` env variable.
The log levels are, in order: `error`, `warn`, `info`, `debug`, `trace`.
Set the `RUST_LOG` env variable to one of them to enable more (or
less) detailed logging. Example:
```
RUST_LOG=debug cargo run -- -m deterministic -a indexed-fsm -g 500 -n 500
```
## Notes
The `regex-automata` public API does not expose the internal states
of the automata, so a [fork](https://github.com/f-forcher/regex/tree/expose-state-iter) of the Rust stdlib `regex` repo has been made and its internals exposed.
## Known issues
Using very large dictionaries may result in failure to produce the right output. The issue seems to be deterministic and independent of the
specific struct-gen algo used, only depending on the dictionary size. It could be connected to the `regex` crate internal implementation of FSM states.
90 changes: 90 additions & 0 deletions benches/benchmarks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use structured_gen_rust::{
sample_model,
util::{generate_dict, DeterministicModel},
LangModel, MaskingAlgorithmConfig,
};

fn _bench_setup_small_dict() -> (
Vec<String>,
&'static str,
impl LangModel,
String,
MaskingAlgorithmConfig<'static>,
) {
let tokens = vec!["A", "3", ".", "42", "B", ".2", "1"];
let vocabulary: Vec<String> = tokens.into_iter().map(|s| s.to_owned()).collect();

let pattern = r"^([0-9]*)?\.?[0-9]*$";
let prompt = String::from("");
let algo = MaskingAlgorithmConfig::IndexedFSM(pattern);

let model = DeterministicModel::new(&vocabulary);

(vocabulary, pattern, model, prompt, algo)
}

fn bench_setup_gen_dict(
vocab_size: usize,
) -> (
Vec<String>,
&'static str,
impl LangModel,
String,
MaskingAlgorithmConfig<'static>,
) {
let vocabulary: Vec<String> = generate_dict(vocab_size);
let pattern = r"^([0-9]*)?\.?[0-9]*$";
let prompt = String::from("");
let algo = MaskingAlgorithmConfig::IndexedFSM(pattern);

let model = DeterministicModel::new(&vocabulary);

(vocabulary, pattern, model, prompt, algo)
}

pub fn simple_benchmark(c: &mut Criterion) {
let vocab_size = 500;
let (_, _, mut model, prompt, algo) = bench_setup_gen_dict(vocab_size);
let max_tokens: usize = 3000;

c.bench_function(
&format!("sample determ indexed-fsm v={vocab_size} n={max_tokens}"),
|b| b.iter(|| sample_model(&mut model, black_box(max_tokens), &prompt, &algo).unwrap()),
);
}

fn naive_vs_fsm(c: &mut Criterion) {
let vocab_size = 500;
let mut group = c.benchmark_group("Sampling algo comparison");
group.sample_size(10);

let (_, pattern, mut model, prompt, _) = bench_setup_gen_dict(vocab_size);
let naive_algo = MaskingAlgorithmConfig::Naive(pattern);
let fsm_algo = MaskingAlgorithmConfig::IndexedFSM(pattern);

for max_tokens in (10..50).step_by(10) {
group.bench_with_input(
BenchmarkId::new("Naive", max_tokens),
&max_tokens,
|b, max_tokens| {
b.iter(|| {
sample_model(&mut model, black_box(*max_tokens), &prompt, &naive_algo).unwrap()
})
},
);
group.bench_with_input(
BenchmarkId::new("Indexed FSM", max_tokens),
&max_tokens,
|b, max_tokens: &usize| {
b.iter(|| {
sample_model(&mut model, black_box(*max_tokens), &prompt, &fsm_algo).unwrap()
})
},
);
}
group.finish();
}

criterion_group!(benches, simple_benchmark, naive_vs_fsm);
criterion_main!(benches);
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ fn naive_mask_from_pattern(vocabulary: &[Token], previous_samples: &str, pattern

let re = Regex::new(pattern).expect("Invalid regex");
if re.is_match(possible_input) {
trace!("pattern {pattern} matches completion {possible_completion}");
mask.inner[i] = 1;
} else {
trace!("pattern {pattern} does not match completion {possible_completion}");
mask.inner[i] = 0;
}
}
Expand Down Expand Up @@ -419,6 +419,6 @@ fn find_subsequences(fsm: &dense::DFA<Vec<u32>>, token: &Token) -> Result<Vec<St
all_subseqs.push(state_sequence);
}

debug!("Token {token} has subsequences {all_subseqs:?}");
trace!("Token {token} has subsequences {all_subseqs:?}");
Ok(all_subseqs)
}
28 changes: 25 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anyhow::Result;

use log::trace;
use rand::thread_rng;
use structured_gen_rust::{
sample_model,
util::{DeterministicModel, RandomSampleModel},
util::{generate_dict, DeterministicModel, RandomSampleModel},
MaskingAlgorithmConfig,
};

Expand All @@ -25,6 +26,8 @@ struct Cli {
///
/// If not present, the default vocabulary ["A", "3", ".", "42", "B", ".2", "1"]
/// will be used.
///
/// If set, it overrides the --gen-vocab option.
#[arg(short, long, value_parser, num_args = 1.., value_delimiter = ' ')]
vocab: Option<Vec<String>>,

Expand All @@ -33,7 +36,7 @@ struct Cli {
#[arg(short, long, default_value_t = String::from(r""))]
input: String,

/// The regex pattern according to which the model output should conform.
/// The regex pattern to which the model output should conform.
/// Usually you want to anchor it at both ends, i.e. `^...$`.
/// Default is the float regex `^([0-9]*)?\.?[0-9]*$`
#[arg(short, long, default_value_t = String::from(r"^([0-9]*)?\.?[0-9]*$"))]
Expand All @@ -42,6 +45,17 @@ struct Cli {
/// The max amount of tokens to produce
#[arg(short, long, default_value_t = 15)]
n_tokens: usize,

/// You can set this to generate a vocabulary with `usize` tokens inside.
///
/// The dictionary consists of the
/// single chars `a-z A-Z 0-9 . : , ! ?` and every multiple char cartesian
/// product combination of these, generating up to `gen_vocab` tokens.
///
/// If neither this or `--vocab` is set, the default vocabulary will be used
/// (see `--vocab` for more details).
#[arg(short, long)]
gen_vocab: Option<usize>,
}

fn default_small_dict() -> Vec<String> {
Expand Down Expand Up @@ -82,7 +96,13 @@ fn main() -> Result<()> {

let cli = Cli::parse();

let vocabulary = cli.vocab.unwrap_or_else(default_small_dict);
let vocabulary = match (cli.vocab, cli.gen_vocab) {
(Some(vocab), Some(_)) => vocab,
(Some(vocab), None) => vocab,
(None, None) => default_small_dict(),
(None, Some(gen_tokens)) => generate_dict(gen_tokens),
};

let input_prompt = cli.input;
let pattern = &cli.pattern[..];
let max_tokens = cli.n_tokens;
Expand All @@ -107,6 +127,8 @@ fn main() -> Result<()> {
}
};

trace!("Vocabulary: {vocabulary:?}");

println!("Model output: \"{output}\"");

Ok(())
Expand Down
51 changes: 49 additions & 2 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ impl<R: Rng> LangModel for RandomSampleModel<R> {
Err(WeightedError::AllWeightsZero) => return EOS_TOKEN,
_ => todo!("error handling"),
};

debug!(
"Next tokens allowed: {:?}",
mask.inner
.iter()
.enumerate()
.filter(|&(_i, m)| (*m != 0))
.map(|(i, _m)| self.get_vocabulary()[i].clone())
.collect::<Vec<_>>()
);

&self.vocabulary[self.dist.sample(&mut self.rng)]
}

Expand Down Expand Up @@ -79,10 +90,10 @@ impl LangModel for DeterministicModel {
let mut out_token = EOS_TOKEN;

for i in 0..self.vocabulary_size() {
let cyclic_idx = (self.idx + i) % self.vocabulary_size();
let cyclic_idx = (self.idx + i + 1) % self.vocabulary_size();
if mask.inner[cyclic_idx] != 0 {
out_token = &self.vocabulary[cyclic_idx];
self.idx = (self.idx + 1) % self.vocabulary_size();
self.idx = cyclic_idx % self.vocabulary_size();
break;
}
}
Expand All @@ -104,3 +115,39 @@ impl LangModel for DeterministicModel {
&self.vocabulary
}
}

/// Helper function to generate a dictionary with up to num_tokens,
/// for testing and benchmarking. The dictionary consists of the
/// single chars a-z A-Z 0-9 . : , ! ? and every multiple char
/// combination of these, up to `num_tokens`.
pub fn generate_dict(num_tokens: usize) -> Vec<Token> {
// Create a vector to hold the characters
let mut chars: Vec<String> = Vec::new();

for c in 'a'..='z' {
chars.push(c.to_string());
}
for c in 'A'..='Z' {
chars.push(c.to_string());
}
for c in '0'..='9' {
chars.push(c.to_string());
}
let punctuation = ['.', ';', ',', '!', '?'];
for &c in &punctuation {
chars.push(c.to_string());
}

chars
.clone()
.into_iter()
.chain(chars.iter().flat_map(|prefix| {
chars.iter().map(|suffix| {
let mut new_string = prefix.clone();
new_string.push_str(suffix);
new_string
})
}))
.take(num_tokens)
.collect()
}
Loading

0 comments on commit f1368de

Please sign in to comment.