diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 16d65f68..69062fd3 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -30,13 +30,12 @@ except ImportError: import importlib_metadata -from datafusion.col import col, column - from . import functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config from .catalog import Catalog, Database, Table +from .col import col, column from .common import ( DFSchema, ) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 1fd63bdc..49c0b2f8 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -191,6 +191,7 @@ def __init__( writer_version: str = "1.0", skip_arrow_metadata: bool = False, compression: Optional[str] = "zstd(3)", + compression_level: Optional[int] = None, dictionary_enabled: Optional[bool] = True, dictionary_page_size_limit: int = 1024 * 1024, statistics_enabled: Optional[str] = "page", @@ -213,7 +214,10 @@ def __init__( self.write_batch_size = write_batch_size self.writer_version = writer_version self.skip_arrow_metadata = skip_arrow_metadata - self.compression = compression + if compression_level is not None: + self.compression = f"{compression}({compression_level})" + else: + self.compression = compression self.dictionary_enabled = dictionary_enabled self.dictionary_page_size_limit = dictionary_page_size_limit self.statistics_enabled = statistics_enabled @@ -870,10 +874,34 @@ def write_csv(self, path: str | pathlib.Path, with_header: bool = False) -> None """ self.df.write_csv(str(path), with_header) + @overload + def write_parquet( + self, + path: str | pathlib.Path, + compression: str, + compression_level: int | None = None, + ) -> None: ... + + @overload + def write_parquet( + self, + path: str | pathlib.Path, + compression: Compression = Compression.ZSTD, + compression_level: int | None = None, + ) -> None: ... + + @overload + def write_parquet( + self, + path: str | pathlib.Path, + compression: ParquetWriterOptions, + compression_level: None = None, + ) -> None: ... + def write_parquet( self, path: str | pathlib.Path, - compression: Union[str, Compression] = Compression.ZSTD, + compression: Union[str, Compression, ParquetWriterOptions] = Compression.ZSTD, compression_level: int | None = None, ) -> None: """Execute the :py:class:`DataFrame` and write the results to a Parquet file. @@ -894,7 +922,13 @@ def write_parquet( recommended range is 1 to 22, with the default being 4. Higher levels provide better compression but slower speed. """ - # Convert string to Compression enum if necessary + if isinstance(compression, ParquetWriterOptions): + if compression_level is not None: + msg = "compression_level should be None when using ParquetWriterOptions" + raise ValueError(msg) + self.write_parquet_with_options(path, compression) + return + if isinstance(compression, str): compression = Compression.from_str(compression) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 3c9b97f2..deaa30b3 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -2038,6 +2038,22 @@ def test_write_parquet_with_options_column_options(df, tmp_path): assert col["encodings"] == result["encodings"] +def test_write_parquet_options(df, tmp_path): + options = ParquetWriterOptions(compression="gzip", compression_level=6) + df.write_parquet(str(tmp_path), options) + + result = pq.read_table(str(tmp_path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +def test_write_parquet_options_error(df, tmp_path): + options = ParquetWriterOptions(compression="gzip", compression_level=6) + with pytest.raises(ValueError): + df.write_parquet(str(tmp_path), options, compression_level=1) + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export