Skip to content

Commit

Permalink
Get rid of readonly variant of NativeBuffer.cs, it is kinda pointless…
Browse files Browse the repository at this point in the history
… and cumbersome to maintain both structures
  • Loading branch information
budgetdevv committed Oct 2, 2024
1 parent e3ff261 commit 32d6764
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 212 deletions.
175 changes: 93 additions & 82 deletions Native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,97 +4,108 @@ use std::slice;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::Encoding;

// #[inline(always)] is used aggressively - Realistically we only have a few callsites.

#[repr(C)]
union RawPointer<T>
{
mutable: *mut T,
readonly: *const T,
}

#[repr(C)]
pub struct Buffer<T>
pub struct NativeBuffer<T>
{
pub ptr: *mut T,
pub ptr: RawPointer<T>,
pub length: usize,
}

impl<T> Buffer<T>
impl<T> NativeBuffer<T>
{
pub fn new(ptr: *mut T, length: usize) -> Self
#[inline(always)]
pub fn wrap_mutable_ptr(ptr: *mut T, length: usize) -> Self
{
Buffer
NativeBuffer
{
ptr,
ptr: RawPointer { mutable: ptr },
length,
}
}

pub fn from_slice(slice: &mut [T]) -> Self
#[inline(always)]
pub fn wrap_ptr(ptr: *const T, length: usize) -> Self
{
Buffer
NativeBuffer
{
ptr: slice.as_mut_ptr(),
length: slice.len(),
ptr: RawPointer { readonly: ptr },
length,
}
}

pub unsafe fn to_slice(&self) -> &mut [T]
#[inline(always)]
pub fn from_slice(slice: &[T]) -> Self
{
return slice::from_raw_parts_mut(self.ptr, self.length)
NativeBuffer
{
ptr: RawPointer { readonly: slice.as_ptr() },
length: slice.len(),
}
}

pub fn empty() -> Self
#[inline(always)]
pub fn from_mutable_slice(slice: &mut [T]) -> Self
{
Buffer
NativeBuffer
{
ptr: null_mut(),
length: 0,
ptr: RawPointer { mutable: slice.as_mut_ptr() },
length: slice.len(),
}
}
}

#[repr(C)]
pub struct ReadOnlyBuffer<T>
{
ptr: *const T,
pub length: usize,
}

impl<T> ReadOnlyBuffer<T>
{
pub fn new(ptr: *const T, length: usize) -> Self
#[inline(always)]
pub unsafe fn as_slice(&self) -> &[T]
{
ReadOnlyBuffer
{
ptr,
length,
}
return slice::from_raw_parts(self.ptr.readonly, self.length)
}

pub fn from_slice(slice: &[T]) -> Self
#[inline(always)]
pub unsafe fn as_mutable_slice(&self) -> &mut [T]
{
ReadOnlyBuffer
{
ptr: slice.as_ptr(),
length: slice.len(),
}
return slice::from_raw_parts_mut(self.ptr.mutable, self.length)
}

pub unsafe fn as_slice(&self) -> &[T]
#[inline(always)]
pub fn from_vec(vec: &Vec<T>) -> Self
{
return slice::from_raw_parts(self.ptr, self.length)
let ptr = vec.as_ptr();
let length = vec.len();

return NativeBuffer
{
ptr: RawPointer { readonly: ptr },
length,
}
}

pub fn from_vec(vec: &mut Vec<T>) -> Self
#[inline(always)]
pub fn from_mutable_vec(vec: &mut Vec<T>) -> Self
{
let ptr = vec.as_mut_ptr();
let length = vec.len();

ReadOnlyBuffer
return NativeBuffer
{
ptr,
ptr: RawPointer { mutable: ptr },
length,
}
}

#[inline(always)]
pub fn empty() -> Self
{
ReadOnlyBuffer
NativeBuffer
{
ptr: null(),
ptr: RawPointer { mutable: null_mut() },
length: 0,
}
}
Expand All @@ -109,6 +120,7 @@ pub struct DropHandle<T=()>

impl <T> DropHandle<T>
{
#[inline(always)]
pub unsafe fn from_value_and_allocate_box(value: T) -> *mut DropHandle<T>
{
let val_box = Box::new(value);
Expand All @@ -131,6 +143,7 @@ impl <T> DropHandle<T>
return Box::into_raw(handle);
}

#[inline(always)]
pub unsafe fn from_handle(handle: *mut DropHandle<T>) -> Box<DropHandle<T>>
{
return Box::from_raw(handle);
Expand All @@ -140,10 +153,10 @@ impl <T> DropHandle<T>
#[repr(C)]
pub struct TokenizeOutput
{
pub ids: ReadOnlyBuffer<u32>,
pub attention_mask: ReadOnlyBuffer<u32>,
pub special_tokens_mask: ReadOnlyBuffer<u32>,
pub overflowing_tokens: ReadOnlyBuffer<TokenizeOutputOverflowedToken>,
pub ids: NativeBuffer<u32>,
pub attention_mask: NativeBuffer<u32>,
pub special_tokens_mask: NativeBuffer<u32>,
pub overflowing_tokens: NativeBuffer<TokenizeOutputOverflowedToken>,
pub original_output_free_handle: *const DropHandle<Encoding>,
pub overflowing_tokens_free_handle: *const DropHandle<Vec<TokenizeOutputOverflowedToken>>,
}
Expand All @@ -155,13 +168,13 @@ impl TokenizeOutput
{
// println!("Offsets {:?}", encoded_tokens.get_offsets());

let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids());
let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask());
let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
let ids = NativeBuffer::from_slice(encoded_tokens.get_ids());
let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask());
let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask());

let overflowing_tokens_slice = encoded_tokens.get_overflowing();

let overflowing_tokens: ReadOnlyBuffer<TokenizeOutputOverflowedToken>;
let overflowing_tokens: NativeBuffer<TokenizeOutputOverflowedToken>;

let overflowing_tokens_free_handle: *const DropHandle<Vec<TokenizeOutputOverflowedToken>>;

Expand All @@ -175,7 +188,7 @@ impl TokenizeOutput

// println!("Overflowing tokens: {:?}", overflowing_tokens.as_slice().len());

overflowing_tokens = ReadOnlyBuffer::from_vec(&mut overflowing_tokens_vec);
overflowing_tokens = NativeBuffer::from_mutable_vec(&mut overflowing_tokens_vec);

overflowing_tokens_free_handle = DropHandle::from_value_and_allocate_box(
overflowing_tokens_vec
Expand All @@ -184,7 +197,7 @@ impl TokenizeOutput

else
{
overflowing_tokens = ReadOnlyBuffer::empty();
overflowing_tokens = NativeBuffer::empty();
overflowing_tokens_free_handle = null();
}

Expand All @@ -206,19 +219,19 @@ impl TokenizeOutput
#[repr(C)]
pub struct TokenizeOutputOverflowedToken
{
pub ids: ReadOnlyBuffer<u32>,
pub attention_mask: ReadOnlyBuffer<u32>,
pub special_tokens_mask: ReadOnlyBuffer<u32>,
pub ids: NativeBuffer<u32>,
pub attention_mask: NativeBuffer<u32>,
pub special_tokens_mask: NativeBuffer<u32>,
}

impl TokenizeOutputOverflowedToken
{
#[inline(always)]
pub unsafe fn from_overflowing_encoded_tokens(encoded_tokens: &Encoding) -> Self
{
let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids());
let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask());
let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask());
let ids = NativeBuffer::from_slice(encoded_tokens.get_ids());
let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask());
let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask());

return TokenizeOutputOverflowedToken
{
Expand All @@ -231,7 +244,7 @@ impl TokenizeOutputOverflowedToken

#[no_mangle]
pub unsafe extern "C" fn allocate_tokenizer(
json_bytes: ReadOnlyBuffer<u8>)
json_bytes: NativeBuffer<u8>)
-> *mut Tokenizer
{
let json_bytes = json_bytes.as_slice();
Expand All @@ -252,7 +265,7 @@ pub unsafe extern "C" fn free_tokenizer(tokenizer_handle: *mut Tokenizer)
#[no_mangle]
pub unsafe extern "C" fn tokenizer_encode(
tokenizer_ptr: *mut Tokenizer,
text_buffer: ReadOnlyBuffer<u8>,
text_buffer: NativeBuffer<u8>,
add_special_tokens: bool)
-> TokenizeOutput
{
Expand All @@ -267,7 +280,7 @@ pub unsafe extern "C" fn tokenizer_encode(
#[no_mangle]
pub unsafe extern "C" fn tokenizer_encode_non_truncating(
tokenizer_ptr: *mut Tokenizer,
text_buffer: ReadOnlyBuffer<u8>,
text_buffer: NativeBuffer<u8>,
add_special_tokens: bool)
-> TokenizeOutput
{
Expand All @@ -282,7 +295,7 @@ pub unsafe extern "C" fn tokenizer_encode_non_truncating(
#[inline(always)]
pub unsafe extern "C" fn tokenizer_encode_core(
tokenizer_ptr: *mut Tokenizer,
text_buffer: ReadOnlyBuffer<u8>,
text_buffer: NativeBuffer<u8>,
truncate: bool,
add_special_tokens: bool)
-> TokenizeOutput
Expand All @@ -305,8 +318,8 @@ pub unsafe extern "C" fn tokenizer_encode_core(
#[no_mangle]
pub unsafe extern "C" fn tokenizer_encode_batch(
tokenizer_ptr: *mut Tokenizer,
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
output_buffer: Buffer<TokenizeOutput>,
text_buffers: NativeBuffer<NativeBuffer<u8>>,
output_buffer: NativeBuffer<TokenizeOutput>,
add_special_tokens: bool)
{
tokenizer_encode_batch_core(
Expand All @@ -321,8 +334,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch(
#[no_mangle]
pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
tokenizer_ptr: *mut Tokenizer,
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
output_buffer: Buffer<TokenizeOutput>,
text_buffers: NativeBuffer<NativeBuffer<u8>>,
output_buffer: NativeBuffer<TokenizeOutput>,
add_special_tokens: bool)
{
tokenizer_encode_batch_core(
Expand All @@ -337,8 +350,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating(
#[inline(always)]
pub unsafe extern "C" fn tokenizer_encode_batch_core(
tokenizer_ptr: *mut Tokenizer,
text_buffers: ReadOnlyBuffer<ReadOnlyBuffer<u8>>,
output_buffer: Buffer<TokenizeOutput>,
text_buffers: NativeBuffer<NativeBuffer<u8>>,
output_buffer: NativeBuffer<TokenizeOutput>,
truncate: bool,
add_special_tokens: bool)
{
Expand All @@ -355,12 +368,10 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
let encoded_tokens = match encoded_result
{
Ok(encoded) => encoded,
Err(err) => panic!("{}", err),
Err(error) => panic!("{}", error),
};

let mut current_ptr = output_buffer.ptr;

// println!("{:?}", current_ptr);
let mut current_ptr = output_buffer.ptr.mutable;

for encoded_token in encoded_tokens
{
Expand All @@ -373,7 +384,7 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core(
#[repr(C)]
pub struct DecodeOutput
{
pub text_buffer: ReadOnlyBuffer<u8>,
pub text_buffer: NativeBuffer<u8>,
pub free_handle: *mut DropHandle<String>
}

Expand All @@ -384,7 +395,7 @@ impl DecodeOutput
{
let text_bytes = text.as_mut_vec();

let text_buffer = ReadOnlyBuffer::from_vec(text_bytes);
let text_buffer = NativeBuffer::from_mutable_vec(text_bytes);

let free_handle = DropHandle::from_value_and_allocate_box(text);

Expand All @@ -399,7 +410,7 @@ impl DecodeOutput
#[no_mangle]
pub unsafe extern "C" fn tokenizer_decode(
tokenizer_ptr: *mut Tokenizer,
id_buffer: ReadOnlyBuffer<u32>)
id_buffer: NativeBuffer<u32>)
-> DecodeOutput
{
return tokenizer_decode_core(tokenizer_ptr, id_buffer, false);
Expand All @@ -408,7 +419,7 @@ pub unsafe extern "C" fn tokenizer_decode(
#[no_mangle]
pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens(
tokenizer_ptr: *mut Tokenizer,
id_buffer: ReadOnlyBuffer<u32>)
id_buffer: NativeBuffer<u32>)
-> DecodeOutput
{
return tokenizer_decode_core(tokenizer_ptr, id_buffer, true);
Expand All @@ -417,7 +428,7 @@ pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens(
#[inline(always)]
pub unsafe extern "C" fn tokenizer_decode_core(
tokenizer_ptr: *mut Tokenizer,
id_buffer: ReadOnlyBuffer<u32>,
id_buffer: NativeBuffer<u32>,
skip_special_tokens: bool)
-> DecodeOutput
{
Expand All @@ -434,15 +445,15 @@ pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>)
{
let free_data = DropHandle::from_handle(handle);

// println!("Freeing memory at {:p}", free_data.ptr_to_box);
// println!("Freeing memory at {:p}", free_data.ptr_to_box);

let drop_callback = free_data.drop_callback;

drop_callback(free_data.ptr_to_box);
}

#[no_mangle]
pub unsafe extern "C" fn free_with_multiple_handles(handle: ReadOnlyBuffer<*mut DropHandle<()>>)
pub unsafe extern "C" fn free_with_multiple_handles(handle: NativeBuffer<*mut DropHandle<()>>)
{
for free_data in handle.as_slice()
{
Expand Down
Loading

0 comments on commit 32d6764

Please sign in to comment.