Skip to content

Commit

Permalink
Merge pull request #6 from Zinnia-LLC/main
Browse files Browse the repository at this point in the history
added 200k context BPE
  • Loading branch information
connorjacobsen authored Jun 21, 2024
2 parents b78a873 + 69bfc61 commit 4e687e4
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 14 deletions.
3 changes: 2 additions & 1 deletion lib/tiktoken.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ defmodule Tiktoken do
"p50k_base" => Tiktoken.P50K,
"p50k_edit" => Tiktoken.P50KEdit,
"r50k_base" => Tiktoken.R50K,
"cl100k_base" => Tiktoken.CL100K
"cl100k_base" => Tiktoken.CL100K,
"o200k_base" => Tiktoken.O200K
}

def encoding_for_model(model) do
Expand Down
5 changes: 5 additions & 0 deletions lib/tiktoken/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ defmodule Tiktoken.Native do
def cl100k_encode_with_special_tokens(_input), do: err()
def cl100k_decode(_ids), do: err()

def o200k_encode_ordinary(_input), do: err()
def o200k_encode(_input, _allowed_special), do: err()
def o200k_encode_with_special_tokens(_input), do: err()
def o200k_decode(_ids), do: err()

def context_size_for_model(_model), do: err()

defp err, do: :erlang.nif_error(:nif_not_loaded)
Expand Down
23 changes: 23 additions & 0 deletions lib/tiktoken/o200k.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
defmodule Tiktoken.O200K do
@behaviour Tiktoken.Encoding

@impl Tiktoken.Encoding
def encode_ordinary(text) do
Tiktoken.Native.o200k_encode_ordinary(text)
end

@impl Tiktoken.Encoding
def encode(text, allowed_special \\ []) do
Tiktoken.Native.o200k_encode(text, allowed_special)
end

@impl Tiktoken.Encoding
def encode_with_special_tokens(text) do
Tiktoken.Native.o200k_encode_with_special_tokens(text)
end

@impl Tiktoken.Encoding
def decode(ids) do
Tiktoken.Native.o200k_decode(ids)
end
end
21 changes: 10 additions & 11 deletions native/tiktoken/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions native/tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ path = "src/lib.rs"
crate-type = ["cdylib"]

[dependencies]
rustler = "0.31.0"
tiktoken-rs = "0.5.8"
rustler = "0.33.0"
tiktoken-rs = "0.5.9"
47 changes: 47 additions & 0 deletions native/tiktoken/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::vec::Vec;
#[rustler::nif]
fn encoding_for_model(model: &str) -> Option<&str> {
match tiktoken_rs::tokenizer::get_tokenizer(model) {
Some(tiktoken_rs::tokenizer::Tokenizer::O200kBase) => Some("o200k_base"),
Some(tiktoken_rs::tokenizer::Tokenizer::Cl100kBase) => Some("cl100k_base"),
Some(tiktoken_rs::tokenizer::Tokenizer::P50kBase) => Some("p50k_base"),
Some(tiktoken_rs::tokenizer::Tokenizer::R50kBase) => Some("r50k_base"),
Expand Down Expand Up @@ -185,6 +186,48 @@ fn context_size_for_model(model: &str) -> usize {
tiktoken_rs::model::get_context_size(model)
}

// o200k

#[rustler::nif]
fn o200k_encode_ordinary(text: &str) -> Result<Vec<usize>, String> {
let bpe = tiktoken_rs::o200k_base_singleton();
{
let guard = bpe.lock();
Ok(guard.encode_ordinary(text))
}
}

#[rustler::nif]
fn o200k_encode(text: &str, allowed_special: Vec<&str>) -> Result<Vec<usize>, String> {
let set = HashSet::from_iter(allowed_special.iter().cloned());
let bpe = tiktoken_rs::o200k_base_singleton();
{
let guard = bpe.lock();
Ok(guard.encode(text, set))
}
}

#[rustler::nif]
fn o200k_encode_with_special_tokens(text: &str) -> Result<Vec<usize>, String> {
let bpe = tiktoken_rs::o200k_base_singleton();
{
let guard = bpe.lock();
Ok(guard.encode_with_special_tokens(text))
}
}

#[rustler::nif]
fn o200k_decode(ids: Vec<usize>) -> Result<String, String> {
let bpe = tiktoken_rs::o200k_base_singleton();
{
let guard = bpe.lock();
match guard.decode(ids) {
Ok(text) => Ok(text),
Err(e) => Err(e.to_string()),
}
}
}

rustler::init!(
"Elixir.Tiktoken.Native",
[
Expand All @@ -205,6 +248,10 @@ rustler::init!(
cl100k_encode,
cl100k_encode_with_special_tokens,
cl100k_decode,
o200k_encode_ordinary,
o200k_encode,
o200k_encode_with_special_tokens,
o200k_decode,
context_size_for_model
]
);

0 comments on commit 4e687e4

Please sign in to comment.