Skip to content

Commit

Permalink
add support for Decimal128(S)
Browse files Browse the repository at this point in the history
  • Loading branch information
mel-mel-king committed Nov 27, 2023
1 parent f224fed commit 8ce22e2
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/types/block/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,6 @@ mod test {
block.columns[14].sql_type(),
SqlType::DateTime(DateTimeType::Chrono)
);
assert_eq!(block.columns[15].sql_type(), SqlType::Decimal(18, 4));
assert_eq!(block.columns[15].sql_type(), SqlType::Decimal(38, 4));
}
}
19 changes: 16 additions & 3 deletions src/types/column/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl DecimalColumnData {
let type_name = match nobits {
NoBits::N32 => "Int32",
NoBits::N64 => "Int64",
NoBits::N128 => "Int128",
};
let inner =
<dyn ColumnData>::load_data::<BoxColumnWrapper, _>(reader, type_name, size, tz)?;
Expand Down Expand Up @@ -161,16 +162,21 @@ impl ColumnData for DecimalColumnData {
let internal: i64 = decimal.internal();
self.inner.push(internal.into())
}
NoBits::N128 => {
let internal: i128 = decimal.internal();
self.inner.push(internal.into())
}
}
} else {
panic!("value should be decimal ({value:?})");
}
}

fn at(&self, index: usize) -> ValueRef {
let underlying: i64 = match self.nobits {
NoBits::N32 => i64::from(i32::from(self.inner.at(index))),
NoBits::N64 => i64::from(self.inner.at(index)),
let underlying: i128 = match self.nobits {
NoBits::N32 => i128::from(i32::from(self.inner.at(index))),
NoBits::N64 => i128::from(i64::from(self.inner.at(index))),
NoBits::N128 => i128::from(self.inner.at(index)),
};

ValueRef::Decimal(Decimal {
Expand Down Expand Up @@ -224,6 +230,10 @@ impl<K: ColumnType> ColumnData for DecimalAdapter<K> {
let internal: i64 = decimal.internal();
encoder.write(internal);
}
NoBits::N128 => {
let internal: i128 = decimal.internal();
encoder.write(internal);
}
}
} else {
panic!("should be decimal");
Expand Down Expand Up @@ -276,6 +286,9 @@ impl<K: ColumnType> ColumnData for NullableDecimalAdapter<K> {
encoder.write(underlying as i32);
}
NoBits::N64 => {
encoder.write(underlying as i64);
}
NoBits::N128 => {
encoder.write(underlying);
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/types/column/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl dyn ColumnData {
let inner_type = match nobits {
NoBits::N32 => SqlType::Int32,
NoBits::N64 => SqlType::Int64,
NoBits::N128 => SqlType::Int128,
};

W::wrap(DecimalColumnData {
Expand Down Expand Up @@ -382,6 +383,7 @@ fn parse_decimal(source: &str) -> Option<(u8, u8, NoBits)> {
let precision = match bits {
NoBits::N32 => 9,
NoBits::N64 => 18,
NoBits::N128 => 38,
};
Some((precision, scale, bits))
}
Expand Down Expand Up @@ -554,7 +556,7 @@ mod test {
fn test_parse_decimal() {
assert_eq!(parse_decimal("Decimal(9, 4)"), Some((9, 4, NoBits::N32)));
assert_eq!(parse_decimal("Decimal(10, 4)"), Some((10, 4, NoBits::N64)));
assert_eq!(parse_decimal("Decimal(20, 4)"), None);
assert_eq!(parse_decimal("Decimal(20, 4)"), Some((20, 4, NoBits::N128)));
assert_eq!(parse_decimal("Decimal(2000, 4)"), None);
assert_eq!(parse_decimal("Decimal(3, 4)"), None);
assert_eq!(parse_decimal("Decimal(20, -4)"), None);
Expand Down
6 changes: 5 additions & 1 deletion src/types/column/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ impl<'a> DecimalIterator<'a> {
unsafe fn next_unchecked_<T>(&mut self) -> Decimal
where
T: Copy + Sized,
i64: From<T>,
i128: From<T>,
{
let current_value = *(self.ptr as *const T);
self.ptr = (self.ptr as *const T).offset(1) as *const u8;
Expand All @@ -327,6 +327,7 @@ impl<'a> DecimalIterator<'a> {
match self.nobits {
NoBits::N32 => self.next_unchecked_::<i32>(),
NoBits::N64 => self.next_unchecked_::<i64>(),
NoBits::N128 => self.next_unchecked_::<i128>(),
}
}

Expand All @@ -336,6 +337,7 @@ impl<'a> DecimalIterator<'a> {
match self.nobits {
NoBits::N32 => self.ptr = (self.ptr as *const i32).add(n) as *const u8,
NoBits::N64 => self.ptr = (self.ptr as *const i64).add(n) as *const u8,
NoBits::N128 => self.ptr = (self.ptr as *const i128).add(n) as *const u8,
}
}
}
Expand All @@ -347,6 +349,7 @@ impl<'a> ExactSizeIterator for DecimalIterator<'a> {
let size = match self.nobits {
NoBits::N32 => mem::size_of::<i32>(),
NoBits::N64 => mem::size_of::<i64>(),
NoBits::N128 => mem::size_of::<i128>(),
};
(self.end as usize - self.ptr as usize) / size
}
Expand Down Expand Up @@ -983,6 +986,7 @@ impl<'a> Iterable<'a, Simple> for Decimal {
match nobits {
NoBits::N32 => (ptr as *const u32).add(size) as *const u8,
NoBits::N64 => (ptr as *const u64).add(size) as *const u8,
NoBits::N128 => (ptr as *const u128).add(size) as *const u8,
}
};

Expand Down
80 changes: 59 additions & 21 deletions src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
hash::{Hash, Hasher},
};

static FACTORS10: &[i64] = &[
static FACTORS10: &[i128] = &[
1,
10,
100,
Expand All @@ -24,27 +24,48 @@ static FACTORS10: &[i64] = &[
10_000_000_000_000_000,
100_000_000_000_000_000,
1_000_000_000_000_000_000,
10_000_000_000_000_000_000,
100_000_000_000_000_000_000,
1_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000,
1_000_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000_000,
1_000_000_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000_000_000,
1_000_000_000_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000_000_000_000,
1_000_000_000_000_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000_000_000_000_000,
1_000_000_000_000_000_000_000_000_000_000_000_000,
10_000_000_000_000_000_000_000_000_000_000_000_000,
100_000_000_000_000_000_000_000_000_000_000_000_000,
];

pub trait Base {
fn scale(self, scale: i64) -> i64;
fn scale(self, scale: i128) -> i128;
}

pub trait InternalResult {
fn get(underlying: i64) -> Self;
fn get(underlying: i128) -> Self;
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) enum NoBits {
N32,
N64,
N128,
}

/// Provides arbitrary-precision floating point decimal.
#[derive(Clone)]
pub struct Decimal {
pub(crate) underlying: i64,
pub(crate) nobits: NoBits, // its domain is {32, 64}
pub(crate) underlying: i128,
pub(crate) nobits: NoBits, // its domain is {32, 64, 128}
pub(crate) precision: u8,
pub(crate) scale: u8,
}
Expand Down Expand Up @@ -73,8 +94,8 @@ macro_rules! base_for {
( $( $t:ty: $cast:expr ),* ) => {
$(
impl Base for $t {
fn scale(self, scale: i64) -> i64 {
$cast(self * (scale as $t)) as i64
fn scale(self, scale: i128) -> i128 {
$cast(self * (scale as $t)) as i128
}
}
)*
Expand All @@ -84,26 +105,35 @@ macro_rules! base_for {
base_for! {
f32: std::convert::identity,
f64: std::convert::identity,
i8: i64::from,
i16: i64::from,
i32: i64::from,
i64: std::convert::identity,
u8: i64::from,
u16: i64::from,
u32: i64::from,
u64 : std::convert::identity
i8: i128::from,
i16: i128::from,
i32: i128::from,
i64: i128::from,
i128: std::convert::identity,
u8: i128::from,
u16: i128::from,
u32: i128::from,
u64: i128::from,
u128: std::convert::identity
}

impl InternalResult for i32 {
#[inline(always)]
fn get(underlying: i64) -> Self {
fn get(underlying: i128) -> Self {
underlying as Self
}
}

impl InternalResult for i64 {
#[inline(always)]
fn get(underlying: i64) -> Self {
fn get(underlying: i128) -> Self {
underlying as Self
}
}

impl InternalResult for i128 {
#[inline(always)]
fn get(underlying: i128) -> Self {
underlying
}
}
Expand All @@ -114,6 +144,8 @@ impl NoBits {
Some(NoBits::N32)
} else if precision <= 18 {
Some(NoBits::N64)
} else if precision <= 38 {
Some(NoBits::N128)
} else {
None
}
Expand Down Expand Up @@ -177,8 +209,8 @@ impl From<Decimal> for f64 {

impl Decimal {
/// Method of creating a Decimal.
pub fn new(underlying: i64, scale: u8) -> Decimal {
let precision = 18;
pub fn new(underlying: i128, scale: u8) -> Decimal {
let precision = 38;
if scale > precision {
panic!("scale can't be greater than 18");
}
Expand All @@ -192,7 +224,7 @@ impl Decimal {
}

pub fn of<B: Base>(source: B, scale: u8) -> Decimal {
let precision = 18;
let precision = 38;
if scale > precision {
panic!("scale can't be greater than 18");
}
Expand All @@ -210,7 +242,7 @@ impl Decimal {
}
}

/// Get the internal representation of decimal as [`i32`] or [`i64`].
/// Get the internal representation of decimal as [`i32`] or [`i64`] or [`i128`].
///
/// example:
/// ```rust
Expand Down Expand Up @@ -305,6 +337,12 @@ mod test {
assert_eq!(internal, 20000_i64);
}

#[test]
fn test_internal128() {
let internal: i128 = Decimal::of(2, 4).internal();
assert_eq!(internal, 20000_i128);
}

#[test]
fn test_scale() {
assert_eq!(Decimal::of(2, 4).scale(), 4);
Expand Down
2 changes: 1 addition & 1 deletion src/types/value_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ mod test {

assert_eq!(
SqlType::from(ValueRef::Decimal(Decimal::of(2.0_f64, 4))),
SqlType::Decimal(18, 4)
SqlType::Decimal(38, 4)
);

assert_eq!(
Expand Down
10 changes: 7 additions & 3 deletions tests/clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1712,14 +1712,16 @@ async fn test_decimal() -> Result<(), Error> {
let ddl = "
CREATE TABLE clickhouse_decimal (
x Decimal(8, 3),
ox Nullable(Decimal(10, 2))
ox Nullable(Decimal(10, 2)),
xx Decimal(30, 4)
) Engine=Memory";

let query = "SELECT x, ox FROM clickhouse_decimal";
let query = "SELECT x, ox, xx FROM clickhouse_decimal";

let block = Block::new()
.column("x", vec![Decimal::of(1.234, 3), Decimal::of(5, 3)])
.column("ox", vec![None, Some(Decimal::of(1.23, 2))]);
.column("ox", vec![None, Some(Decimal::of(1.23, 2))])
.column("xx", vec![Decimal::of(1.23456, 4), Decimal::of(5, 4)]);

let pool = Pool::new(database_url());

Expand All @@ -1732,11 +1734,13 @@ async fn test_decimal() -> Result<(), Error> {
let x: Decimal = block.get(0, "x")?;
let ox: Option<Decimal> = block.get(1, "ox")?;
let ox0: Option<Decimal> = block.get(0, "ox")?;
let xx: Decimal = block.get(0, "xx")?;

assert_eq!(2, block.row_count());
assert_eq!(1.234, x.into());
assert_eq!(Some(1.23), ox.map(|v| v.into()));
assert_eq!(None, ox0);
assert_eq!(1.2345, xx.into());

Ok(())
}
Expand Down

0 comments on commit 8ce22e2

Please sign in to comment.