diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 2eca7ae3417..df6cc4b3e93 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -1102,9 +1102,12 @@ def sort(self, order_by: Union[str, Sequence[str]], order_by = to_sequence(order_by) if not order: order = (SortDirection.ASCENDING,) * len(order_by) - order = to_sequence(order) - if len(order_by) != len(order): - raise DHError(message="The number of sort columns must be the same as the number of sort directions.") + else: + order = to_sequence(order) + if any([o not in (SortDirection.ASCENDING, SortDirection.DESCENDING) for o in order]): + raise DHError(message="The sort direction must be either 'ASCENDING' or 'DESCENDING'.") + if len(order_by) != len(order): + raise DHError(message="The number of sort columns must be the same as the number of sort directions.") sort_columns = [_sort_column(col, dir_) for col, dir_ in zip(order_by, order)] j_sc_list = j_array_list(sort_columns) @@ -2008,6 +2011,9 @@ def partition_by(self, by: Union[str, Sequence[str]], drop_keys: bool = False) - DHError """ try: + if not isinstance(drop_keys, bool): + raise DHError(message="drop_keys must be a boolean value.") + by = to_sequence(by) return PartitionedTable(j_partitioned_table=self.j_table.partitionBy(drop_keys, *by)) except Exception as e: @@ -2737,12 +2743,14 @@ def sort(self, order_by: Union[str, Sequence[str]], DHError """ try: - order_by = to_sequence(order_by) if not order: order = (SortDirection.ASCENDING,) * len(order_by) - order = to_sequence(order) - if len(order_by) != len(order): - raise ValueError("The number of sort columns must be the same as the number of sort directions.") + else: + order = to_sequence(order) + if any([o not in (SortDirection.ASCENDING, SortDirection.DESCENDING) for o in order]): + raise DHError(message="The sort direction must be either 'ASCENDING' or 'DESCENDING'.") + if len(order_by) != len(order): + raise DHError(message="The number of sort columns must be the same as the number of sort directions.") sort_columns = [_sort_column(col, dir_) for col, dir_ in zip(order_by, order)] j_sc_list = j_array_list(sort_columns) diff --git a/py/server/tests/test_parquet.py b/py/server/tests/test_parquet.py index 612d9bbc759..2d49f7c82cd 100644 --- a/py/server/tests/test_parquet.py +++ b/py/server/tests/test_parquet.py @@ -658,7 +658,7 @@ def verify_table_from_disk(table): self.assertTrue(len(table.columns)) self.assertTrue(table.columns[0].name == "X") self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING) - self.assert_table_equals(table.select().sort("X", "Y"), source.sort("X", "Y")) + self.assert_table_equals(table.select().sort(["X", "Y"]), source.sort(["X", "Y"])) def verify_file_names(): partition_dir_path = os.path.join(root_dir, 'X=Aa') diff --git a/py/server/tests/test_table.py b/py/server/tests/test_table.py index 5e30966d906..c9abe604a08 100644 --- a/py/server/tests/test_table.py +++ b/py/server/tests/test_table.py @@ -1105,6 +1105,16 @@ def my_fn(vals): t = partitioned_by_formula() self.assertIsNotNone(t) + def test_arg_validation(self): + t = empty_table(1).update(["A=i", "B=i", "C=i"]) + with self.assertRaises(DHError) as cm: + t.sort("A", "B") + self.assertIn("The sort direction must be", str(cm.exception)) + + with self.assertRaises(DHError) as cm: + t.partition_by("A", "B") + self.assertIn("drop_keys must be", str(cm.exception)) + if __name__ == "__main__": unittest.main()