diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 9ff4da30ed8c..99abff34f47e 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -79,6 +79,7 @@ fn create_array( field: &Field, variadic_counts: &mut VecDeque, require_alignment: bool, + skip_validations: bool, ) -> Result { let data_type = field.data_type(); match data_type { @@ -91,6 +92,7 @@ fn create_array( reader.next_buffer()?, ], require_alignment, + skip_validations, ), BinaryView | Utf8View => { let count = variadic_counts @@ -107,6 +109,7 @@ fn create_array( data_type, &buffers, require_alignment, + skip_validations, ) } FixedSizeBinary(_) => create_primitive_array( @@ -114,29 +117,44 @@ fn create_array( data_type, &[reader.next_buffer()?, reader.next_buffer()?], require_alignment, + skip_validations, ), List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?, reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts, require_alignment)?; + let values = create_array( + reader, + list_field, + variadic_counts, + require_alignment, + skip_validations, + )?; create_list_array( list_node, data_type, &list_buffers, values, require_alignment, + skip_validations, ) } FixedSizeList(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts, require_alignment)?; + let values = create_array( + reader, + list_field, + variadic_counts, + require_alignment, + skip_validations, + )?; create_list_array( list_node, data_type, &list_buffers, values, require_alignment, + skip_validations, ) } Struct(struct_fields) => { @@ -148,7 +166,13 @@ fn create_array( // TODO investigate whether just knowing the number of buffers could // still work for struct_field in struct_fields { - let child = create_array(reader, struct_field, variadic_counts, require_alignment)?; + let child = create_array( + reader, + struct_field, + variadic_counts, + require_alignment, + skip_validations, + )?; struct_arrays.push(child); } let null_count = struct_node.null_count() as usize; @@ -172,9 +196,20 @@ fn create_array( } RunEndEncoded(run_ends_field, values_field) => { let run_node = reader.next_node(field)?; - let run_ends = - create_array(reader, run_ends_field, variadic_counts, require_alignment)?; - let values = create_array(reader, values_field, variadic_counts, require_alignment)?; + let run_ends = create_array( + reader, + run_ends_field, + variadic_counts, + require_alignment, + skip_validations, + )?; + let values = create_array( + reader, + values_field, + variadic_counts, + require_alignment, + skip_validations, + )?; let run_array_length = run_node.length() as usize; let builder = ArrayData::builder(data_type.clone()) @@ -183,7 +218,9 @@ fn create_array( .add_child_data(run_ends.into_data()) .add_child_data(values.into_data()); - let array_data = if require_alignment { + let array_data = if skip_validations { + unsafe { builder.build_unchecked() } + } else if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -213,6 +250,7 @@ fn create_array( &index_buffers, value_array.clone(), require_alignment, + skip_validations, ) } Union(fields, mode) => { @@ -239,7 +277,13 @@ fn create_array( let mut children = Vec::with_capacity(fields.len()); for (_id, field) in fields.iter() { - let child = create_array(reader, field, variadic_counts, require_alignment)?; + let child = create_array( + reader, + field, + variadic_counts, + require_alignment, + skip_validations, + )?; children.push(child); } @@ -261,7 +305,9 @@ fn create_array( .len(length as usize) .offset(0); - let array_data = if require_alignment { + let array_data = if skip_validations { + unsafe { builder.build_unchecked() } + } else if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -275,17 +321,36 @@ fn create_array( data_type, &[reader.next_buffer()?, reader.next_buffer()?], require_alignment, + skip_validations, ), } } /// Reads the correct number of buffers based on data type and null_count, and creates a /// primitive array ref +/// +/// # Arguments +/// +/// * `field_node` - A reference to the `FieldNode` which contains the length and null count of the array. +/// * `data_type` - The `DataType` of the array to be created. +/// * `buffers` - A slice of `Buffer` which contains the data for the array. +/// * `require_alignment` - A boolean indicating whether the buffers need to be aligned. +/// * `skip_validations` - A boolean indicating whether to skip validations. +/// +/// # Safety +/// +/// `skip_validations` allows the creation of an `ArrayData` without performing the +/// usual validations. This can lead to undefined behavior if the data is not +/// correctly formatted. Set `skip_validations` to true only if you are certain +/// +/// # Notes +/// If `skip_validations` is true, `require_alignment` is ignored. fn create_primitive_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], require_alignment: bool, + skip_validations: bool, ) -> Result { let length = field_node.length() as usize; let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -311,7 +376,9 @@ fn create_primitive_array( t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; - let array_data = if require_alignment { + let array_data = if skip_validations { + unsafe { builder.build_unchecked() } + } else if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -322,12 +389,21 @@ fn create_primitive_array( /// Reads the correct number of buffers based on list type and null_count, and creates a /// list array ref +/// +/// Safety: +/// `skip_validations` allows the creation of an `ArrayData` without performing the +/// usual validations. This can lead to undefined behavior if the data is not +/// correctly formatted. Set `skip_validations` to true only if you are certain. +/// +/// Notes: +/// * If `skip_validations` is true, `require_alignment` is ignored. fn create_list_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], child_array: ArrayRef, require_alignment: bool, + skip_validations: bool, ) -> Result { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let length = field_node.length() as usize; @@ -347,7 +423,9 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - let array_data = if require_alignment { + let array_data = if skip_validations { + unsafe { builder.build_unchecked() } + } else if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -358,12 +436,21 @@ fn create_list_array( /// Reads the correct number of buffers based on list type and null_count, and creates a /// list array ref +/// +/// Safety: +/// `skip_validations` allows the creation of an `ArrayData` without performing the +/// usual validations. This can lead to undefined behavior if the data is not +/// correctly formatted. Set `skip_validations` to true only if you are certain. +/// +/// Notes: +/// * If `skip_validations` is true, `require_alignment` is ignored. fn create_dictionary_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], value_array: ArrayRef, require_alignment: bool, + skip_validations: bool, ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -373,7 +460,9 @@ fn create_dictionary_array( .add_child_data(value_array.into_data()) .null_bit_buffer(null_buffer); - let array_data = if require_alignment { + let array_data = if skip_validations { + unsafe { builder.build_unchecked() } + } else if require_alignment { builder.build()? } else { builder.build_aligned()? @@ -521,6 +610,7 @@ pub fn read_record_batch( projection, metadata, false, + false, ) } @@ -533,7 +623,15 @@ pub fn read_dictionary( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, ) -> Result<(), ArrowError> { - read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) + read_dictionary_impl( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + false, + false, + ) } fn read_record_batch_impl( @@ -544,6 +642,7 @@ fn read_record_batch_impl( projection: Option<&[usize]>, metadata: &MetadataVersion, require_alignment: bool, + skip_validations: bool, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -577,8 +676,13 @@ fn read_record_batch_impl( for (idx, field) in schema.fields().iter().enumerate() { // Create array for projected field if let Some(proj_idx) = projection.iter().position(|p| p == &idx) { - let child = - create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; + let child = create_array( + &mut reader, + field, + &mut variadic_counts, + require_alignment, + skip_validations, + )?; arrays.push((proj_idx, child)); } else { reader.skip_field(field, &mut variadic_counts)?; @@ -595,7 +699,13 @@ fn read_record_batch_impl( let mut children = vec![]; // keep track of index as lists require more than one node for field in schema.fields() { - let child = create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; + let child = create_array( + &mut reader, + field, + &mut variadic_counts, + require_alignment, + skip_validations, + )?; children.push(child); } assert!(variadic_counts.is_empty()); @@ -610,6 +720,7 @@ fn read_dictionary_impl( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, require_alignment: bool, + skip_validations: bool, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -641,6 +752,7 @@ fn read_dictionary_impl( None, metadata, require_alignment, + skip_validations, )?; Some(record_batch.column(0).clone()) } @@ -766,6 +878,7 @@ pub struct FileDecoder { version: MetadataVersion, projection: Option>, require_alignment: bool, + skip_validations: bool, } impl FileDecoder { @@ -777,6 +890,7 @@ impl FileDecoder { dictionaries: Default::default(), projection: None, require_alignment: false, + skip_validations: false, } } @@ -803,6 +917,19 @@ impl FileDecoder { self } + /// Specifies whether or not to skip validations when creating [`ArrayData`]. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// Set `skip_validations` to true only if you are certain. + /// + /// Notes: + /// * If `skip_validations` is true, `require_alignment` is ignored. + /// * If `skip_validations` is true, it uses [`arrow_data::ArrayDataBuilder::build_unchecked`] to + /// construct [`arrow_data::ArrayData`] under the hood. + pub fn with_skip_validations(mut self, skip_validations: bool) -> Self { + self.skip_validations = skip_validations; + self + } + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; @@ -828,6 +955,7 @@ impl FileDecoder { &mut self.dictionaries, &message.version(), self.require_alignment, + self.skip_validations, ) } t => Err(ArrowError::ParseError(format!( @@ -860,6 +988,7 @@ impl FileDecoder { self.projection.as_deref(), &message.version(), self.require_alignment, + self.skip_validations, ) .map(Some) } @@ -880,6 +1009,8 @@ pub struct FileReaderBuilder { max_footer_fb_tables: usize, /// Passed through to construct [`VerifierOptions`] max_footer_fb_depth: usize, + /// Skip validations when creating [`ArrayData`] + skip_validations: bool, } impl Default for FileReaderBuilder { @@ -889,6 +1020,7 @@ impl Default for FileReaderBuilder { max_footer_fb_tables: verifier_options.max_tables, max_footer_fb_depth: verifier_options.max_depth, projection: None, + skip_validations: false, } } } @@ -907,6 +1039,14 @@ impl FileReaderBuilder { self } + /// Skip validations when creating underlying [`ArrayData`]. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// Set `skip_validations` to true only if you are certain. + pub fn with_skip_validations(mut self, skip_validations: bool) -> Self { + self.skip_validations = skip_validations; + self + } + /// Flatbuffers option for parsing the footer. Controls the max number of fields and /// metadata key-value pairs that can be parsed from the schema of the footer. /// @@ -989,7 +1129,8 @@ impl FileReaderBuilder { } } - let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()); + let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()) + .with_skip_validations(self.skip_validations); if let Some(projection) = self.projection { decoder = decoder.with_projection(projection) } @@ -1075,6 +1216,23 @@ impl FileReader { builder.build(reader) } + /// Try to create a new file reader without validations. + /// + /// This is useful when the file is known to be valid and the user wants to skip validations. + /// This might be useful when the content is trusted and the user wants to avoid the overhead of + /// validating the content. + pub fn try_new_unvalidated( + reader: R, + projection: Option>, + ) -> Result { + let builder = FileReaderBuilder { + projection, + skip_validations: true, + ..Default::default() + }; + builder.build(reader) + } + /// Return user defined customized metadata pub fn custom_metadata(&self) -> &HashMap { &self.custom_metadata @@ -1168,6 +1326,10 @@ pub struct StreamReader { /// Optional projection projection: Option<(Vec, Schema)>, + + /// Specifies whether or not skip validations when creating underlying [`ArrayData`]. + /// This can lead to undefined behavior if the data is not correctly formatted. + skip_validations: bool, } impl fmt::Debug for StreamReader { @@ -1247,6 +1409,7 @@ impl StreamReader { finished: false, dictionaries_by_id, projection, + skip_validations: false, }) } @@ -1269,6 +1432,16 @@ impl StreamReader { self.finished } + /// Specifies whether or not skip validations when creating underlying [`ArrayData`]. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// + /// Notes: + /// * If `skip_validations` is true, `require_alignment` is ignored. + pub fn with_skip_validations(mut self, skip_validations: bool) -> Self { + self.skip_validations = skip_validations; + self + } + fn maybe_next(&mut self) -> Result, ArrowError> { if self.finished { return Ok(None); @@ -1334,6 +1507,7 @@ impl StreamReader { self.projection.as_ref().map(|x| x.0.as_ref()), &message.version(), false, + self.skip_validations, ) .map(Some) } @@ -1354,6 +1528,7 @@ impl StreamReader { &mut self.dictionaries_by_id, &message.version(), false, + self.skip_validations, )?; // read the next message until we encounter a RecordBatch @@ -2184,6 +2359,7 @@ mod tests { None, &message.version(), false, + false, ) .unwrap(); assert_eq!(batch, roundtrip); @@ -2222,6 +2398,7 @@ mod tests { None, &message.version(), true, + false, ); let error = result.unwrap_err(); diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 9b0eea9b6198..127d346cf8c5 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -42,6 +42,8 @@ pub struct StreamDecoder { buf: MutableBuffer, /// Whether or not array data in input buffers are required to be aligned require_alignment: bool, + /// Whether or not to skip validation for underlying array creations + skip_validations: bool, } #[derive(Debug)] @@ -102,6 +104,19 @@ impl StreamDecoder { self } + /// Specifies whether or not to skip validations when creating [`ArrayData`]. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// Set `skip_validations` to true only if you are certain. + /// + /// Notes: + /// * If `skip_validations` is true, `require_alignment` is ignored. + /// * If `skip_validations` is true, it uses [`arrow_data::ArrayDataBuilder::build_unchecked`] to + /// construct [`arrow_data::ArrayData`] under the hood. + pub fn with_skip_validations(mut self, skip_validations: bool) -> Self { + self.skip_validations = skip_validations; + self + } + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] /// /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. @@ -219,6 +234,7 @@ impl StreamDecoder { None, &version, self.require_alignment, + self.skip_validations, )?; self.state = DecoderState::default(); return Ok(Some(batch)); @@ -235,6 +251,7 @@ impl StreamDecoder { &mut self.dictionaries, &version, self.require_alignment, + self.skip_validations, )?; self.state = DecoderState::default(); }