diff --git a/Native/src/lib.rs b/Native/src/lib.rs index ae17eed..cb2b544 100644 --- a/Native/src/lib.rs +++ b/Native/src/lib.rs @@ -4,97 +4,108 @@ use std::slice; use tokenizers::tokenizer::Tokenizer; use tokenizers::Encoding; +// #[inline(always)] is used aggressively - Realistically we only have a few callsites. + +#[repr(C)] +union RawPointer +{ + mutable: *mut T, + readonly: *const T, +} + #[repr(C)] -pub struct Buffer +pub struct NativeBuffer { - pub ptr: *mut T, + pub ptr: RawPointer, pub length: usize, } -impl Buffer +impl NativeBuffer { - pub fn new(ptr: *mut T, length: usize) -> Self + #[inline(always)] + pub fn wrap_mutable_ptr(ptr: *mut T, length: usize) -> Self { - Buffer + NativeBuffer { - ptr, + ptr: RawPointer { mutable: ptr }, length, } } - pub fn from_slice(slice: &mut [T]) -> Self + #[inline(always)] + pub fn wrap_ptr(ptr: *const T, length: usize) -> Self { - Buffer + NativeBuffer { - ptr: slice.as_mut_ptr(), - length: slice.len(), + ptr: RawPointer { readonly: ptr }, + length, } } - pub unsafe fn to_slice(&self) -> &mut [T] + #[inline(always)] + pub fn from_slice(slice: &[T]) -> Self { - return slice::from_raw_parts_mut(self.ptr, self.length) + NativeBuffer + { + ptr: RawPointer { readonly: slice.as_ptr() }, + length: slice.len(), + } } - pub fn empty() -> Self + #[inline(always)] + pub fn from_mutable_slice(slice: &mut [T]) -> Self { - Buffer + NativeBuffer { - ptr: null_mut(), - length: 0, + ptr: RawPointer { mutable: slice.as_mut_ptr() }, + length: slice.len(), } } -} - -#[repr(C)] -pub struct ReadOnlyBuffer -{ - ptr: *const T, - pub length: usize, -} -impl ReadOnlyBuffer -{ - pub fn new(ptr: *const T, length: usize) -> Self + #[inline(always)] + pub unsafe fn as_slice(&self) -> &[T] { - ReadOnlyBuffer - { - ptr, - length, - } + return slice::from_raw_parts(self.ptr.readonly, self.length) } - pub fn from_slice(slice: &[T]) -> Self + #[inline(always)] + pub unsafe fn as_mutable_slice(&self) -> &mut [T] { - ReadOnlyBuffer - { - ptr: slice.as_ptr(), - length: slice.len(), - } + return slice::from_raw_parts_mut(self.ptr.mutable, self.length) } - pub unsafe fn as_slice(&self) -> &[T] + #[inline(always)] + pub fn from_vec(vec: &Vec) -> Self { - return slice::from_raw_parts(self.ptr, self.length) + let ptr = vec.as_ptr(); + let length = vec.len(); + + return NativeBuffer + { + ptr: RawPointer { readonly: ptr }, + length, + } } - pub fn from_vec(vec: &mut Vec) -> Self + #[inline(always)] + pub fn from_mutable_vec(vec: &mut Vec) -> Self { let ptr = vec.as_mut_ptr(); let length = vec.len(); - ReadOnlyBuffer + return NativeBuffer { - ptr, + ptr: RawPointer { mutable: ptr }, length, } } + #[inline(always)] pub fn empty() -> Self { - ReadOnlyBuffer + NativeBuffer { - ptr: null(), + ptr: RawPointer { mutable: null_mut() }, length: 0, } } @@ -109,6 +120,7 @@ pub struct DropHandle impl DropHandle { + #[inline(always)] pub unsafe fn from_value_and_allocate_box(value: T) -> *mut DropHandle { let val_box = Box::new(value); @@ -131,6 +143,7 @@ impl DropHandle return Box::into_raw(handle); } + #[inline(always)] pub unsafe fn from_handle(handle: *mut DropHandle) -> Box> { return Box::from_raw(handle); @@ -140,10 +153,10 @@ impl DropHandle #[repr(C)] pub struct TokenizeOutput { - pub ids: ReadOnlyBuffer, - pub attention_mask: ReadOnlyBuffer, - pub special_tokens_mask: ReadOnlyBuffer, - pub overflowing_tokens: ReadOnlyBuffer, + pub ids: NativeBuffer, + pub attention_mask: NativeBuffer, + pub special_tokens_mask: NativeBuffer, + pub overflowing_tokens: NativeBuffer, pub original_output_free_handle: *const DropHandle, pub overflowing_tokens_free_handle: *const DropHandle>, } @@ -155,13 +168,13 @@ impl TokenizeOutput { // println!("Offsets {:?}", encoded_tokens.get_offsets()); - let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids()); - let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask()); - let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask()); + let ids = NativeBuffer::from_slice(encoded_tokens.get_ids()); + let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask()); + let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask()); let overflowing_tokens_slice = encoded_tokens.get_overflowing(); - let overflowing_tokens: ReadOnlyBuffer; + let overflowing_tokens: NativeBuffer; let overflowing_tokens_free_handle: *const DropHandle>; @@ -175,7 +188,7 @@ impl TokenizeOutput // println!("Overflowing tokens: {:?}", overflowing_tokens.as_slice().len()); - overflowing_tokens = ReadOnlyBuffer::from_vec(&mut overflowing_tokens_vec); + overflowing_tokens = NativeBuffer::from_mutable_vec(&mut overflowing_tokens_vec); overflowing_tokens_free_handle = DropHandle::from_value_and_allocate_box( overflowing_tokens_vec @@ -184,7 +197,7 @@ impl TokenizeOutput else { - overflowing_tokens = ReadOnlyBuffer::empty(); + overflowing_tokens = NativeBuffer::empty(); overflowing_tokens_free_handle = null(); } @@ -206,9 +219,9 @@ impl TokenizeOutput #[repr(C)] pub struct TokenizeOutputOverflowedToken { - pub ids: ReadOnlyBuffer, - pub attention_mask: ReadOnlyBuffer, - pub special_tokens_mask: ReadOnlyBuffer, + pub ids: NativeBuffer, + pub attention_mask: NativeBuffer, + pub special_tokens_mask: NativeBuffer, } impl TokenizeOutputOverflowedToken @@ -216,9 +229,9 @@ impl TokenizeOutputOverflowedToken #[inline(always)] pub unsafe fn from_overflowing_encoded_tokens(encoded_tokens: &Encoding) -> Self { - let ids = ReadOnlyBuffer::from_slice(encoded_tokens.get_ids()); - let attention_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_attention_mask()); - let special_tokens_mask = ReadOnlyBuffer::from_slice(encoded_tokens.get_special_tokens_mask()); + let ids = NativeBuffer::from_slice(encoded_tokens.get_ids()); + let attention_mask = NativeBuffer::from_slice(encoded_tokens.get_attention_mask()); + let special_tokens_mask = NativeBuffer::from_slice(encoded_tokens.get_special_tokens_mask()); return TokenizeOutputOverflowedToken { @@ -231,7 +244,7 @@ impl TokenizeOutputOverflowedToken #[no_mangle] pub unsafe extern "C" fn allocate_tokenizer( - json_bytes: ReadOnlyBuffer) + json_bytes: NativeBuffer) -> *mut Tokenizer { let json_bytes = json_bytes.as_slice(); @@ -252,7 +265,7 @@ pub unsafe extern "C" fn free_tokenizer(tokenizer_handle: *mut Tokenizer) #[no_mangle] pub unsafe extern "C" fn tokenizer_encode( tokenizer_ptr: *mut Tokenizer, - text_buffer: ReadOnlyBuffer, + text_buffer: NativeBuffer, add_special_tokens: bool) -> TokenizeOutput { @@ -267,7 +280,7 @@ pub unsafe extern "C" fn tokenizer_encode( #[no_mangle] pub unsafe extern "C" fn tokenizer_encode_non_truncating( tokenizer_ptr: *mut Tokenizer, - text_buffer: ReadOnlyBuffer, + text_buffer: NativeBuffer, add_special_tokens: bool) -> TokenizeOutput { @@ -282,7 +295,7 @@ pub unsafe extern "C" fn tokenizer_encode_non_truncating( #[inline(always)] pub unsafe extern "C" fn tokenizer_encode_core( tokenizer_ptr: *mut Tokenizer, - text_buffer: ReadOnlyBuffer, + text_buffer: NativeBuffer, truncate: bool, add_special_tokens: bool) -> TokenizeOutput @@ -305,8 +318,8 @@ pub unsafe extern "C" fn tokenizer_encode_core( #[no_mangle] pub unsafe extern "C" fn tokenizer_encode_batch( tokenizer_ptr: *mut Tokenizer, - text_buffers: ReadOnlyBuffer>, - output_buffer: Buffer, + text_buffers: NativeBuffer>, + output_buffer: NativeBuffer, add_special_tokens: bool) { tokenizer_encode_batch_core( @@ -321,8 +334,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch( #[no_mangle] pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating( tokenizer_ptr: *mut Tokenizer, - text_buffers: ReadOnlyBuffer>, - output_buffer: Buffer, + text_buffers: NativeBuffer>, + output_buffer: NativeBuffer, add_special_tokens: bool) { tokenizer_encode_batch_core( @@ -337,8 +350,8 @@ pub unsafe extern "C" fn tokenizer_encode_batch_non_truncating( #[inline(always)] pub unsafe extern "C" fn tokenizer_encode_batch_core( tokenizer_ptr: *mut Tokenizer, - text_buffers: ReadOnlyBuffer>, - output_buffer: Buffer, + text_buffers: NativeBuffer>, + output_buffer: NativeBuffer, truncate: bool, add_special_tokens: bool) { @@ -355,12 +368,10 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core( let encoded_tokens = match encoded_result { Ok(encoded) => encoded, - Err(err) => panic!("{}", err), + Err(error) => panic!("{}", error), }; - let mut current_ptr = output_buffer.ptr; - - // println!("{:?}", current_ptr); + let mut current_ptr = output_buffer.ptr.mutable; for encoded_token in encoded_tokens { @@ -373,7 +384,7 @@ pub unsafe extern "C" fn tokenizer_encode_batch_core( #[repr(C)] pub struct DecodeOutput { - pub text_buffer: ReadOnlyBuffer, + pub text_buffer: NativeBuffer, pub free_handle: *mut DropHandle } @@ -384,7 +395,7 @@ impl DecodeOutput { let text_bytes = text.as_mut_vec(); - let text_buffer = ReadOnlyBuffer::from_vec(text_bytes); + let text_buffer = NativeBuffer::from_mutable_vec(text_bytes); let free_handle = DropHandle::from_value_and_allocate_box(text); @@ -399,7 +410,7 @@ impl DecodeOutput #[no_mangle] pub unsafe extern "C" fn tokenizer_decode( tokenizer_ptr: *mut Tokenizer, - id_buffer: ReadOnlyBuffer) + id_buffer: NativeBuffer) -> DecodeOutput { return tokenizer_decode_core(tokenizer_ptr, id_buffer, false); @@ -408,7 +419,7 @@ pub unsafe extern "C" fn tokenizer_decode( #[no_mangle] pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens( tokenizer_ptr: *mut Tokenizer, - id_buffer: ReadOnlyBuffer) + id_buffer: NativeBuffer) -> DecodeOutput { return tokenizer_decode_core(tokenizer_ptr, id_buffer, true); @@ -417,7 +428,7 @@ pub unsafe extern "C" fn tokenizer_decode_skip_special_tokens( #[inline(always)] pub unsafe extern "C" fn tokenizer_decode_core( tokenizer_ptr: *mut Tokenizer, - id_buffer: ReadOnlyBuffer, + id_buffer: NativeBuffer, skip_special_tokens: bool) -> DecodeOutput { @@ -434,7 +445,7 @@ pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>) { let free_data = DropHandle::from_handle(handle); -// println!("Freeing memory at {:p}", free_data.ptr_to_box); + // println!("Freeing memory at {:p}", free_data.ptr_to_box); let drop_callback = free_data.drop_callback; @@ -442,7 +453,7 @@ pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>) } #[no_mangle] -pub unsafe extern "C" fn free_with_multiple_handles(handle: ReadOnlyBuffer<*mut DropHandle<()>>) +pub unsafe extern "C" fn free_with_multiple_handles(handle: NativeBuffer<*mut DropHandle<()>>) { for free_data in handle.as_slice() { diff --git a/Tests/EncodeTests.cs b/Tests/EncodeTests.cs index 1727a1e..28c3d0e 100644 --- a/Tests/EncodeTests.cs +++ b/Tests/EncodeTests.cs @@ -69,7 +69,7 @@ public void EncodeOverflowing() TokenizeOutput tokenizeResult; - ReadOnlyNativeBuffer overflowingTokens; + NativeBuffer overflowingTokens; nuint numOverflowingTokensSegments; @@ -138,7 +138,7 @@ public void EncodeOverflowing() tokenizeResult.Dispose(); } - private static ulong[] WidenSafely(ReadOnlyNativeBuffer source) + private static ulong[] WidenSafely(NativeBuffer source) { var sourceSpan = source.AsReadOnlySpan(); @@ -165,7 +165,7 @@ public void EncodeOverflowingWiden() TokenizeOutput tokenizeResult; - ReadOnlyNativeBuffer overflowingTokens; + NativeBuffer overflowingTokens; nuint numOverflowingTokensSegments; @@ -258,5 +258,7 @@ public void EncodeWithMaxManagedLength() tokenizeResult.Dispose(); } } + + } } \ No newline at end of file diff --git a/Tests/SIMDHelpersTests.cs b/Tests/SIMDHelpersTests.cs index 2553dd7..847d098 100644 --- a/Tests/SIMDHelpersTests.cs +++ b/Tests/SIMDHelpersTests.cs @@ -46,7 +46,7 @@ public void WidenTest() slot = (uint) currentIndex++; } - srcBuffer.Buffer.AsReadOnly().Widen(destBuffer.Buffer); + srcBuffer.Buffer.Widen(destBuffer.Buffer); srcSpan.ToArray().Should().BeEquivalentTo(destSpan.ToArray()); } diff --git a/Tests/DebugHelpers.cs b/Tests/TestHelpers.cs similarity index 95% rename from Tests/DebugHelpers.cs rename to Tests/TestHelpers.cs index 0b41752..cc2cab9 100644 --- a/Tests/DebugHelpers.cs +++ b/Tests/TestHelpers.cs @@ -1,8 +1,8 @@ using System.Text; -namespace Sample +namespace Tests { - public static class DebugHelpers + public static class TestHelpers { public static string GetSpanPrintString(this Span span) { diff --git a/Tokenizers.NET/Collections/NativeBuffer.cs b/Tokenizers.NET/Collections/NativeBuffer.cs index f7d5dcd..d9f40a1 100644 --- a/Tokenizers.NET/Collections/NativeBuffer.cs +++ b/Tokenizers.NET/Collections/NativeBuffer.cs @@ -38,78 +38,6 @@ public Span AsSpan() return MemoryMarshal.CreateSpan(ref *Ptr, (int) Length); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReadOnlyNativeBuffer AsReadOnly() - { - return Unsafe.BitCast, ReadOnlyNativeBuffer>(this); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public NativeBuffer Cast() where F: unmanaged - { - return new((F*) Ptr, UnsafeHelpers.CalculateCastLength(Length)); - } - - public struct Enumerator - { - private T* CurrentPtr; - - private readonly T* LastPtrOffsetByOne; - - public ref T Current - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => ref *CurrentPtr; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal Enumerator(T* ptr, nuint length) - { - LastPtrOffsetByOne = ptr + length; - CurrentPtr = ptr - 1; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public bool MoveNext() - { - return ++CurrentPtr != LastPtrOffsetByOne; - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public Enumerator GetEnumerator() - { - return new(Ptr, Length); - } - } - - [StructLayout(LayoutKind.Sequential)] - public readonly unsafe struct ReadOnlyNativeBuffer(T* ptr, nuint length) where T: unmanaged - { - internal readonly T* Ptr = ptr; - public readonly nuint Length = length; - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReadOnlyNativeBuffer(T[] pinnedBuffer, nuint length) : - this(ref MemoryMarshal.GetArrayDataReference(pinnedBuffer), length) - { - // Nothing here - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReadOnlyNativeBuffer(ReadOnlySpan pinnedSpan) : - this(ref MemoryMarshal.GetReference(pinnedSpan), (nuint) pinnedSpan.Length) - { - // Nothing here - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReadOnlyNativeBuffer(ref T pinnedStart, nuint length) : - this((T*) Unsafe.AsPointer(ref pinnedStart), length) - { - // Nothing here - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public ReadOnlySpan AsReadOnlySpan() { @@ -117,13 +45,7 @@ public ReadOnlySpan AsReadOnlySpan() } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public NativeBuffer AsWritable() - { - return Unsafe.BitCast, NativeBuffer>(this); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ReadOnlyNativeBuffer Cast() where F: unmanaged + public NativeBuffer Cast() where F: unmanaged { return new((F*) Ptr, UnsafeHelpers.CalculateCastLength(Length)); } @@ -134,7 +56,7 @@ public struct Enumerator private readonly T* LastPtrOffsetByOne; - public readonly ref T Current + public ref T Current { [MethodImpl(MethodImplOptions.AggressiveInlining)] get => ref *CurrentPtr; diff --git a/Tokenizers.NET/DecodeOutput.cs b/Tokenizers.NET/DecodeOutput.cs index 1447e45..8425e55 100644 --- a/Tokenizers.NET/DecodeOutput.cs +++ b/Tokenizers.NET/DecodeOutput.cs @@ -9,7 +9,7 @@ namespace Tokenizers.NET [StructLayout(LayoutKind.Sequential)] public readonly struct DecodeOutput: IDisposable { - public readonly ReadOnlyNativeBuffer TextBuffer; + public readonly NativeBuffer TextBuffer; public readonly nint FreeHandle; diff --git a/Tokenizers.NET/SIMDHelpers.cs b/Tokenizers.NET/SIMDHelpers.cs index 1c11b55..469c474 100644 --- a/Tokenizers.NET/SIMDHelpers.cs +++ b/Tokenizers.NET/SIMDHelpers.cs @@ -10,7 +10,7 @@ namespace Tokenizers.NET { public static unsafe class SIMDHelpers { - public static NativeMemory Widen(this ReadOnlyNativeBuffer srcBuffer) + public static NativeMemory Widen(this NativeBuffer srcBuffer) { var result = new NativeMemory(srcBuffer.Length); @@ -19,19 +19,19 @@ public static NativeMemory Widen(this ReadOnlyNativeBuffer srcBuffe return result; } - public static void Widen(this ReadOnlyNativeBuffer srcBuffer, NativeBuffer destBuffer) + public static void Widen(this NativeBuffer srcBuffer, NativeBuffer destBuffer) { srcBuffer.WidenInternal(destBuffer, performLengthCheck: true); } - public static void WidenUnsafely(this ReadOnlyNativeBuffer srcBuffer, NativeBuffer destBuffer) + public static void WidenUnsafely(this NativeBuffer srcBuffer, NativeBuffer destBuffer) { srcBuffer.WidenInternal(destBuffer, performLengthCheck: false); } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void WidenInternal( - this ReadOnlyNativeBuffer srcBuffer, + this NativeBuffer srcBuffer, NativeBuffer destBuffer, bool performLengthCheck) { diff --git a/Tokenizers.NET/TokenizeOutput.cs b/Tokenizers.NET/TokenizeOutput.cs index 483dd19..47e3829 100644 --- a/Tokenizers.NET/TokenizeOutput.cs +++ b/Tokenizers.NET/TokenizeOutput.cs @@ -8,60 +8,60 @@ namespace Tokenizers.NET { public interface ITokenizeOutput { - public ReadOnlyNativeBuffer IDs { get; } + public NativeBuffer IDs { get; } - public ReadOnlyNativeBuffer AttentionMask { get; } + public NativeBuffer AttentionMask { get; } - public ReadOnlyNativeBuffer SpecialTokensMask { get; } + public NativeBuffer SpecialTokensMask { get; } } [StructLayout(LayoutKind.Sequential)] public readonly struct TokenizeOutputOverflowedToken: ITokenizeOutput { - public readonly ReadOnlyNativeBuffer IDs; + public readonly NativeBuffer IDs; - public readonly ReadOnlyNativeBuffer AttentionMask; + public readonly NativeBuffer AttentionMask; - public readonly ReadOnlyNativeBuffer SpecialTokensMask; + public readonly NativeBuffer SpecialTokensMask; - ReadOnlyNativeBuffer ITokenizeOutput.IDs => IDs; + NativeBuffer ITokenizeOutput.IDs => IDs; - ReadOnlyNativeBuffer ITokenizeOutput.AttentionMask => AttentionMask; + NativeBuffer ITokenizeOutput.AttentionMask => AttentionMask; - ReadOnlyNativeBuffer ITokenizeOutput.SpecialTokensMask => SpecialTokensMask; + NativeBuffer ITokenizeOutput.SpecialTokensMask => SpecialTokensMask; } [StructLayout(LayoutKind.Sequential)] public readonly unsafe struct TokenizeOutput: ITokenizeOutput, IDisposable { - public readonly ReadOnlyNativeBuffer IDs; + public readonly NativeBuffer IDs; - public readonly ReadOnlyNativeBuffer AttentionMask; + public readonly NativeBuffer AttentionMask; - public readonly ReadOnlyNativeBuffer SpecialTokensMask; + public readonly NativeBuffer SpecialTokensMask; - public readonly ReadOnlyNativeBuffer OverflowingTokens; + public readonly NativeBuffer OverflowingTokens; public readonly nint OriginalOutputFreeHandle; private readonly nint OverflowingTokensFreeHandle; - ReadOnlyNativeBuffer ITokenizeOutput.IDs => IDs; + NativeBuffer ITokenizeOutput.IDs => IDs; - ReadOnlyNativeBuffer ITokenizeOutput.AttentionMask => AttentionMask; + NativeBuffer ITokenizeOutput.AttentionMask => AttentionMask; - ReadOnlyNativeBuffer ITokenizeOutput.SpecialTokensMask => SpecialTokensMask; + NativeBuffer ITokenizeOutput.SpecialTokensMask => SpecialTokensMask; private interface IGatherFieldAccessor { - public static abstract ReadOnlyNativeBuffer AccessField(T item) + public static abstract NativeBuffer AccessField(T item) where T: struct, ITokenizeOutput; } private struct AccessIDs: IGatherFieldAccessor { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static ReadOnlyNativeBuffer AccessField(T item) + public static NativeBuffer AccessField(T item) where T: struct, ITokenizeOutput { return item.IDs; @@ -71,7 +71,7 @@ public static ReadOnlyNativeBuffer AccessField(T item) private struct AccessAttentionMask: IGatherFieldAccessor { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static ReadOnlyNativeBuffer AccessField(T item) + public static NativeBuffer AccessField(T item) where T: struct, ITokenizeOutput { return item.AttentionMask; @@ -81,7 +81,7 @@ public static ReadOnlyNativeBuffer AccessField(T item) private struct AccessSpecialTokensMask: IGatherFieldAccessor { [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static ReadOnlyNativeBuffer AccessField(T item) + public static NativeBuffer AccessField(T item) where T: struct, ITokenizeOutput { return item.SpecialTokensMask; diff --git a/Tokenizers.NET/Tokenizer.cs b/Tokenizers.NET/Tokenizer.cs index ea616ce..7e90aa2 100644 --- a/Tokenizers.NET/Tokenizer.cs +++ b/Tokenizers.NET/Tokenizer.cs @@ -296,7 +296,7 @@ public Tokenizer() var expectedMaxBatches = config.ExpectedMaxBatches; - var rawTokenizerData = config.RawTokenizerData.Buffer.AsReadOnly(); + var rawTokenizerData = config.RawTokenizerData.Buffer; var u8StringBuffers = U8StringBuffers = AllocationHelpers.AllocatePinnedUninitialized>( expectedMaxBatches @@ -340,7 +340,7 @@ internal TokenizeOutput TokenizeInternal(string input, bool addSpecialTokens) var bytesWritten = Encoding.UTF8.GetBytes(input, allocation.AsSpan()); - var u8String = new ReadOnlyNativeBuffer(allocation.Ptr, (nuint) bytesWritten); + var u8String = new NativeBuffer(allocation.Ptr, (nuint) bytesWritten); var result = TokenizerNativeMethods.TokenizerEncode( TokenizerHandle, @@ -447,8 +447,8 @@ internal void TokenizeBatchInternal( currentU8String++; } - var readonlyU8Strings = new ReadOnlyNativeBuffer>( - (ReadOnlyNativeBuffer*) u8StringsPtr, + var readonlyU8Strings = new NativeBuffer>( + (NativeBuffer*) u8StringsPtr, numInputs ); @@ -466,7 +466,7 @@ internal void TokenizeBatchInternal( foreach (var buffer in readonlyU8Strings) { - if (!allocator.IsManagedAllocation(buffer.AsWritable())) + if (!allocator.IsManagedAllocation(buffer)) { NativeMemory.FreeWithPtrUnsafely(buffer.Ptr); } @@ -487,12 +487,12 @@ public DecodeOutput Decode(ReadOnlySpan ids, bool skipSpecialTokens) fixed(uint* ptr = &first) { - return Decode((ReadOnlyNativeBuffer) new(ptr, (nuint) ids.Length), skipSpecialTokens); + return Decode((NativeBuffer) new(ptr, (nuint) ids.Length), skipSpecialTokens); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - public DecodeOutput Decode(ReadOnlyNativeBuffer ids, bool skipSpecialTokens) + public DecodeOutput Decode(NativeBuffer ids, bool skipSpecialTokens) { var tokenizerHandle = TokenizerHandle; @@ -514,7 +514,7 @@ public DecodeOutput DecodeMutating(NativeBuffer ids, bool skipSpecialToke { var tokenizerHandle = TokenizerHandle; - var mutated = ids.NarrowMutating().AsReadOnly(); + var mutated = ids.NarrowMutating(); // The length should still be the same, even though the actual underlying length is double mutated = new(mutated.Ptr, ids.Length); diff --git a/Tokenizers.NET/TokenizerNativeMethods.cs b/Tokenizers.NET/TokenizerNativeMethods.cs index 0f2fc99..fb777eb 100644 --- a/Tokenizers.NET/TokenizerNativeMethods.cs +++ b/Tokenizers.NET/TokenizerNativeMethods.cs @@ -9,7 +9,7 @@ internal static unsafe partial class TokenizerNativeMethods private const string DLL_NAME = "tokenizers_net"; [LibraryImport(DLL_NAME, EntryPoint = "allocate_tokenizer")] - public static partial nint AllocateTokenizer(ReadOnlyNativeBuffer jsonBytes); + public static partial nint AllocateTokenizer(NativeBuffer jsonBytes); [LibraryImport(DLL_NAME, EntryPoint = "free_tokenizer")] public static partial void FreeTokenizer(nint tokenizerHandle); @@ -17,7 +17,7 @@ internal static unsafe partial class TokenizerNativeMethods [MethodImpl(MethodImplOptions.AggressiveInlining)] public static TokenizeOutput TokenizerEncode( nint tokenizerPtr, - ReadOnlyNativeBuffer textNativeBuffer, + NativeBuffer textNativeBuffer, bool addSpecialTokens, bool truncate) { @@ -40,7 +40,7 @@ public static TokenizeOutput TokenizerEncode( [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] private static partial TokenizeOutput TokenizerEncode( nint tokenizerPtr, - ReadOnlyNativeBuffer textNativeBuffer, + NativeBuffer textNativeBuffer, byte addSpecialTokens ); @@ -48,14 +48,14 @@ byte addSpecialTokens [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] private static partial TokenizeOutput TokenizerEncodeNonTruncating( nint tokenizerPtr, - ReadOnlyNativeBuffer textNativeBuffer, + NativeBuffer textNativeBuffer, byte addSpecialTokens ); [MethodImpl(MethodImplOptions.AggressiveInlining)] public static void TokenizerEncodeBatch( nint tokenizerPtr, - ReadOnlyNativeBuffer> textNativeBuffers, + NativeBuffer> textNativeBuffers, NativeBuffer outputNativeBuffer, bool addSpecialTokens, bool truncate) @@ -79,7 +79,7 @@ public static void TokenizerEncodeBatch( [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] private static partial void TokenizerEncodeBatch( nint tokenizerPtr, - ReadOnlyNativeBuffer> textNativeBuffers, + NativeBuffer> textNativeBuffers, NativeBuffer outputNativeBuffer, byte addSpecialTokens ); @@ -88,7 +88,7 @@ byte addSpecialTokens [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] private static partial void TokenizerEncodeBatchNonTruncating( nint tokenizerPtr, - ReadOnlyNativeBuffer> textNativeBuffers, + NativeBuffer> textNativeBuffers, NativeBuffer outputNativeBuffer, byte addSpecialTokens ); @@ -96,7 +96,7 @@ byte addSpecialTokens [MethodImpl(MethodImplOptions.AggressiveInlining)] public static DecodeOutput TokenizerDecode( nint tokenizerPtr, - ReadOnlyNativeBuffer ids, + NativeBuffer ids, bool skipSpecialTokens) { if (skipSpecialTokens) @@ -112,11 +112,11 @@ public static DecodeOutput TokenizerDecode( [LibraryImport(DLL_NAME, EntryPoint = "tokenizer_decode")] [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] - private static partial DecodeOutput TokenizerDecode(nint tokenizerPtr, ReadOnlyNativeBuffer idBuffer); + private static partial DecodeOutput TokenizerDecode(nint tokenizerPtr, NativeBuffer idBuffer); [LibraryImport(DLL_NAME, EntryPoint = "tokenizer_decode_skip_special_tokens")] [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] - private static partial DecodeOutput TokenizerDecodeSkipSpecialTokens(nint tokenizerPtr, ReadOnlyNativeBuffer idBuffer); + private static partial DecodeOutput TokenizerDecodeSkipSpecialTokens(nint tokenizerPtr, NativeBuffer idBuffer); [LibraryImport(DLL_NAME, EntryPoint = "free_with_handle")] [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -124,6 +124,6 @@ public static DecodeOutput TokenizerDecode( [LibraryImport(DLL_NAME, EntryPoint = "free_with_multiple_handles")] [SuppressGCTransition, MethodImpl(MethodImplOptions.AggressiveInlining)] - public static partial void FreeWithMultipleHandles(ReadOnlyNativeBuffer handles); + public static partial void FreeWithMultipleHandles(NativeBuffer handles); } } \ No newline at end of file