diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 5e307022df7..03a92384706 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -197,13 +197,18 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Struct(_), _) => false, (_, Struct(_)) => false, (_, Boolean) => { - DataType::is_integer(from_type) || - DataType::is_floating(from_type) + DataType::is_integer(from_type) + || DataType::is_floating(from_type) + || from_type == &Utf8View || from_type == &Utf8 || from_type == &LargeUtf8 } (Boolean, _) => { - DataType::is_integer(to_type) || DataType::is_floating(to_type) || to_type == &Utf8 || to_type == &LargeUtf8 + DataType::is_integer(to_type) + || DataType::is_floating(to_type) + || to_type == &Utf8View + || to_type == &Utf8 + || to_type == &LargeUtf8 } (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, @@ -1202,6 +1207,7 @@ pub fn cast_with_options( Float16 => cast_numeric_to_bool::(array), Float32 => cast_numeric_to_bool::(array), Float64 => cast_numeric_to_bool::(array), + Utf8View => cast_utf8view_to_boolean(array, cast_options), Utf8 => cast_utf8_to_boolean::(array, cast_options), LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), _ => Err(ArrowError::CastError(format!( @@ -1220,6 +1226,7 @@ pub fn cast_with_options( Float16 => cast_bool_to_numeric::(array, cast_options), Float32 => cast_bool_to_numeric::(array, cast_options), Float64 => cast_bool_to_numeric::(array, cast_options), + Utf8View => value_to_string_view(array, cast_options), Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( @@ -3845,6 +3852,14 @@ mod tests { assert_eq!(*as_boolean_array(&casted), expected); } + #[test] + fn test_cast_utf8view_to_bool() { + let strings = StringViewArray::from(vec!["true", "false", "invalid", " Y ", ""]); + let casted = cast(&strings, &DataType::Boolean).unwrap(); + let expected = BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + assert_eq!(*as_boolean_array(&casted), expected); + } + #[test] fn test_cast_with_options_utf8_to_bool() { let strings = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]); @@ -3876,6 +3891,16 @@ mod tests { assert!(!c.is_valid(2)); } + #[test] + fn test_cast_bool_to_utf8view() { + let array = BooleanArray::from(vec![Some(true), Some(false), None]); + let b = cast(&array, &DataType::Utf8View).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!("true", c.value(0)); + assert_eq!("false", c.value(1)); + assert!(!c.is_valid(2)); + } + #[test] fn test_cast_bool_to_utf8() { let array = BooleanArray::from(vec![Some(true), Some(false), None]); diff --git a/arrow-cast/src/cast/string.rs b/arrow-cast/src/cast/string.rs index 07366a785af..7f22c4fd64d 100644 --- a/arrow-cast/src/cast/string.rs +++ b/arrow-cast/src/cast/string.rs @@ -368,19 +368,14 @@ pub(crate) fn cast_binary_to_string( } } -/// Casts Utf8 to Boolean -pub(crate) fn cast_utf8_to_boolean( - from: &dyn Array, +/// Casts string to boolean +fn cast_string_to_boolean<'a, StrArray>( + array: &StrArray, cast_options: &CastOptions, ) -> Result where - OffsetSize: OffsetSizeTrait, + StrArray: StringArrayType<'a>, { - let array = from - .as_any() - .downcast_ref::>() - .unwrap(); - let output_array = array .iter() .map(|value| match value { @@ -402,3 +397,27 @@ where Ok(Arc::new(output_array)) } + +pub(crate) fn cast_utf8_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + cast_string_to_boolean(&array, cast_options) +} + +pub(crate) fn cast_utf8view_to_boolean( + from: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = from.as_any().downcast_ref::().unwrap(); + + cast_string_to_boolean(&array, cast_options) +}