diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 88f0e8de..0b746de4 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -252,45 +252,57 @@ impl RecordBatchDeserializer for RecordBatch { )), ))), }, - DataType::RunEndEncoded(counter_type, ..) => { - let items: Vec> = match counter_type.data_type() { + DataType::RunEndEncoded(index_type, ..) => { + let items = match index_type.data_type() { DataType::Int16 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } DataType::Int32 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } DataType::Int64 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } _ => { return Err(VectorStoreError::DatastoreError(Box::new( ArrowError::CastError(format!( - "RunEndEncoded index type is not accepted: {counter_type:?}" + "RunEndEncoded index type is not accepted: {index_type:?}" )), ))) } @@ -867,9 +879,29 @@ mod tests { let array = builder.finish(); let record_batch = - RecordBatch::try_from_iter(vec![("some_dict", Arc::new(array) as ArrayRef)]).unwrap(); + RecordBatch::try_from_iter(vec![("some_run_end", Arc::new(array) as ArrayRef)]) + .unwrap(); - assert_eq!(record_batch.deserialize().unwrap(), vec![json!({})]) + assert_eq!( + record_batch.deserialize().unwrap(), + vec![ + json!({ + "some_run_end": "abc" + }), + json!({ + "some_run_end": "" + }), + json!({ + "some_run_end": "def" + }), + json!({ + "some_run_end": "def" + }), + json!({ + "some_run_end": "abc" + }) + ] + ) } #[tokio::test]