From f7af294547c141f0510db196dbd4a248f748f639 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 28 Dec 2024 12:10:56 +0100 Subject: [PATCH] feat: reads using global ctx --- python/datafusion/__init__.py | 6 + python/datafusion/io.py | 181 ++++++++++++++++++++++++++ python/tests/test_context.py | 1 + python/tests/test_io.py | 97 ++++++++++++++ python/tests/test_wrapper_coverage.py | 2 + src/context.rs | 12 +- src/utils.rs | 8 ++ 7 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 python/datafusion/io.py create mode 100644 python/tests/test_io.py diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index e0bc57f44..0dca880d8 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -48,6 +48,8 @@ from .dataframe import DataFrame +from .io import read_parquet, read_avro, read_csv, read_json + from .expr import ( Expr, WindowFrame, @@ -89,6 +91,10 @@ "functions", "object_store", "substrait", + "read_parquet", + "read_avro", + "read_csv", + "read_json", ] diff --git a/python/datafusion/io.py b/python/datafusion/io.py new file mode 100644 index 000000000..f4f441975 --- /dev/null +++ b/python/datafusion/io.py @@ -0,0 +1,181 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""IO read functions using global context.""" + +import pathlib + +from datafusion.dataframe import DataFrame +from datafusion.expr import Expr +import pyarrow +from ._internal import SessionContext as SessionContextInternal + + +def read_parquet( + path: str | pathlib.Path, + table_partition_cols: list[tuple[str, str]] | None = None, + parquet_pruning: bool = True, + file_extension: str = ".parquet", + skip_metadata: bool = True, + schema: pyarrow.Schema | None = None, + file_sort_order: list[list[Expr]] | None = None, +) -> DataFrame: + """Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`. + + Args: + path: Path to the Parquet file. + table_partition_cols: Partition columns. + parquet_pruning: Whether the parquet reader should use the predicate + to prune row groups. + file_extension: File extension; only files with this extension are + selected for data input. + skip_metadata: Whether the parquet reader should skip any metadata + that may be in the file schema. This can help avoid schema + conflicts due to metadata. + schema: An optional schema representing the parquet files. If None, + the parquet reader will try to infer it based on data in the + file. + file_sort_order: Sort order for the file. + + Returns: + DataFrame representation of the read Parquet files + """ + if table_partition_cols is None: + table_partition_cols = [] + return DataFrame( + SessionContextInternal._global_ctx().read_parquet( + str(path), + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, + ) + ) + + +def read_json( + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + schema_infer_max_records: int = 1000, + file_extension: str = ".json", + table_partition_cols: list[tuple[str, str]] | None = None, + file_compression_type: str | None = None, +) -> DataFrame: + """Read a line-delimited JSON data source. + + Args: + path: Path to the JSON file. + schema: The data source schema. + schema_infer_max_records: Maximum number of rows to read from JSON + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + + Returns: + DataFrame representation of the read JSON files. + """ + if table_partition_cols is None: + table_partition_cols = [] + return DataFrame( + SessionContextInternal._global_ctx().read_json( + str(path), + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + +def read_csv( + path: str | pathlib.Path | list[str] | list[pathlib.Path], + schema: pyarrow.Schema | None = None, + has_header: bool = True, + delimiter: str = ",", + schema_infer_max_records: int = 1000, + file_extension: str = ".csv", + table_partition_cols: list[tuple[str, str]] | None = None, + file_compression_type: str | None = None, +) -> DataFrame: + """Read a CSV data source. + + Args: + path: Path to the CSV file + schema: An optional schema representing the CSV files. If None, the + CSV reader will try to infer it based on data in file. + has_header: Whether the CSV file have a header. If schema inference + is run on a file with no headers, default column names are + created. + delimiter: An optional column delimiter. + schema_infer_max_records: Maximum number of rows to read from CSV + files for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns. + file_compression_type: File compression type. + + Returns: + DataFrame representation of the read CSV files + """ + if table_partition_cols is None: + table_partition_cols = [] + + path = [str(p) for p in path] if isinstance(path, list) else str(path) + + return DataFrame( + SessionContextInternal._global_ctx().read_csv( + path, + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + ) + ) + + +def read_avro( + path: str | pathlib.Path, + schema: pyarrow.Schema | None = None, + file_partition_cols: list[tuple[str, str]] | None = None, + file_extension: str = ".avro", +) -> DataFrame: + """Create a :py:class:`DataFrame` for reading Avro data source. + + Args: + path: Path to the Avro file. + schema: The data source schema. + file_partition_cols: Partition columns. + file_extension: File extension to select. + + Returns: + DataFrame representation of the read Avro file + """ + if file_partition_cols is None: + file_partition_cols = [] + return DataFrame( + SessionContextInternal._global_ctx().read_avro( + str(path), schema, file_partition_cols, file_extension + ) + ) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index ab86faa9d..2d60cca8b 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -23,6 +23,7 @@ import pyarrow.dataset as ds import pytest + from datafusion import ( DataFrame, RuntimeConfig, diff --git a/python/tests/test_io.py b/python/tests/test_io.py new file mode 100644 index 000000000..606e17755 --- /dev/null +++ b/python/tests/test_io.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import pathlib + +from datafusion import column +import pyarrow as pa + + +from datafusion.io import read_avro, read_csv, read_json, read_parquet + + +def test_read_json_global_ctx(ctx): + path = os.path.dirname(os.path.abspath(__file__)) + + # Default + test_data_path = os.path.join(path, "data_test_context", "data.json") + df = read_json(test_data_path) + result = df.collect() + + assert result[0].column(0) == pa.array(["a", "b", "c"]) + assert result[0].column(1) == pa.array([1, 2, 3]) + + # Schema + schema = pa.schema( + [ + pa.field("A", pa.string(), nullable=True), + ] + ) + df = read_json(test_data_path, schema=schema) + result = df.collect() + + assert result[0].column(0) == pa.array(["a", "b", "c"]) + assert result[0].schema == schema + + # File extension + test_data_path = os.path.join(path, "data_test_context", "data.json") + df = read_json(test_data_path, file_extension=".json") + result = df.collect() + + assert result[0].column(0) == pa.array(["a", "b", "c"]) + assert result[0].column(1) == pa.array([1, 2, 3]) + + +def test_read_parquet_global(): + parquet_df = read_parquet(path="parquet/data/alltypes_plain.parquet") + parquet_df.show() + assert parquet_df is not None + + path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" + parquet_df = read_parquet(path=path) + assert parquet_df is not None + + +def test_read_csv(): + csv_df = read_csv(path="testing/data/csv/aggregate_test_100.csv") + csv_df.select(column("c1")).show() + + +def test_read_csv_list(): + csv_df = read_csv(path=["testing/data/csv/aggregate_test_100.csv"]) + expected = csv_df.count() * 2 + + double_csv_df = read_csv( + path=[ + "testing/data/csv/aggregate_test_100.csv", + "testing/data/csv/aggregate_test_100.csv", + ] + ) + actual = double_csv_df.count() + + double_csv_df.select(column("c1")).show() + assert actual == expected + + +def test_read_avro(): + avro_df = read_avro(path="testing/data/avro/alltypes_plain.avro") + avro_df.show() + assert avro_df is not None + + path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" + avro_df = read_avro(path=path) + assert avro_df is not None diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index 86f2d57f2..ac064ba95 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -34,6 +34,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None: return for attr in dir(internal_obj): + if attr in ["_global_ctx"]: + continue assert attr in dir(wrapped_obj) internal_attr = getattr(internal_obj, attr) diff --git a/src/context.rs b/src/context.rs index 8675e97df..87ba14ed3 100644 --- a/src/context.rs +++ b/src/context.rs @@ -43,7 +43,7 @@ use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udwf::PyWindowUDF; -use crate::utils::{get_tokio_runtime, wait_for_future}; +use crate::utils::{get_global_ctx, get_tokio_runtime, wait_for_future}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -68,7 +68,7 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; -use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; +use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -299,6 +299,14 @@ impl PySessionContext { }) } + #[classmethod] + #[pyo3(signature = ())] + fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { + Ok(Self { + ctx: get_global_ctx().clone(), + }) + } + /// Register an object store with the given name #[pyo3(signature = (scheme, store, host=None))] pub fn register_object_store( diff --git a/src/utils.rs b/src/utils.rs index 7fb23cafe..8d9010939 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,6 +17,7 @@ use crate::errors::DataFusionError; use crate::TokioRuntime; +use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; use pyo3::prelude::*; use std::future::Future; @@ -35,6 +36,13 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) } +/// Utility to get the Global Datafussion CTX +#[inline] +pub(crate) fn get_global_ctx() -> &'static SessionContext { + static CTX: OnceLock = OnceLock::new(); + CTX.get_or_init(|| SessionContext::new()) +} + /// Utility to collect rust futures with GIL released pub fn wait_for_future(py: Python, f: F) -> F::Output where