Skip to content

Commit

Permalink
Fix deserialization of empty PostgreSQL ranges
Browse files Browse the repository at this point in the history
In diesel empty ranges are represented as:

```
(
  Bound::Excluded(T::default()),
  Bound::Excluded(T::default()),
)
```
  • Loading branch information
mpflanzer committed Oct 4, 2024
1 parent 381be19 commit 2b97dd1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
3 changes: 2 additions & 1 deletion diesel/src/pg/types/floats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::error::Error;
#[cfg(feature = "quickcheck")]
mod quickcheck_impls;

#[derive(Debug, Clone, PartialEq, Eq, AsExpression, FromSqlRow)]
#[derive(Default, Debug, Clone, PartialEq, Eq, AsExpression, FromSqlRow)]
#[diesel(sql_type = sql_types::Numeric)]
/// Represents a NUMERIC value, closely mirroring the PG wire protocol
/// representation
Expand All @@ -33,6 +33,7 @@ pub enum PgNumeric {
digits: Vec<i16>,
},
/// Not a number
#[default]
NaN,
}

Expand Down
2 changes: 1 addition & 1 deletion diesel/src/pg/types/multirange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ multirange_as_expressions!(std::ops::RangeTo<T>);
#[cfg(feature = "postgres_backend")]
impl<T, ST> FromSql<Multirange<ST>, Pg> for Vec<(Bound<T>, Bound<T>)>
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + std::default::Default,
{
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
Expand Down
12 changes: 8 additions & 4 deletions diesel/src/pg/types/ranges.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,17 @@ range_as_expression!(&'a std::ops::RangeTo<T>; Nullable<Range<ST>>);
#[cfg(feature = "postgres_backend")]
impl<T, ST> FromSql<Range<ST>, Pg> for (Bound<T>, Bound<T>)
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + std::default::Default,
{
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
let mut lower_bound = Bound::Unbounded;
let mut upper_bound = Bound::Unbounded;

if !flags.contains(RangeFlags::LB_INF) {
if flags.contains(RangeFlags::EMPTY) {
lower_bound = Bound::Excluded(T::default());
} else if !flags.contains(RangeFlags::LB_INF) {
let elem_size = bytes.read_i32::<NetworkEndian>()?;
let (elem_bytes, new_bytes) = bytes.split_at(elem_size.try_into()?);
bytes = new_bytes;
Expand All @@ -95,7 +97,9 @@ where
};
}

if !flags.contains(RangeFlags::UB_INF) {
if flags.contains(RangeFlags::EMPTY) {
upper_bound = Bound::Excluded(T::default());
} else if !flags.contains(RangeFlags::UB_INF) {
let _size = bytes.read_i32::<NetworkEndian>()?;
let value = T::from_sql(PgValue::new_internal(bytes, &value))?;

Expand All @@ -113,7 +117,7 @@ where
#[cfg(feature = "postgres_backend")]
impl<T, ST> Queryable<Range<ST>, Pg> for (Bound<T>, Bound<T>)
where
T: FromSql<ST, Pg>,
T: FromSql<ST, Pg> + std::default::Default,
{
type Row = Self;

Expand Down
37 changes: 36 additions & 1 deletion diesel_tests/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1422,10 +1422,45 @@ fn test_range_from_sql() {
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '(1,1]'::int4range";
let query = "SELECT '[2,1)'::int4range";
assert!(sql::<Range<Int4>>(query)
.load::<(Bound<i32>, Bound<i32>)>(connection)
.is_err());

let query = "'empty'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '(1,1)'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '(1,1]'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '[1,1)'::int4range";
let expected_value = (Bound::Excluded(0), Bound::Excluded(0));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);

let query = "SELECT '[1,1]'::int4range";
let expected_value = (Bound::Included(1), Bound::Included(1));
assert_eq!(
expected_value,
query_single_value::<Range<Int4>, (Bound<i32>, Bound<i32>)>(query)
);
}

#[cfg(feature = "postgres")]
Expand Down

0 comments on commit 2b97dd1

Please sign in to comment.