From cff378b9e5c82d9f7cf466da75f37b1322353280 Mon Sep 17 00:00:00 2001 From: Eugenio Salvador Arellano Ruiz Date: Fri, 21 Jun 2024 17:55:02 +0200 Subject: [PATCH] Initial implementation of polars plugin --- frictionless/formats/__init__.py | 1 + frictionless/formats/pandas/plugin.py | 2 +- frictionless/formats/polars/__init__.py | 3 + .../formats/polars/__spec__/test_parser.py | 200 ++++++++++++++++++ frictionless/formats/polars/control.py | 12 ++ frictionless/formats/polars/parser.py | 142 +++++++++++++ frictionless/formats/polars/plugin.py | 41 ++++ frictionless/helpers/general.py | 2 +- frictionless/platform.py | 7 + frictionless/resources/table.py | 6 + pyproject.toml | 1 + 11 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 frictionless/formats/polars/__init__.py create mode 100644 frictionless/formats/polars/__spec__/test_parser.py create mode 100644 frictionless/formats/polars/control.py create mode 100644 frictionless/formats/polars/parser.py create mode 100644 frictionless/formats/polars/plugin.py diff --git a/frictionless/formats/__init__.py b/frictionless/formats/__init__.py index 75d4902773..fa51004989 100644 --- a/frictionless/formats/__init__.py +++ b/frictionless/formats/__init__.py @@ -9,6 +9,7 @@ from .markdown import * from .ods import * from .pandas import * +from .polars import * from .parquet import * from .qsv import * from .spss import * diff --git a/frictionless/formats/pandas/plugin.py b/frictionless/formats/pandas/plugin.py index e933eb180e..267dd53f03 100644 --- a/frictionless/formats/pandas/plugin.py +++ b/frictionless/formats/pandas/plugin.py @@ -28,7 +28,7 @@ def create_parser(self, resource: Resource): def detect_resource(self, resource: Resource): if resource.data is not None: - if helpers.is_type(resource.data, "DataFrame"): + if helpers.is_type(resource.data, "pandas.core.frame.DataFrame"): resource.format = resource.format or "pandas" if resource.format == "pandas": if resource.data is None: diff --git a/frictionless/formats/polars/__init__.py b/frictionless/formats/polars/__init__.py new file mode 100644 index 0000000000..10a49142c7 --- /dev/null +++ b/frictionless/formats/polars/__init__.py @@ -0,0 +1,3 @@ +from .control import PolarsControl as PolarsControl +from .parser import PolarsParser as PolarsParser +from .plugin import PolarsPlugin as PolarsPlugin diff --git a/frictionless/formats/polars/__spec__/test_parser.py b/frictionless/formats/polars/__spec__/test_parser.py new file mode 100644 index 0000000000..d7064d01b1 --- /dev/null +++ b/frictionless/formats/polars/__spec__/test_parser.py @@ -0,0 +1,200 @@ +from datetime import datetime, time +from decimal import Decimal + +import isodate +import polars as pl +import pytz +from dateutil.tz import tzoffset, tzutc +from frictionless import Package +from frictionless.resources import TableResource + +# Read + + +def test_polars_parser(): + dataframe = pl.DataFrame(data={"id": [1, 2], "name": ["english", "中国人"]}) + with TableResource(data=dataframe) as resource: + assert resource.header == ["id", "name"] + assert resource.read_rows() == [ + {"id": 1, "name": "english"}, + {"id": 2, "name": "中国人"}, + ] + + +def test_polars_parser_from_dataframe_with_datetime(): + # Polars does not have the concept of an index! + df = pl.read_csv("data/vix.csv", separator=";", try_parse_dates=True) # type: ignore + with TableResource(data=df) as resource: + # Assert meta + assert resource.schema.to_descriptor() == { + "fields": [ + {"name": "Date", "type": "datetime"}, + {"name": "VIXClose", "type": "number"}, + {"name": "VIXHigh", "type": "number"}, + {"name": "VIXLow", "type": "number"}, + {"name": "VIXOpen", "type": "number"}, + ] + } + rows = resource.read_rows() + # Assert rows + assert rows == [ + { + "Date": datetime(2004, 1, 5, tzinfo=pytz.utc), + "VIXClose": Decimal("17.49"), + "VIXHigh": Decimal("18.49"), + "VIXLow": Decimal("17.44"), + "VIXOpen": Decimal("18.45"), + }, + { + "Date": datetime(2004, 1, 6, tzinfo=pytz.utc), + "VIXClose": Decimal("16.73"), + "VIXHigh": Decimal("17.67"), + "VIXLow": Decimal("16.19"), + "VIXOpen": Decimal("17.66"), + }, + ] + + +# Write + + +def test_pandas_parser_write(): + source = TableResource(path="data/table.csv") + target = source.write(format="polars") + assert target.data.to_dicts() == [ # type: ignore + {"id": 1, "name": "english"}, + {"id": 2, "name": "中国人"}, + ] + + +def test_polars_parser_nan_in_integer_resource_column(): + # see issue 1109 + res = TableResource( + data=[ + ["int", "number", "string"], + ["1", "2.3", "string"], + ["", "4.3", "string"], + ["3", "3.14", "string"], + ] + ) + df = res.to_polars() + assert df.dtypes == [pl.Int64, pl.Float64, pl.String] # type: ignore + + +def test_pandas_parser_nan_in_integer_csv_column(): + res = TableResource(path="data/issue-1109.csv") + df = res.to_polars() + assert df.dtypes == [pl.Int64, pl.Float64, pl.String] # type: ignore + + +def test_pandas_parser_write_types(): + source = Package("data/storage/types.json").get_table_resource("types") + target = source.write(format="polars") + with target: + # Assert schema + assert target.schema.to_descriptor() == { + "fields": [ + {"name": "any", "type": "string"}, # type fallback + {"name": "array", "type": "array"}, + {"name": "boolean", "type": "boolean"}, + {"name": "date", "type": "datetime"}, # type downgrade + {"name": "date_year", "type": "datetime"}, # type downgrade/fmt removal + {"name": "datetime", "type": "datetime"}, + {"name": "duration", "type": "duration"}, + {"name": "geojson", "type": "object"}, + {"name": "geopoint", "type": "array"}, + {"name": "integer", "type": "integer"}, + {"name": "number", "type": "number"}, + {"name": "object", "type": "object"}, + {"name": "string", "type": "string"}, + {"name": "time", "type": "time"}, + {"name": "year", "type": "integer"}, # type downgrade + {"name": "yearmonth", "type": "array"}, # type downgrade + ], + } + + # Assert rows + assert target.read_rows() == [ + { + "any": "中国人", + "array": ["Mike", "John"], + "boolean": True, + "date": datetime(2015, 1, 1), + "date_year": datetime(2015, 1, 1), + "datetime": datetime(2015, 1, 1, 3, 0), + "duration": isodate.parse_duration("P1Y1M"), + "geojson": {"type": "Point", "coordinates": [33, 33.33]}, + "geopoint": [30, 70], + "integer": 1, + "number": 7, + "object": {"chars": 560}, + "string": "english", + "time": time(3, 0), + "year": 2015, + "yearmonth": [2015, 1], + }, + ] + + +def test_pandas_write_constraints(): + source = Package("data/storage/constraints.json").get_table_resource("constraints") + target = source.write(format="pandas") + with target: + # Assert schema + assert target.schema.to_descriptor() == { + "fields": [ + {"name": "required", "type": "string"}, # constraint removal + {"name": "minLength", "type": "string"}, # constraint removal + {"name": "maxLength", "type": "string"}, # constraint removal + {"name": "pattern", "type": "string"}, # constraint removal + {"name": "enum", "type": "string"}, # constraint removal + {"name": "minimum", "type": "integer"}, # constraint removal + {"name": "maximum", "type": "integer"}, # constraint removal + ], + } + + # Assert rows + assert target.read_rows() == [ + { + "required": "passing", + "minLength": "passing", + "maxLength": "passing", + "pattern": "passing", + "enum": "passing", + "minimum": 5, + "maximum": 5, + }, + ] + + +def test_pandas_parser_write_timezone(): + source = TableResource(path="data/timezone.csv") + target = source.write(format="pandas") + with target: + # Assert schema + assert target.schema.to_descriptor() == { + "fields": [ + {"name": "datetime", "type": "datetime"}, + {"name": "time", "type": "time"}, + ], + } + + # Assert rows + assert target.read_rows() == [ + { + "datetime": datetime(2020, 1, 1, 15), + "time": time(15), + }, + { + "datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzutc()), + "time": time(15, 0, tzinfo=tzutc()), + }, + { + "datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzoffset(None, 10800)), + "time": time(15, 0, tzinfo=tzoffset(None, 10800)), + }, + { + "datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzoffset(None, -10800)), + "time": time(15, 0, tzinfo=tzoffset(None, -10800)), + }, + ] diff --git a/frictionless/formats/polars/control.py b/frictionless/formats/polars/control.py new file mode 100644 index 0000000000..51d6a761a7 --- /dev/null +++ b/frictionless/formats/polars/control.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import attrs + +from ...dialect import Control + + +@attrs.define(kw_only=True, repr=False) +class PolarsControl(Control): + """Polars dialect representation""" + + type = "polars" diff --git a/frictionless/formats/polars/parser.py b/frictionless/formats/polars/parser.py new file mode 100644 index 0000000000..22b8dcf3b5 --- /dev/null +++ b/frictionless/formats/polars/parser.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import datetime +import decimal +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +from dateutil.tz import tzoffset + +from ... import types +from ...platform import platform +from ...schema import Field, Schema +from ...system import Parser + +if TYPE_CHECKING: + from ...resources import TableResource + + +class PolarsParser(Parser): + """Polars parser implementation.""" + + supported_types = [ + "array", + "boolean", + "datetime", + "date", + "duration", + "integer", + "number", + "object", + "string", + "time", + ] + + # Read + + def read_cell_stream_create(self): + pl = platform.polars + assert isinstance(self.resource.data, pl.DataFrame) + dataframe = self.resource.data + + # Schema + schema = self.__read_convert_schema() + if not self.resource.schema: + self.resource.schema = schema + + # Lists + yield schema.field_names + for row in dataframe.iter_rows(): # type: ignore + cells: List[Any] = [v if v is not pl.Null else None for v in row] + yield cells + + def __read_convert_schema(self): + pl = platform.polars + dataframe = self.resource.data + schema = Schema() + + # Fields + for name, dtype in zip(dataframe.columns, dataframe.dtypes): # type: ignore + sample = dataframe.select(pl.first(name)).item() if len(dataframe) else None # type: ignore + type = self.__read_convert_type(dtype, sample=sample) # type: ignore + field = Field.from_descriptor({"name": name, "type": type}) + schema.add_field(field) + + # Return schema + return schema + + def __read_convert_type(self, _: Any, sample: Optional[types.ISample] = None): + + # Python types + if sample is not None: + if isinstance(sample, bool): # type: ignore + return "boolean" + elif isinstance(sample, int): # type: ignore + return "integer" + elif isinstance(sample, float): # type: ignore + return "number" + if isinstance(sample, (list, tuple)): # type: ignore + return "array" + elif isinstance(sample, datetime.datetime): + return "datetime" + elif isinstance(sample, datetime.date): + return "date" + elif isinstance(sample, platform.isodate.Duration): # type: ignore + return "duration" + elif isinstance(sample, dict): + return "object" + elif isinstance(sample, str): + return "string" + elif isinstance(sample, datetime.time): + return "time" + + # Default + return "string" + + # Write + + def write_row_stream(self, source: TableResource): + pl = platform.polars + data_rows: List[Tuple[Any]] = [] + fixed_types = {} + with source: + for row in source.row_stream: + data_values: List[Any] = [] + for field in source.schema.fields: + value = row[field.name] + if isinstance(value, dict): + value = str(value) + if isinstance(value, decimal.Decimal): + value = float(value) + if isinstance(value, datetime.datetime) and value.tzinfo: + value = value.astimezone(datetime.timezone.utc) + if isinstance(value, datetime.time) and value.tzinfo: + value = value.replace( + tzinfo=tzoffset( + datetime.timezone.utc, + value.utcoffset().total_seconds(), # type: ignore + ) + ) + if value is None and field.type in ("number", "integer"): + fixed_types[field.name] = "number" + value = None + data_values.append(value) + data_rows.append(tuple(data_values)) + # Create dtypes/columns + columns: List[str] = [] + for field in source.schema.fields: + if field.name not in source.schema.primary_key: + columns.append(field.name) + + # Create/set dataframe + dataframe = pl.DataFrame(data_rows) + dataframe.columns = columns + + for field in source.schema.fields: + if ( + field.type == "integer" + and field.name in dataframe.columns + and str(dataframe.select(field.name).dtypes[0]) != "int" + ): + dataframe = dataframe.with_columns(pl.col(field.name).cast(int)) + + self.resource.data = dataframe diff --git a/frictionless/formats/polars/plugin.py b/frictionless/formats/polars/plugin.py new file mode 100644 index 0000000000..14885f4f31 --- /dev/null +++ b/frictionless/formats/polars/plugin.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from ... import helpers +from ...platform import platform +from ...system import Plugin +from .control import PolarsControl +from .parser import PolarsParser + +if TYPE_CHECKING: + from ...resource import Resource + + +# NOTE: +# We need to ensure that the way we detect pandas dataframe is good enough. +# We don't want to be importing pandas and checking the type without a good reason + + +class PolarsPlugin(Plugin): + """Plugin for Polars""" + + # Hooks + + def create_parser(self, resource: Resource): + if resource.format == "polars": + return PolarsParser(resource) + + def detect_resource(self, resource: Resource): + if resource.data is not None: + if helpers.is_type(resource.data, "polars.dataframe.frame.DataFrame"): + resource.format = resource.format or "polars" + if resource.format == "polars": + if resource.data is None: + resource.data = platform.polars.DataFrame() + resource.datatype = resource.datatype or "table" + resource.mediatype = resource.mediatype or "application/polars" + + def select_control_class(self, type: Optional[str] = None): + if type == "polars": + return PolarsControl diff --git a/frictionless/helpers/general.py b/frictionless/helpers/general.py index 45e27ce5b7..fefe7d7723 100644 --- a/frictionless/helpers/general.py +++ b/frictionless/helpers/general.py @@ -266,7 +266,7 @@ def is_zip_descriptor(descriptor: Union[str, Dict[str, Any]]): def is_type(object: type, name: str): - return type(object).__name__ == name + return type(object).__module__ + "." + type(object).__name__ == name def parse_json_string(string: Optional[str]): diff --git a/frictionless/platform.py b/frictionless/platform.py index b88584d99d..c5fa31930d 100644 --- a/frictionless/platform.py +++ b/frictionless/platform.py @@ -325,6 +325,13 @@ def pandas(self): return pandas + @cached_property + @extras(name="polars") + def polars(self): + import polars # type: ignore + + return polars + @cached_property @extras(name="pandas") def pandas_core_dtypes_api(self): diff --git a/frictionless/resources/table.py b/frictionless/resources/table.py index 98a76f3473..de9355dfdc 100644 --- a/frictionless/resources/table.py +++ b/frictionless/resources/table.py @@ -618,6 +618,12 @@ def to_pandas(self, *, dialect: Optional[Dialect] = None): target = self.write(Resource(format="pandas", dialect=dialect)) # type: ignore return target.data + def to_polars(self, *, dialect: Optional[Dialect] = None): + """Helper to export resource as an Polars dataframe""" + dialect = dialect or Dialect() + target = self.write(Resource(format="polars", dialect=dialect)) # type: ignore + return target.data + def to_snap(self, *, json: bool = False): """Create a snapshot from the resource diff --git a/pyproject.toml b/pyproject.toml index 011a362302..7e9283ecb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ html = ["pyquery>=1.4"] mysql = ["sqlalchemy>=1.4", "pymysql>=1.0"] ods = ["ezodf>=0.3", "lxml>=4.0"] pandas = ["pyarrow>=14.0", "pandas>=1.0"] +polars = ["polars>=0.20"] parquet = ["fastparquet>=0.8"] postgresql = ["sqlalchemy>=1.4", "psycopg>=3.0", "psycopg2>=2.9"] spss = ["savReaderWriter>=3.0"]