From f1368de55041cdf7c2186dd542c67037aafc9830 Mon Sep 17 00:00:00 2001 From: Francesco Forcher Date: Sat, 3 Aug 2024 01:42:41 +0200 Subject: [PATCH] Add benchmarks (#32) * Add benchmarks, update Readme, add utility dict generating function --- Cargo.lock | 385 ++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 4 + README.md | 76 ++++++--- benches/benchmarks.rs | 90 ++++++++++ src/lib.rs | 4 +- src/main.rs | 28 ++- src/util/mod.rs | 51 +++++- tests/tests.rs | 49 +++++- 8 files changed, 648 insertions(+), 39 deletions(-) create mode 100644 benches/benchmarks.rs diff --git a/Cargo.lock b/Cargo.lock index dff131e..5871e21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.15" @@ -66,18 +72,63 @@ version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.13" @@ -136,6 +187,79 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -176,12 +300,28 @@ dependencies = [ "wasi", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "humantime" version = "2.1.0" @@ -200,12 +340,47 @@ dependencies = [ "similar", ] +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -236,6 +411,55 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "ppv-lite86" version = "0.2.18" @@ -293,6 +517,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "regex" version = "1.10.5" @@ -326,6 +570,53 @@ name = "regex-syntax" version = "0.8.4" source = "git+https://github.com/f-forcher/regex?branch=expose-state-iter#cdc3eb10549fc9ce0da2dd6a047c1e26aebdc9ef" +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.204" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.204" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "similar" version = "2.6.0" @@ -344,6 +635,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", + "criterion", "env_logger", "insta", "log", @@ -362,6 +654,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -374,12 +676,95 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "web-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +dependencies = [ + "windows-sys", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index b682cab..ecce7ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,8 @@ regex-automata = { git = 'https://github.com/f-forcher/regex', branch = 'expose- [dev-dependencies] insta = "1.39.0" +criterion = "0.5.1" +[[bench]] +name = "benchmarks" +harness = false \ No newline at end of file diff --git a/README.md b/README.md index 1124cbb..5318ed9 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Usage: structured-gen-rust [OPTIONS] Options: -m, --model Name of language model to use. At the moment only "mock" options are available - + [default: random-sample] Possible values: @@ -43,49 +43,47 @@ Options: [default: indexed-fsm] Possible values: - - - no-masking: Do not perform structured generation, - mask will allow all tokens - - naive: Use naive `O(N)` pattern matching algorithm, i.e. check for each - token if the resulting completed output would still validate the pattern - - indexed-fsm: The algorithm from the paper - [Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702), - precomputing the token vocabulary with a hashmap from the pattern - FSM states to valid tokens. - The masking step is now O(1), indepentent of the current output sequence length - + - no-masking: Do not perform structured generation, mask will allow all tokens + - naive: Use naive `O(N)` pattern matching algorithm, i.e. check for each token if the + resulting completed output would still validate the pattern + - indexed-fsm: The algorithm from the paper [Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702), + precomputing the token vocabulary with a hashmap from the pattern FSM states to valid tokens. + The masking step is now O(1), indepentent of the current output sequence length -v, --vocab ... The model vocabulary as a space separated list of words. Example: - + -v A B 3 ... - - If not present, the default vocabulary - ["A", "3", ".", "42", "B", ".2", "1"] will be used. + If not present, the default vocabulary ["A", "3", ".", "42", "B", ".2", "1"] will be used. + + If set, it overrides the --gen-vocab option. -i, --input - The input prompt to the model. Keep in mind that the - whole text completion including the prompt, - must conform to the pattern. The default is an empty string + The input prompt to the model. Keep in mind that the whole text completion including the + prompt, must conform to the pattern. The default is an empty string - [default: ] -p, --pattern + The regex pattern to which the model output should conform. Usually you want to anchor + it at both ends, i.e. `^...$`. Default is the float regex `^([0-9]*)?\.?[0-9]*$` - The regex pattern according to which the model output should conform. - Usually you want to anchor it at both ends, i.e. `^...$`. Default is - the float regex `^([0-9]*)?\.?[0-9]*$` - - [default: ^([0-9]*)?\.?[0-9]*$] -n, --n-tokens The max amount of tokens to produce - + [default: 15] + -g, --gen-vocab + You can set this to generate a vocabulary with `usize` tokens inside. + + The dictionary consists of the single chars `a-z A-Z 0-9 . : , ! ?` and every multiple + char cartesian product combination of these, generating up to `gen_vocab` tokens. + + If neither this or `--vocab` is set, the default vocabulary will be used (see `--vocab` for more details). + -h, --help Print help (see a summary with '-h') @@ -116,4 +114,28 @@ about 0.04 seconds (on my machine): ``` time cargo run --release -- -m deterministic -a indexed-fsm -n 10000 -v A B 3 . 45 - ``` \ No newline at end of file + ``` + +## Benchmarks +Benchmarks can be run by using the command `cargo bench`. +In `target/criterion/report/index.html` an HTML plot will be generated +containing several plots and statistics. + +## Debug and logs +Debug logs can be enabled by running with a debug profile (i.e. `cargo run` without the `--release` flag) and setting the `RUST_LOG` env variable. + +The log levels are, in order: `error`, `warn`, `info`, `debug`, `trace`. +Set the `RUST_LOG` env variable to one of them to enable more (or +less) detailed logging. Example: + +``` +RUST_LOG=debug cargo run -- -m deterministic -a indexed-fsm -g 500 -n 500 +``` + +## Notes +The `regex-automata` public API does not expose the internal states +of the automata, so a [fork](https://github.com/f-forcher/regex/tree/expose-state-iter) of the Rust stdlib `regex` repo has been made and its internals exposed. + +## Known issues +Using very large dictionaries may result in failure to produce the right output. The issue seems to be deterministic and independent of the +specific struct-gen algo used, only depending on the dictionary size. It could be connected to the `regex` crate internal implementation of FSM states. \ No newline at end of file diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs new file mode 100644 index 0000000..4c6b6c8 --- /dev/null +++ b/benches/benchmarks.rs @@ -0,0 +1,90 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use structured_gen_rust::{ + sample_model, + util::{generate_dict, DeterministicModel}, + LangModel, MaskingAlgorithmConfig, +}; + +fn _bench_setup_small_dict() -> ( + Vec, + &'static str, + impl LangModel, + String, + MaskingAlgorithmConfig<'static>, +) { + let tokens = vec!["A", "3", ".", "42", "B", ".2", "1"]; + let vocabulary: Vec = tokens.into_iter().map(|s| s.to_owned()).collect(); + + let pattern = r"^([0-9]*)?\.?[0-9]*$"; + let prompt = String::from(""); + let algo = MaskingAlgorithmConfig::IndexedFSM(pattern); + + let model = DeterministicModel::new(&vocabulary); + + (vocabulary, pattern, model, prompt, algo) +} + +fn bench_setup_gen_dict( + vocab_size: usize, +) -> ( + Vec, + &'static str, + impl LangModel, + String, + MaskingAlgorithmConfig<'static>, +) { + let vocabulary: Vec = generate_dict(vocab_size); + let pattern = r"^([0-9]*)?\.?[0-9]*$"; + let prompt = String::from(""); + let algo = MaskingAlgorithmConfig::IndexedFSM(pattern); + + let model = DeterministicModel::new(&vocabulary); + + (vocabulary, pattern, model, prompt, algo) +} + +pub fn simple_benchmark(c: &mut Criterion) { + let vocab_size = 500; + let (_, _, mut model, prompt, algo) = bench_setup_gen_dict(vocab_size); + let max_tokens: usize = 3000; + + c.bench_function( + &format!("sample determ indexed-fsm v={vocab_size} n={max_tokens}"), + |b| b.iter(|| sample_model(&mut model, black_box(max_tokens), &prompt, &algo).unwrap()), + ); +} + +fn naive_vs_fsm(c: &mut Criterion) { + let vocab_size = 500; + let mut group = c.benchmark_group("Sampling algo comparison"); + group.sample_size(10); + + let (_, pattern, mut model, prompt, _) = bench_setup_gen_dict(vocab_size); + let naive_algo = MaskingAlgorithmConfig::Naive(pattern); + let fsm_algo = MaskingAlgorithmConfig::IndexedFSM(pattern); + + for max_tokens in (10..50).step_by(10) { + group.bench_with_input( + BenchmarkId::new("Naive", max_tokens), + &max_tokens, + |b, max_tokens| { + b.iter(|| { + sample_model(&mut model, black_box(*max_tokens), &prompt, &naive_algo).unwrap() + }) + }, + ); + group.bench_with_input( + BenchmarkId::new("Indexed FSM", max_tokens), + &max_tokens, + |b, max_tokens: &usize| { + b.iter(|| { + sample_model(&mut model, black_box(*max_tokens), &prompt, &fsm_algo).unwrap() + }) + }, + ); + } + group.finish(); +} + +criterion_group!(benches, simple_benchmark, naive_vs_fsm); +criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index 1ceed90..806c599 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -328,9 +328,9 @@ fn naive_mask_from_pattern(vocabulary: &[Token], previous_samples: &str, pattern let re = Regex::new(pattern).expect("Invalid regex"); if re.is_match(possible_input) { + trace!("pattern {pattern} matches completion {possible_completion}"); mask.inner[i] = 1; } else { - trace!("pattern {pattern} does not match completion {possible_completion}"); mask.inner[i] = 0; } } @@ -419,6 +419,6 @@ fn find_subsequences(fsm: &dense::DFA>, token: &Token) -> Result>, @@ -33,7 +36,7 @@ struct Cli { #[arg(short, long, default_value_t = String::from(r""))] input: String, - /// The regex pattern according to which the model output should conform. + /// The regex pattern to which the model output should conform. /// Usually you want to anchor it at both ends, i.e. `^...$`. /// Default is the float regex `^([0-9]*)?\.?[0-9]*$` #[arg(short, long, default_value_t = String::from(r"^([0-9]*)?\.?[0-9]*$"))] @@ -42,6 +45,17 @@ struct Cli { /// The max amount of tokens to produce #[arg(short, long, default_value_t = 15)] n_tokens: usize, + + /// You can set this to generate a vocabulary with `usize` tokens inside. + /// + /// The dictionary consists of the + /// single chars `a-z A-Z 0-9 . : , ! ?` and every multiple char cartesian + /// product combination of these, generating up to `gen_vocab` tokens. + /// + /// If neither this or `--vocab` is set, the default vocabulary will be used + /// (see `--vocab` for more details). + #[arg(short, long)] + gen_vocab: Option, } fn default_small_dict() -> Vec { @@ -82,7 +96,13 @@ fn main() -> Result<()> { let cli = Cli::parse(); - let vocabulary = cli.vocab.unwrap_or_else(default_small_dict); + let vocabulary = match (cli.vocab, cli.gen_vocab) { + (Some(vocab), Some(_)) => vocab, + (Some(vocab), None) => vocab, + (None, None) => default_small_dict(), + (None, Some(gen_tokens)) => generate_dict(gen_tokens), + }; + let input_prompt = cli.input; let pattern = &cli.pattern[..]; let max_tokens = cli.n_tokens; @@ -107,6 +127,8 @@ fn main() -> Result<()> { } }; + trace!("Vocabulary: {vocabulary:?}"); + println!("Model output: \"{output}\""); Ok(()) diff --git a/src/util/mod.rs b/src/util/mod.rs index 8cd4b90..24ef74a 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -50,6 +50,17 @@ impl LangModel for RandomSampleModel { Err(WeightedError::AllWeightsZero) => return EOS_TOKEN, _ => todo!("error handling"), }; + + debug!( + "Next tokens allowed: {:?}", + mask.inner + .iter() + .enumerate() + .filter(|&(_i, m)| (*m != 0)) + .map(|(i, _m)| self.get_vocabulary()[i].clone()) + .collect::>() + ); + &self.vocabulary[self.dist.sample(&mut self.rng)] } @@ -79,10 +90,10 @@ impl LangModel for DeterministicModel { let mut out_token = EOS_TOKEN; for i in 0..self.vocabulary_size() { - let cyclic_idx = (self.idx + i) % self.vocabulary_size(); + let cyclic_idx = (self.idx + i + 1) % self.vocabulary_size(); if mask.inner[cyclic_idx] != 0 { out_token = &self.vocabulary[cyclic_idx]; - self.idx = (self.idx + 1) % self.vocabulary_size(); + self.idx = cyclic_idx % self.vocabulary_size(); break; } } @@ -104,3 +115,39 @@ impl LangModel for DeterministicModel { &self.vocabulary } } + +/// Helper function to generate a dictionary with up to num_tokens, +/// for testing and benchmarking. The dictionary consists of the +/// single chars a-z A-Z 0-9 . : , ! ? and every multiple char +/// combination of these, up to `num_tokens`. +pub fn generate_dict(num_tokens: usize) -> Vec { + // Create a vector to hold the characters + let mut chars: Vec = Vec::new(); + + for c in 'a'..='z' { + chars.push(c.to_string()); + } + for c in 'A'..='Z' { + chars.push(c.to_string()); + } + for c in '0'..='9' { + chars.push(c.to_string()); + } + let punctuation = ['.', ';', ',', '!', '?']; + for &c in &punctuation { + chars.push(c.to_string()); + } + + chars + .clone() + .into_iter() + .chain(chars.iter().flat_map(|prefix| { + chars.iter().map(|suffix| { + let mut new_string = prefix.clone(); + new_string.push_str(suffix); + new_string + }) + })) + .take(num_tokens) + .collect() +} diff --git a/tests/tests.rs b/tests/tests.rs index 1b73fb0..38f669e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,7 +1,7 @@ use rand::{rngs::SmallRng, SeedableRng}; use structured_gen_rust::{ sample_model, - util::{DeterministicModel, RandomSampleModel}, + util::{generate_dict, DeterministicModel, RandomSampleModel}, MaskingAlgorithmConfig, }; @@ -13,6 +13,24 @@ fn small_default_setup() -> (Vec, usize, &'static str) { (vocabulary, max_samples, pattern) } +// Function to pretty print long dicts +fn print_combinations(combinations: Vec, values_per_row: usize) -> String { + let mut out = String::new(); + let mut count = 0; + for combination in combinations { + out.push_str(&format!("{}, ", combination)); + count += 1; + if count % values_per_row == 0 { + out.push('\n'); + } + } + + if count % values_per_row != 0 { + out.push('\n'); + } + out +} + #[test] fn unmasked() { let (vocabulary, max_samples, _) = small_default_setup(); @@ -28,7 +46,7 @@ fn unmasked() { ) .unwrap(); - insta::assert_snapshot!(out, @"A3.42B.21A3.42B.21A"); + insta::assert_snapshot!(out, @"3.42B.21A3.42B.21A3"); // RandomSampleModel let rng = SmallRng::seed_from_u64(42); @@ -60,7 +78,7 @@ fn naive_mask() { ) .unwrap(); - insta::assert_snapshot!(out, @"33.421113342421113"); + insta::assert_snapshot!(out, @"3.421342134213421342"); // RandomSampleModel let rng = SmallRng::seed_from_u64(42); @@ -92,7 +110,7 @@ fn indexed_fsm_mask() { ) .unwrap(); - insta::assert_snapshot!(out, @"33.421113342421113"); + insta::assert_snapshot!(out, @"3.421342134213421342"); // RandomSampleModel let rng = SmallRng::seed_from_u64(42); @@ -132,7 +150,7 @@ fn fsm_with_input_shorter() { .unwrap(); assert!(out.len() < max_samples); - insta::assert_snapshot!(out, @"AAAAABBBBB"); + insta::assert_snapshot!(out, @"AAAABBBBB"); // RandomSampleModel let rng = SmallRng::seed_from_u64(42); @@ -149,3 +167,24 @@ fn fsm_with_input_shorter() { assert!(out_rng.len() < max_samples); insta::assert_snapshot!(out_rng, @"AAAAABBBBB"); } + +#[test] +fn test_generate_dict() { + let max_tokens = 200; + + let tokens = generate_dict(max_tokens); + + let out = print_combinations(tokens, 20); + insta::assert_snapshot!(out, @r###" + a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, + u, v, w, x, y, z, A, B, C, D, E, F, G, H, I, J, K, L, M, N, + O, P, Q, R, S, T, U, V, W, X, Y, Z, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, ., ;, ,, !, ?, aa, ab, ac, ad, ae, af, ag, ah, ai, aj, ak, al, am, + an, ao, ap, aq, ar, as, at, au, av, aw, ax, ay, az, aA, aB, aC, aD, aE, aF, aG, + aH, aI, aJ, aK, aL, aM, aN, aO, aP, aQ, aR, aS, aT, aU, aV, aW, aX, aY, aZ, a0, + a1, a2, a3, a4, a5, a6, a7, a8, a9, a., a;, a,, a!, a?, ba, bb, bc, bd, be, bf, + bg, bh, bi, bj, bk, bl, bm, bn, bo, bp, bq, br, bs, bt, bu, bv, bw, bx, by, bz, + bA, bB, bC, bD, bE, bF, bG, bH, bI, bJ, bK, bL, bM, bN, bO, bP, bQ, bR, bS, bT, + bU, bV, bW, bX, bY, bZ, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b., b;, b,, b!, + "###); +}