diff --git a/rerun_py/tests/unit/test_dataframe.py b/rerun_py/tests/unit/test_dataframe.py index 93183f42f8d12..5a13feb13568c 100644 --- a/rerun_py/tests/unit/test_dataframe.py +++ b/rerun_py/tests/unit/test_dataframe.py @@ -59,6 +59,32 @@ def setup_method(self) -> None: self.recording = rr.dataframe.load_recording(rrd) + self.expected_index0 = pa.array( + [1], + type=pa.int64(), + ) + + self.expected_index1 = pa.array( + [7], + type=pa.int64(), + ) + + self.expected_pos0 = pa.array( + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ], + type=rr.components.Position3D.arrow_type(), + ) + + self.expected_pos1 = pa.array( + [ + [10, 11, 12], + ], + type=rr.components.Position3D.arrow_type(), + ) + def test_recording_info(self) -> None: assert self.recording.application_id() == APP_ID assert self.recording.recording_id() == str(RECORDING_ID) @@ -85,38 +111,53 @@ def test_select_columns(self) -> None: assert table.num_columns == 2 assert table.num_rows == 2 - expected_index0 = pa.array( - [1], - type=pa.int64(), - ) + assert table.column("my_index")[0].equals(self.expected_index0[0]) + assert table.column("my_index")[1].equals(self.expected_index1[0]) + assert table.column("/points:Position3D")[0].values.equals(self.expected_pos0) + assert table.column("/points:Position3D")[1].values.equals(self.expected_pos1) - expected_index1 = pa.array( - [7], - type=pa.int64(), - ) + def test_index_values(self) -> None: + view = self.recording.view(index="my_index", contents="points") + view = view.filter_index_values([1, 7, 9]) - expected_pos0 = pa.array( - [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ], - type=rr.components.Position3D.arrow_type(), - ) + batches = view.select() + table = pa.Table.from_batches(batches, batches.schema) - expected_pos1 = pa.array( - [ - [10, 11, 12], - ], - type=rr.components.Position3D.arrow_type(), - ) + # my_index, log_time, log_tick, points, colors + assert table.num_columns == 5 + assert table.num_rows == 2 + + assert table.column("my_index")[0].equals(self.expected_index0[0]) + assert table.column("my_index")[1].equals(self.expected_index1[0]) + + # This is a chunked array + new_selection_chunked = table.column("my_index").take([1]) + + # This is a single array + new_selection = new_selection_chunked.combine_chunks() + + view2 = view.filter_index_values(new_selection_chunked) + batches = view2.select() + table2 = pa.Table.from_batches(batches, batches.schema) + + # my_index, log_time, log_tick, points, colors + assert table2.num_columns == 5 + assert table2.num_rows == 1 + + assert table2.column("my_index")[0].equals(self.expected_index1[0]) + + view3 = view.filter_index_values(new_selection) + batches = view3.select() + table3 = pa.Table.from_batches(batches, batches.schema) + + assert table3 == table2 - print(table.schema) + ## Manually create a pyarrow array with no matches + view4 = view.filter_index_values(pa.array([8], type=pa.int64())) + batches = view4.select() + table4 = pa.Table.from_batches(batches, batches.schema) - assert table.column("my_index")[0].equals(expected_index0[0]) - assert table.column("my_index")[1].equals(expected_index1[0]) - assert table.column("/points:Position3D")[0].values.equals(expected_pos0) - assert table.column("/points:Position3D")[1].values.equals(expected_pos1) + assert table4.num_rows == 0 def test_view_syntax(self) -> None: good_content_expressions = [