Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Runtime detection, take 2 #86

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ name = "value_operator"
harness = false

[features]
default = []
default = ["runtime-detection"]
Copy link
Collaborator

@liuq19 liuq19 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the runtime detection always has fewer overheads, I think it is better not to enable the feature in the default

runtime-detection = []

# Use an arbitrary precision number type representation when parsing JSON into `sonic_rs::Value`.
# This allows the JSON numbers will be serialized without loss of precision.
Expand Down
4 changes: 2 additions & 2 deletions src/lazyvalue/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ where
let lv = LazyValue::new(json.from_subset(sub), status == ParseStatus::HasEscaped)?;

// validate the utf-8 if slice
let index = parser.read.index();
let index = parser.read().index();
if json.need_utf8_valid() {
from_utf8(&slice[..index])?;
}
Expand Down Expand Up @@ -429,7 +429,7 @@ where
let nodes = parser.get_many(tree, true)?;

// validate the utf-8 if slice
let index = parser.read.index();
let index = parser.read().index();
if json.need_utf8_valid() {
from_utf8(&slice[..index])?;
}
Expand Down
12 changes: 6 additions & 6 deletions src/lazyvalue/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ impl<'de> ObjectJsonIter<'de> {
if self.parser.is_none() {
let slice = self.json.as_ref();
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
let parser = Parser::new(Read::new(slice, check));
let mut parser = Parser::new(Read::new(slice, check));
// check invalid utf8
if let Err(err) = parser.read.check_utf8_final() {
if let Err(err) = parser.read().check_utf8_final() {
self.ending = true;
return Some(Err(err));
}
self.parser = Some(parser);
}

let parser = unsafe { self.parser.as_mut().unwrap_unchecked() };
unsafe { parser.read.update_slice(self.json.as_ref().as_ptr()) };
unsafe { parser.read().update_slice(self.json.as_ref().as_ptr()) };
match parser.parse_entry_lazy(&mut self.strbuf, &mut self.first, check) {
Ok(ret) => {
if let Some((key, val, has_escaped)) = ret {
Expand Down Expand Up @@ -146,17 +146,17 @@ impl<'de> ArrayJsonIter<'de> {
if self.parser.is_none() {
let slice = self.json.as_ref();
let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
let parser = Parser::new(Read::new(slice, check));
let mut parser = Parser::new(Read::new(slice, check));
// check invalid utf8
if let Err(err) = parser.read.check_utf8_final() {
if let Err(err) = parser.read().check_utf8_final() {
self.ending = true;
return Some(Err(err));
}
self.parser = Some(parser);
}

let parser = self.parser.as_mut().unwrap();
unsafe { parser.read.update_slice(self.json.as_ref().as_ptr()) };
unsafe { parser.read().update_slice(self.json.as_ref().as_ptr()) };
match parser.parse_array_elem_lazy(&mut self.first, check) {
Ok(ret) => {
if let Some((ret, has_escaped)) = ret {
Expand Down
136 changes: 78 additions & 58 deletions src/parser.rs → src/parser/inner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
marker::PhantomData,
num::NonZeroU8,
ops::Deref,
slice::{from_raw_parts, from_raw_parts_mut},
Expand All @@ -9,7 +10,7 @@ use faststr::FastStr;
use serde::de::{self, Expected, Unexpected};
use smallvec::SmallVec;

use super::reader::{Reader, Reference};
use super::{as_str, DEFAULT_KEY_BUF_CAPACITY};
#[cfg(all(target_feature = "neon", target_arch = "aarch64"))]
use crate::util::simd::bits::NeonBits;
use crate::{
Expand All @@ -27,18 +28,18 @@ use crate::{
arc::Arc,
arch::{get_nonspace_bits, prefix_xor},
num::{parse_number, ParserNumber},
simd::{i8x32, m8x32, u8x32, u8x64, Mask, Simd},
//simd::{i8x32, m8x32, u8x32, u8x64, Mask, Simd},
simd::{Mask, Simd},
string::*,
unicode::{codepoint_to_utf8, hex_to_u32_nocheck},
},
value::{shared::Shared, visitor::JsonVisitor},
JsonType, LazyValue,
};

pub(crate) const DEFAULT_KEY_BUF_CAPACITY: usize = 128;
pub(crate) fn as_str(data: &[u8]) -> &str {
unsafe { from_utf8_unchecked(data) }
}
use crate::{
reader::{Reader, Reference},
util::simd::BitMask,
};

#[inline(always)]
fn get_escaped_branchless_u32(prev_escaped: &mut u32, backslash: u32) -> u32 {
Expand Down Expand Up @@ -90,41 +91,49 @@ pub(crate) fn is_whitespace(ch: u8) -> bool {
}

#[inline(always)]
fn get_string_bits(data: &[u8; 64], prev_instring: &mut u64, prev_escaped: &mut u64) -> u64 {
let v = unsafe { u8x64::from_slice_unaligned_unchecked(data) };
fn get_string_bits<U8x64>(data: &[u8; 64], prev_instring: &mut u64, prev_escaped: &mut u64) -> u64
where
U8x64: Simd<Element = u8>,
<U8x64::Mask as Mask>::BitMask: BitMask<Primitive = u64>,
{
let v = unsafe { U8x64::from_slice_unaligned_unchecked(data) };

let bs_bits = (v.eq(&u8x64::splat(b'\\'))).bitmask();
let bs_bits = (v.eq(&U8x64::splat(b'\\'))).bitmask();
let escaped: u64;
if bs_bits != 0 {
escaped = get_escaped_branchless_u64(prev_escaped, bs_bits);
if !bs_bits.all_zero() {
escaped = get_escaped_branchless_u64(prev_escaped, bs_bits.as_primitive());
} else {
escaped = *prev_escaped;
*prev_escaped = 0;
}
let quote_bits = (v.eq(&u8x64::splat(b'"'))).bitmask() & !escaped;
let quote_bits = (v.eq(&U8x64::splat(b'"'))).bitmask().as_primitive() & !escaped;
let in_string = unsafe { prefix_xor(quote_bits) ^ *prev_instring };
*prev_instring = (in_string as i64 >> 63) as u64;
in_string
}

#[inline(always)]
fn skip_container_loop(
fn skip_container_loop<U8x64>(
input: &[u8; 64], /* a 64-bytes slice from json */
prev_instring: &mut u64, /* the bitmap of last string */
prev_escaped: &mut u64,
lbrace_num: &mut usize,
rbrace_num: &mut usize,
left: u8,
right: u8,
) -> Option<NonZeroU8> {
) -> Option<NonZeroU8>
where
U8x64: Simd<Element = u8>,
<U8x64::Mask as Mask>::BitMask: BitMask<Primitive = u64>,
{
// get the bitmao
let instring = get_string_bits(input, prev_instring, prev_escaped);
let instring = get_string_bits::<U8x64>(input, prev_instring, prev_escaped);
// #Safety
// the input is 64 bytes, so the v is always valid.
let v = unsafe { u8x64::from_slice_unaligned_unchecked(input) };
let v = unsafe { U8x64::from_slice_unaligned_unchecked(input) };
let last_lbrace_num = *lbrace_num;
let mut rbrace = (v.eq(&u8x64::splat(right))).bitmask() & !instring;
let lbrace = (v.eq(&u8x64::splat(left))).bitmask() & !instring;
let mut rbrace = (v.eq(&U8x64::splat(right))).bitmask().as_primitive() & !instring;
let lbrace = (v.eq(&U8x64::splat(left))).bitmask().as_primitive() & !instring;
while rbrace != 0 {
*rbrace_num += 1;
*lbrace_num = last_lbrace_num + (lbrace & (rbrace - 1)).count_ones() as usize;
Expand All @@ -140,12 +149,14 @@ fn skip_container_loop(
None
}

pub(crate) struct Parser<R> {
pub(crate) read: R,
error_index: usize, // mark the error position
nospace_bits: u64, // SIMD marked nospace bitmap
nospace_start: isize, // the start position of nospace_bits
pub(crate) shared: Option<Arc<Shared>>, // the shared allocator for `Value`
pub(crate) struct Parser<R, I8x32, U8x32, U8x64> {
read: R,
error_index: usize, // mark the error position
nospace_bits: u64, // SIMD marked nospace bitmap
nospace_start: isize, // the start position of nospace_bits
shared: Option<Arc<Shared>>, // the shared allocator for `Value`

_marker: PhantomData<(I8x32, U8x32, U8x64)>,
}

/// Records the parse status
Expand All @@ -155,9 +166,15 @@ pub(crate) enum ParseStatus {
HasEscaped,
}

impl<'de, R> Parser<R>
impl<'de, R, I8x32, U8x32, U8x64> Parser<R, I8x32, U8x32, U8x64>
where
R: Reader<'de>,
I8x32: Simd<Element = i8>,
<I8x32::Mask as Mask>::BitMask: BitMask<Primitive = u32>,
U8x32: Simd<Element = u8>,
<U8x32::Mask as Mask>::BitMask: BitMask<Primitive = u32>,
U8x64: Simd<Element = u8>,
<U8x64::Mask as Mask>::BitMask: BitMask<Primitive = u64>,
{
pub fn new(read: R) -> Self {
Self {
Expand All @@ -166,9 +183,16 @@ where
nospace_bits: 0,
nospace_start: -128,
shared: None,

_marker: PhantomData,
}
}

#[inline(always)]
pub(crate) fn read(&mut self) -> &mut R {
&mut self.read
}

#[inline(always)]
fn error_index(&self) -> usize {
// when parsing strings , we need record the error position.
Expand Down Expand Up @@ -972,21 +996,20 @@ where
#[inline(always)]
fn get_next_token<const N: usize>(&mut self, tokens: [u8; N], advance: usize) -> Option<u8> {
let r = &mut self.read;
const LANS: usize = u8x32::lanes();
while let Some(chunk) = r.peek_n(LANS) {
let v = unsafe { u8x32::from_slice_unaligned_unchecked(chunk) };
let mut vor = m8x32::splat(false);
while let Some(chunk) = r.peek_n(U8x32::LANES) {
let v = unsafe { U8x32::from_slice_unaligned_unchecked(chunk) };
let mut vor = U8x32::Mask::splat(false);
for t in tokens.iter().take(N) {
vor |= v.eq(&u8x32::splat(*t));
vor |= v.eq(&U8x32::splat(*t));
}
let next = vor.bitmask();
if next != 0 {
let cnt = next.trailing_zeros() as usize;
if !next.all_zero() {
let cnt = next.as_primitive().trailing_zeros() as usize;
let ch = chunk[cnt];
r.eat(cnt + advance);
return Some(ch);
}
r.eat(LANS);
r.eat(U8x32::LANES);
}

while let Some(ch) = r.peek() {
Expand All @@ -1005,17 +1028,16 @@ where
// escaped status. skip_string always start with the quote marks.
#[inline(always)]
fn skip_string_impl(&mut self) -> Result<ParseStatus> {
const LANS: usize = u8x32::lanes();
let r = &mut self.read;
let mut quote_bits;
let mut escaped;
let mut prev_escaped = 0;
let mut status = ParseStatus::None;

while let Some(chunk) = r.peek_n(LANS) {
let v = unsafe { u8x32::from_slice_unaligned_unchecked(chunk) };
let bs_bits = (v.eq(&u8x32::splat(b'\\'))).bitmask();
quote_bits = (v.eq(&u8x32::splat(b'"'))).bitmask();
while let Some(chunk) = r.peek_n(U8x32::LANES) {
let v = unsafe { U8x32::from_slice_unaligned_unchecked(chunk) };
let bs_bits = (v.eq(&U8x32::splat(b'\\'))).bitmask().as_primitive();
quote_bits = (v.eq(&U8x32::splat(b'"'))).bitmask().as_primitive();
// maybe has escaped quotes
if ((quote_bits.wrapping_sub(1)) & bs_bits) != 0 || prev_escaped != 0 {
escaped = get_escaped_branchless_u32(&mut prev_escaped, bs_bits);
Expand All @@ -1028,7 +1050,7 @@ where
r.eat(quote_bits.trailing_zeros() as usize + 1);
return Ok(status);
}
r.eat(LANS)
r.eat(U8x32::LANES)
}

// skip the possible prev escaped quote
Expand Down Expand Up @@ -1084,19 +1106,17 @@ where
// skip_string skips a JSON string with validation.
#[inline(always)]
fn skip_string(&mut self) -> Result<ParseStatus> {
const LANS: usize = u8x32::lanes();

let mut status = ParseStatus::None;
while let Some(chunk) = self.read.peek_n(LANS) {
let v = unsafe { u8x32::from_slice_unaligned_unchecked(chunk) };
let v_bs = v.eq(&u8x32::splat(b'\\'));
let v_quote = v.eq(&u8x32::splat(b'"'));
let v_cc = v.le(&u8x32::splat(0x1f));
while let Some(chunk) = self.read.peek_n(U8x32::LANES) {
let v = unsafe { U8x32::from_slice_unaligned_unchecked(chunk) };
let v_bs = v.eq(&U8x32::splat(b'\\'));
let v_quote = v.eq(&U8x32::splat(b'"'));
let v_cc = v.le(&U8x32::splat(0x1f));
let mask = (v_bs | v_quote | v_cc).bitmask();

// check the mask
if mask != 0 {
let cnt = mask.trailing_zeros() as usize;
if !mask.all_zero() {
let cnt = mask.as_primitive().trailing_zeros() as usize;
self.read.eat(cnt + 1);

match chunk[cnt] {
Expand All @@ -1109,7 +1129,7 @@ where
_ => unreachable!(),
}
} else {
self.read.eat(LANS)
self.read.eat(U8x32::LANES)
}
}

Expand Down Expand Up @@ -1217,7 +1237,7 @@ where

while let Some(chunk) = reader.peek_n(64) {
let input = unsafe { &*(chunk.as_ptr() as *const [_; 64]) };
if let Some(count) = skip_container_loop(
if let Some(count) = skip_container_loop::<U8x64>(
input,
&mut prev_instring,
&mut prev_escaped,
Expand All @@ -1237,7 +1257,7 @@ where
let n = reader.remain();
remain[..n].copy_from_slice(reader.peek_n(n).unwrap_unchecked());
}
if let Some(count) = skip_container_loop(
if let Some(count) = skip_container_loop::<U8x64>(
&remain,
&mut prev_instring,
&mut prev_escaped,
Expand Down Expand Up @@ -1460,12 +1480,12 @@ where

// SIMD path for long number
while let Some(chunk) = self.read.peek_n(32) {
let v = unsafe { i8x32::from_slice_unaligned_unchecked(chunk) };
let zero = i8x32::splat(b'0' as i8);
let nine = i8x32::splat(b'9' as i8);
let v = unsafe { I8x32::from_slice_unaligned_unchecked(chunk) };
let zero = I8x32::splat(b'0' as i8);
let nine = I8x32::splat(b'9' as i8);
let nondigits = (zero.gt(&v) | v.gt(&nine)).bitmask();
if nondigits != 0 {
let cnt = nondigits.trailing_zeros() as usize;
if !nondigits.all_zero() {
let cnt = nondigits.as_primitive().trailing_zeros() as usize;
let ch = chunk[cnt];
if ch == b'.' && !is_float {
self.read.eat(cnt + 1);
Expand All @@ -1474,7 +1494,7 @@ where

let traversed = cnt + 2;
// check the remaining digits
let nondigts = nondigits.wrapping_shr((traversed) as u32);
let nondigts = nondigits.as_primitive().wrapping_shr((traversed) as u32);
if nondigts != 0 {
while let Some(ch) = self.read.peek() {
if ch == b'e' || ch == b'E' {
Expand Down
20 changes: 20 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::str::from_utf8_unchecked;

mod inner;

cfg_if::cfg_if! {
if #[cfg(feature = "runtime-detection")] {
mod runtime;
pub(crate) use self::runtime::Parser;
} else {
use crate::util::simd::{i8x32, u8x32, u8x64};
pub(crate) type Parser<R> = self::inner::Parser<R, i8x32, u8x32, u8x64>;
}
}

pub(crate) use self::inner::ParseStatus;

pub(crate) const DEFAULT_KEY_BUF_CAPACITY: usize = 128;
pub(crate) fn as_str(data: &[u8]) -> &str {
unsafe { from_utf8_unchecked(data) }
}
Loading