diff --git a/lib/tiktoken.ex b/lib/tiktoken.ex index 14a0fe7..a611ecf 100644 --- a/lib/tiktoken.ex +++ b/lib/tiktoken.ex @@ -7,7 +7,8 @@ defmodule Tiktoken do "p50k_base" => Tiktoken.P50K, "p50k_edit" => Tiktoken.P50KEdit, "r50k_base" => Tiktoken.R50K, - "cl100k_base" => Tiktoken.CL100K + "cl100k_base" => Tiktoken.CL100K, + "o200k_base" => Tiktoken.O200K } def encoding_for_model(model) do diff --git a/lib/tiktoken/native.ex b/lib/tiktoken/native.ex index a41de1d..8ebdf8a 100644 --- a/lib/tiktoken/native.ex +++ b/lib/tiktoken/native.ex @@ -32,6 +32,11 @@ defmodule Tiktoken.Native do def cl100k_encode_with_special_tokens(_input), do: err() def cl100k_decode(_ids), do: err() + def o200k_encode_ordinary(_input), do: err() + def o200k_encode(_input, _allowed_special), do: err() + def o200k_encode_with_special_tokens(_input), do: err() + def o200k_decode(_ids), do: err() + def context_size_for_model(_model), do: err() defp err, do: :erlang.nif_error(:nif_not_loaded) diff --git a/lib/tiktoken/o200k.ex b/lib/tiktoken/o200k.ex new file mode 100644 index 0000000..030de90 --- /dev/null +++ b/lib/tiktoken/o200k.ex @@ -0,0 +1,23 @@ +defmodule Tiktoken.O200K do + @behaviour Tiktoken.Encoding + + @impl Tiktoken.Encoding + def encode_ordinary(text) do + Tiktoken.Native.o200k_encode_ordinary(text) + end + + @impl Tiktoken.Encoding + def encode(text, allowed_special \\ []) do + Tiktoken.Native.o200k_encode(text, allowed_special) + end + + @impl Tiktoken.Encoding + def encode_with_special_tokens(text) do + Tiktoken.Native.o200k_encode_with_special_tokens(text) + end + + @impl Tiktoken.Encoding + def decode(ids) do + Tiktoken.Native.o200k_decode(ids) + end +end diff --git a/native/tiktoken/Cargo.lock b/native/tiktoken/Cargo.lock index 6e888fd..3140b1f 100644 --- a/native/tiktoken/Cargo.lock +++ b/native/tiktoken/Cargo.lock @@ -79,9 +79,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "lazy_static" @@ -204,20 +204,19 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustler" -version = "0.31.0" +version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75d458f38f550976d0e4b347ca57241c192019777e46af7af73b27783287088" +checksum = "45d51ae0239c57c3a3e603dd855ace6795078ef33c95c85d397a100ac62ed352" dependencies = [ - "lazy_static", "rustler_codegen", "rustler_sys", ] [[package]] name = "rustler_codegen" -version = "0.31.0" +version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbd46408f51c0ca6a68dc36aa4f90e3554960bd1b7cc513e6ff2ccad7dd92aff" +checksum = "27061f1a2150ad64717dca73902678c124b0619b0d06563294df265bc84759e1" dependencies = [ "heck", "proc-macro2", @@ -227,9 +226,9 @@ dependencies = [ [[package]] name = "rustler_sys" -version = "2.3.2" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff76ba8524729d7c9db2b3e80f2269d1fdef39b5a60624c33fd794797e69b558" +checksum = "2062df0445156ae93cf695ef38c00683848d956b30507592143c01fe8fb52fda" dependencies = [ "regex", "unreachable", @@ -274,9 +273,9 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40894b788eb28bbb7e36bdc8b7b1b1488b9c93fa3730f315ab965330c94c0842" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" dependencies = [ "anyhow", "base64", diff --git a/native/tiktoken/Cargo.toml b/native/tiktoken/Cargo.toml index b630fa3..88185da 100644 --- a/native/tiktoken/Cargo.toml +++ b/native/tiktoken/Cargo.toml @@ -10,5 +10,5 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -rustler = "0.31.0" -tiktoken-rs = "0.5.8" +rustler = "0.33.0" +tiktoken-rs = "0.5.9" diff --git a/native/tiktoken/src/lib.rs b/native/tiktoken/src/lib.rs index 3a3defd..4d37c7b 100644 --- a/native/tiktoken/src/lib.rs +++ b/native/tiktoken/src/lib.rs @@ -4,6 +4,7 @@ use std::vec::Vec; #[rustler::nif] fn encoding_for_model(model: &str) -> Option<&str> { match tiktoken_rs::tokenizer::get_tokenizer(model) { + Some(tiktoken_rs::tokenizer::Tokenizer::O200kBase) => Some("o200k_base"), Some(tiktoken_rs::tokenizer::Tokenizer::Cl100kBase) => Some("cl100k_base"), Some(tiktoken_rs::tokenizer::Tokenizer::P50kBase) => Some("p50k_base"), Some(tiktoken_rs::tokenizer::Tokenizer::R50kBase) => Some("r50k_base"), @@ -185,6 +186,48 @@ fn context_size_for_model(model: &str) -> usize { tiktoken_rs::model::get_context_size(model) } +// o200k + +#[rustler::nif] +fn o200k_encode_ordinary(text: &str) -> Result, String> { + let bpe = tiktoken_rs::o200k_base_singleton(); + { + let guard = bpe.lock(); + Ok(guard.encode_ordinary(text)) + } +} + +#[rustler::nif] +fn o200k_encode(text: &str, allowed_special: Vec<&str>) -> Result, String> { + let set = HashSet::from_iter(allowed_special.iter().cloned()); + let bpe = tiktoken_rs::o200k_base_singleton(); + { + let guard = bpe.lock(); + Ok(guard.encode(text, set)) + } +} + +#[rustler::nif] +fn o200k_encode_with_special_tokens(text: &str) -> Result, String> { + let bpe = tiktoken_rs::o200k_base_singleton(); + { + let guard = bpe.lock(); + Ok(guard.encode_with_special_tokens(text)) + } +} + +#[rustler::nif] +fn o200k_decode(ids: Vec) -> Result { + let bpe = tiktoken_rs::o200k_base_singleton(); + { + let guard = bpe.lock(); + match guard.decode(ids) { + Ok(text) => Ok(text), + Err(e) => Err(e.to_string()), + } + } +} + rustler::init!( "Elixir.Tiktoken.Native", [ @@ -205,6 +248,10 @@ rustler::init!( cl100k_encode, cl100k_encode_with_special_tokens, cl100k_decode, + o200k_encode_ordinary, + o200k_encode, + o200k_encode_with_special_tokens, + o200k_decode, context_size_for_model ] );