diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 0fd034ee..d7489d2d 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -41,6 +41,8 @@ # excerpt from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 class Compression(Enum): + """Enum representing the available compression types for Parquet files.""" + UNCOMPRESSED = "uncompressed" SNAPPY = "snappy" GZIP = "gzip" @@ -52,6 +54,17 @@ class Compression(Enum): @classmethod def from_str(cls, value: str) -> "Compression": + """Convert a string to a Compression enum value. + + Args: + value (str): The string representation of the compression type. + + Returns: + Compression: The corresponding Compression enum value. + + Raises: + ValueError: If the string does not match any Compression enum value. + """ try: return cls(value.lower()) except ValueError: @@ -60,6 +73,14 @@ def from_str(cls, value: str) -> "Compression": ) def get_default_level(self) -> int: + """Get the default compression level for the compression type. + + Returns: + int: The default compression level. + + Raises: + KeyError: If the compression type does not have a default level. + """ # GZIP, BROTLI defaults from deltalake # https://github.com/apache/datafusion-python/pull/981#discussion_r1905619163 if self == Compression.GZIP: diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index c13cc6fd..248af0a2 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -1113,7 +1113,9 @@ def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression) df.write_parquet(str(path), compression=compression) -# test write_parquet with zstd, brotli default compression level, should complete without error +# Test write_parquet with zstd, brotli default compression level, +# ie don't specify compression level +# should complete without error @pytest.mark.parametrize("compression", ["zstd", "brotli"]) def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): path = tmp_path