From 205f0c706befa39992153f6124ae31cb621117b3 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 17:16:54 -0400 Subject: [PATCH] feat: merge all arrow columns into JSON document in deserializer --- rig-lancedb/src/utils/deserializer.rs | 140 +++++++++++++++++++------- 1 file changed, 104 insertions(+), 36 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 96c7ce7b..94d28c40 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -287,27 +287,14 @@ impl RecordBatchDeserializer for RecordBatch { .collect() } // Not yet fully supported - DataType::BinaryView => { + DataType::BinaryView + | DataType::Utf8View + | DataType::ListView(..) + | DataType::LargeListView(..) => { todo!() } - // Not yet fully supported - DataType::Utf8View => { - todo!() - } - // Not yet fully supported - DataType::ListView(..) => { - todo!() - } - // Not yet fully supported - DataType::LargeListView(..) => { - todo!() - } - // f16 currently unstable - DataType::Float16 => { - todo!() - } - // i256 currently unstable - DataType::Decimal256(..) => { + // Currently unstable + DataType::Float16 | DataType::Decimal256(..) => { todo!() } _ => { @@ -317,13 +304,32 @@ impl RecordBatchDeserializer for RecordBatch { } } + let binding = self.schema(); + let column_names = binding + .fields() + .iter() + .map(|field| field.name()) + .collect::>(); + let columns = self .columns() .iter() .map(type_matcher) .collect::, _>>()?; - serde_json::to_value(&columns).map_err(serde_to_rig_error) + Ok(Value::Object((0..self.num_rows()).fold( + serde_json::Map::new(), + |mut acc, row_i| { + columns.iter().enumerate().for_each(|(col_i, col)| { + acc.entry(column_names[col_i].to_string()).and_modify(|v| { + if let Value::Array(v_arr) = v { + v_arr.push(col[row_i].clone()) + } + }).or_insert(Value::Array(vec![col[row_i].clone()])); + }); + acc + }, + ))) } } @@ -707,22 +713,84 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!([ - [0.0, 1.0], - [0.0, 1.0], - [0, -1], - [0, 1], - [0, -1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - ["Marty", "Tony"], - ["Jerry", "Freddy"], - [[97, 98, 99], [100, 101, 102]], - [[104, 101, 108, 108, 111], [119, 111, 114, 108, 100]] - ]) + json!({ + "binary": [ + [ + 104, + 101, + 108, + 108, + 111 + ], + [ + 119, + 111, + 114, + 108, + 100 + ] + ], + "float_32": [ + 0.0, + 1.0 + ], + "float_64": [ + 0.0, + 1.0 + ], + "int_16": [ + 0, + 1 + ], + "int_32": [ + 0, + -1 + ], + "int_64": [ + 0, + 1 + ], + "int_8": [ + 0, + -1 + ], + "large_binary": [ + [ + 97, + 98, + 99 + ], + [ + 100, + 101, + 102 + ] + ], + "large_string": [ + "Jerry", + "Freddy" + ], + "string": [ + "Marty", + "Tony" + ], + "uint_16": [ + 0, + 1 + ], + "uint_32": [ + 0, + 1 + ], + "uint_64": [ + 0, + 1 + ], + "uint_8": [ + 0, + 1 + ] + }) ) }