Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return Error From Postgres Float Deserialization #3944

Merged
merged 9 commits into from
May 24, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Increasing the minimal supported Rust version will always be coupled at least wi

* The minimal officially supported rustc version is now 1.78.0
* Deprecated `sql_function!` in favour of `define_sql_function!` which provides compatibility with `#[dsl::auto_type]`
* Deserialization error messages now contain information about the field that failed to deserialize

## [2.1.0] 2023-05-26

Expand Down
8 changes: 7 additions & 1 deletion diesel/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,13 @@ where
use crate::row::Field;

let field = row.get(0).ok_or(crate::result::UnexpectedEndOfRow)?;
T::from_nullable_sql(field.value())
T::from_nullable_sql(field.value()).map_err(|e| {
if e.is::<crate::result::UnexpectedNullError>() {
e
} else {
Box::new(crate::result::DeserializeFieldError::new(field, e))
}
})
}
}

Expand Down
1 change: 1 addition & 0 deletions diesel/src/pg/connection/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl PgResult {
)
}

#[inline(always)] // benchmarks indicate a ~1.7% improvement in instruction count for this
pub(super) fn column_name(&self, col_idx: usize) -> Option<&str> {
self.column_name_map
.get_or_init(|| {
Expand Down
54 changes: 34 additions & 20 deletions diesel/src/pg/types/floats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,23 @@ impl ToSql<sql_types::Numeric, Pg> for PgNumeric {
impl FromSql<sql_types::Float, Pg> for f32 {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 4,
"Received more than 4 bytes while decoding \
an f32. Was a double accidentally marked as float?"
);
debug_assert!(
bytes.len() >= 4,
"Received less than 4 bytes while decoding \
an f32."
);

if bytes.len() < 4 {
return deserialize::Result::Err(
"Received less than 4 bytes while decoding an f32. \
Was a numeric accidentally marked as float?"
.into(),
);
}

if bytes.len() > 4 {
return deserialize::Result::Err(
"Received more than 4 bytes while decoding an f32. \
Was a double accidentally marked as float?"
.into(),
);
}

bytes
.read_f32::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
Expand All @@ -137,16 +144,23 @@ impl FromSql<sql_types::Float, Pg> for f32 {
impl FromSql<sql_types::Double, Pg> for f64 {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 8,
"Received less than 8 bytes while decoding \
an f64. Was a float accidentally marked as double?"
);
debug_assert!(
bytes.len() >= 8,
"Received more than 8 bytes while decoding \
an f64. Was a numeric accidentally marked as double?"
);

if bytes.len() < 8 {
return deserialize::Result::Err(
"Received less than 8 bytes while decoding an f64. \
Was a float accidentally marked as double?"
.into(),
);
}

if bytes.len() > 8 {
return deserialize::Result::Err(
"Received more than 8 bytes while decoding an f64. \
Was a numeric accidentally marked as double?"
.into(),
);
}

bytes
.read_f64::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
Expand Down
79 changes: 48 additions & 31 deletions diesel/src/pg/types/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ impl ToSql<sql_types::Oid, Pg> for u32 {

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::SmallInt, Pg> for i16 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 2,
"Received more than 2 bytes decoding i16. \
Was an Integer expression accidentally identified as SmallInt?"
);
debug_assert!(
bytes.len() >= 2,
"Received fewer than 2 bytes decoding i16. \
Was an expression of a different type accidentally identified \
as SmallInt?"
);
if bytes.len() < 2 {
return emit_size_error(
"Received less than 2 bytes while decoding an i16. \
Was an expression of a different type accidentally marked as SmallInt?",
);
}

if bytes.len() > 2 {
return emit_size_error(
"Received more than 2 bytes while decoding an i16. \
Was an Integer expression accidentally marked as SmallInt?",
);
}
bytes
.read_i16::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
Expand All @@ -44,38 +47,52 @@ impl FromSql<sql_types::SmallInt, Pg> for i16 {

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Integer, Pg> for i32 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 4,
"Received more than 4 bytes decoding i32. \
Was a BigInt expression accidentally identified as Integer?"
);
debug_assert!(
bytes.len() >= 4,
"Received fewer than 4 bytes decoding i32. \
Was a SmallInt expression accidentally identified as Integer?"
);
if bytes.len() < 4 {
return emit_size_error(
"Received less than 4 bytes while decoding an i32. \
Was an SmallInt expression accidentally marked as Integer?",
);
}

if bytes.len() > 4 {
return emit_size_error(
"Received more than 4 bytes while decoding an i32. \
Was an BigInt expression accidentally marked as Integer?",
);
}
bytes
.read_i32::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
}
}

#[cold]
#[inline(never)]
fn emit_size_error<T>(var_name: &str) -> deserialize::Result<T> {
deserialize::Result::Err(var_name.into())
}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::BigInt, Pg> for i64 {
#[inline(always)]
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
let mut bytes = value.as_bytes();
debug_assert!(
bytes.len() <= 8,
"Received more than 8 bytes decoding i64. \
Was an expression of a different type misidentified as BigInt?"
);
debug_assert!(
bytes.len() >= 8,
"Received fewer than 8 bytes decoding i64. \
Was an Integer expression misidentified as BigInt?"
);
if bytes.len() < 8 {
return emit_size_error(
"Received less than 8 bytes while decoding an i64. \
Was an Integer expression accidentally marked as BigInt?",
);
}

if bytes.len() > 8 {
return emit_size_error(
"Received more than 8 bytes while decoding an i64. \
Was an expression of a different type expression accidentally marked as BigInt?"
);
}
bytes
.read_i64::<NetworkEndian>()
.map_err(|e| Box::new(e) as Box<_>)
Expand Down
44 changes: 44 additions & 0 deletions diesel/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,47 @@ impl fmt::Display for EmptyChangeset {
}

impl StdError for EmptyChangeset {}

/// An error occurred while deserializing a field
#[derive(Debug)]
#[non_exhaustive]
pub struct DeserializeFieldError {
weiznich marked this conversation as resolved.
Show resolved Hide resolved
/// The name of the field that failed to deserialize
pub field_name: Option<String>,
/// The error that occurred while deserializing the field
pub error: Box<dyn StdError + Send + Sync>,
}

impl DeserializeFieldError {
#[cold]
pub(crate) fn new<'a, F, DB>(field: F, error: Box<dyn std::error::Error + Send + Sync>) -> Self
where
DB: crate::backend::Backend,
F: crate::row::Field<'a, DB>,
{
DeserializeFieldError {
field_name: field.field_name().map(|s| s.to_string()),
error,
}
}
}

impl StdError for DeserializeFieldError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&*self.error)
}
}

impl fmt::Display for DeserializeFieldError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref field_name) = self.field_name {
write!(
f,
"Error deserializing field '{}': {}",
field_name, self.error
)
} else {
write!(f, "Error deserializing field: {}", self.error)
}
}
}
1 change: 1 addition & 0 deletions diesel_bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fast_run = []

[profile.release]
lto = true
debug = true
codegen-units = 1

[patch.crates-io]
Expand Down
114 changes: 113 additions & 1 deletion diesel_tests/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ where

#[cfg(feature = "postgres")]
#[test]
#[should_panic(expected = "Received more than 4 bytes decoding i32")]
#[should_panic(expected = "Received more than 4 bytes while decoding an i32")]
fn debug_check_catches_reading_bigint_as_i32_when_using_raw_sql() {
use diesel::dsl::sql;
use diesel::sql_types::Integer;
Expand Down Expand Up @@ -1574,3 +1574,115 @@ fn citext_fields() {

assert_eq!(lowercase_in_db, Some("lowercase_value".to_string()));
}

#[test]
#[cfg(feature = "postgres")]
fn deserialize_wrong_primitive_gives_good_error() {
let conn = &mut connection();

diesel::sql_query(
"CREATE TABLE test_table(\
bool BOOLEAN,
small SMALLINT, \
int INTEGER, \
big BIGINT, \
float FLOAT4, \
double FLOAT8,
text TEXT)",
)
.execute(conn)
.unwrap();
diesel::sql_query("INSERT INTO test_table VALUES('t', 1, 1, 1, 1, 1, 'long text long text')")
.execute(conn)
.unwrap();

let res = diesel::dsl::sql::<SmallInt>("SELECT bool FROM test_table").get_result::<i16>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'bool': \
Received less than 2 bytes while decoding an i16. \
Was an expression of a different type accidentally marked as SmallInt?"
);

let res = diesel::dsl::sql::<SmallInt>("SELECT int FROM test_table").get_result::<i16>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'int': \
Received more than 2 bytes while decoding an i16. \
Was an Integer expression accidentally marked as SmallInt?"
);

let res = diesel::dsl::sql::<Integer>("SELECT small FROM test_table").get_result::<i32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'small': \
Received less than 4 bytes while decoding an i32. \
Was an SmallInt expression accidentally marked as Integer?"
);

let res = diesel::dsl::sql::<Integer>("SELECT big FROM test_table").get_result::<i32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'big': \
Received more than 4 bytes while decoding an i32. \
Was an BigInt expression accidentally marked as Integer?"
);

let res = diesel::dsl::sql::<BigInt>("SELECT int FROM test_table").get_result::<i64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'int': \
Received less than 8 bytes while decoding an i64. \
Was an Integer expression accidentally marked as BigInt?"
);

let res = diesel::dsl::sql::<BigInt>("SELECT text FROM test_table").get_result::<i64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'text': \
Received more than 8 bytes while decoding an i64. \
Was an expression of a different type expression accidentally marked as BigInt?"
);

let res = diesel::dsl::sql::<Float>("SELECT small FROM test_table").get_result::<f32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'small': \
Received less than 4 bytes while decoding an f32. \
Was a numeric accidentally marked as float?"
);

let res = diesel::dsl::sql::<Float>("SELECT double FROM test_table").get_result::<f32>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'double': \
Received more than 4 bytes while decoding an f32. \
Was a double accidentally marked as float?"
);

let res = diesel::dsl::sql::<Double>("SELECT float FROM test_table").get_result::<f64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'float': \
Received less than 8 bytes while decoding an f64. \
Was a float accidentally marked as double?"
);

let res = diesel::dsl::sql::<Double>("SELECT text FROM test_table").get_result::<f64>(conn);
assert!(res.is_err());
assert_eq!(
res.unwrap_err().to_string(),
"Error deserializing field 'text': \
Received more than 8 bytes while decoding an f64. \
Was a numeric accidentally marked as double?"
);
}
Loading