forked from frictionlessdata/frictionless-py
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation of polars plugin
- Loading branch information
Showing
11 changed files
with
415 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .control import PolarsControl as PolarsControl | ||
from .parser import PolarsParser as PolarsParser | ||
from .plugin import PolarsPlugin as PolarsPlugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from datetime import datetime, time | ||
from decimal import Decimal | ||
|
||
import isodate | ||
import polars as pl | ||
import pytz | ||
from dateutil.tz import tzoffset, tzutc | ||
from frictionless import Package | ||
from frictionless.resources import TableResource | ||
|
||
# Read | ||
|
||
|
||
def test_polars_parser(): | ||
dataframe = pl.DataFrame(data={"id": [1, 2], "name": ["english", "中国人"]}) | ||
with TableResource(data=dataframe) as resource: | ||
assert resource.header == ["id", "name"] | ||
assert resource.read_rows() == [ | ||
{"id": 1, "name": "english"}, | ||
{"id": 2, "name": "中国人"}, | ||
] | ||
|
||
|
||
def test_polars_parser_from_dataframe_with_datetime(): | ||
# Polars does not have the concept of an index! | ||
df = pl.read_csv("data/vix.csv", separator=";", try_parse_dates=True) # type: ignore | ||
with TableResource(data=df) as resource: | ||
# Assert meta | ||
assert resource.schema.to_descriptor() == { | ||
"fields": [ | ||
{"name": "Date", "type": "datetime"}, | ||
{"name": "VIXClose", "type": "number"}, | ||
{"name": "VIXHigh", "type": "number"}, | ||
{"name": "VIXLow", "type": "number"}, | ||
{"name": "VIXOpen", "type": "number"}, | ||
] | ||
} | ||
rows = resource.read_rows() | ||
# Assert rows | ||
assert rows == [ | ||
{ | ||
"Date": datetime(2004, 1, 5, tzinfo=pytz.utc), | ||
"VIXClose": Decimal("17.49"), | ||
"VIXHigh": Decimal("18.49"), | ||
"VIXLow": Decimal("17.44"), | ||
"VIXOpen": Decimal("18.45"), | ||
}, | ||
{ | ||
"Date": datetime(2004, 1, 6, tzinfo=pytz.utc), | ||
"VIXClose": Decimal("16.73"), | ||
"VIXHigh": Decimal("17.67"), | ||
"VIXLow": Decimal("16.19"), | ||
"VIXOpen": Decimal("17.66"), | ||
}, | ||
] | ||
|
||
|
||
# Write | ||
|
||
|
||
def test_pandas_parser_write(): | ||
source = TableResource(path="data/table.csv") | ||
target = source.write(format="polars") | ||
assert target.data.to_dicts() == [ # type: ignore | ||
{"id": 1, "name": "english"}, | ||
{"id": 2, "name": "中国人"}, | ||
] | ||
|
||
|
||
def test_polars_parser_nan_in_integer_resource_column(): | ||
# see issue 1109 | ||
res = TableResource( | ||
data=[ | ||
["int", "number", "string"], | ||
["1", "2.3", "string"], | ||
["", "4.3", "string"], | ||
["3", "3.14", "string"], | ||
] | ||
) | ||
df = res.to_polars() | ||
assert df.dtypes == [pl.Int64, pl.Float64, pl.String] # type: ignore | ||
|
||
|
||
def test_pandas_parser_nan_in_integer_csv_column(): | ||
res = TableResource(path="data/issue-1109.csv") | ||
df = res.to_polars() | ||
assert df.dtypes == [pl.Int64, pl.Float64, pl.String] # type: ignore | ||
|
||
|
||
def test_pandas_parser_write_types(): | ||
source = Package("data/storage/types.json").get_table_resource("types") | ||
target = source.write(format="polars") | ||
with target: | ||
# Assert schema | ||
assert target.schema.to_descriptor() == { | ||
"fields": [ | ||
{"name": "any", "type": "string"}, # type fallback | ||
{"name": "array", "type": "array"}, | ||
{"name": "boolean", "type": "boolean"}, | ||
{"name": "date", "type": "datetime"}, # type downgrade | ||
{"name": "date_year", "type": "datetime"}, # type downgrade/fmt removal | ||
{"name": "datetime", "type": "datetime"}, | ||
{"name": "duration", "type": "duration"}, | ||
{"name": "geojson", "type": "object"}, | ||
{"name": "geopoint", "type": "array"}, | ||
{"name": "integer", "type": "integer"}, | ||
{"name": "number", "type": "number"}, | ||
{"name": "object", "type": "object"}, | ||
{"name": "string", "type": "string"}, | ||
{"name": "time", "type": "time"}, | ||
{"name": "year", "type": "integer"}, # type downgrade | ||
{"name": "yearmonth", "type": "array"}, # type downgrade | ||
], | ||
} | ||
|
||
# Assert rows | ||
assert target.read_rows() == [ | ||
{ | ||
"any": "中国人", | ||
"array": ["Mike", "John"], | ||
"boolean": True, | ||
"date": datetime(2015, 1, 1), | ||
"date_year": datetime(2015, 1, 1), | ||
"datetime": datetime(2015, 1, 1, 3, 0), | ||
"duration": isodate.parse_duration("P1Y1M"), | ||
"geojson": {"type": "Point", "coordinates": [33, 33.33]}, | ||
"geopoint": [30, 70], | ||
"integer": 1, | ||
"number": 7, | ||
"object": {"chars": 560}, | ||
"string": "english", | ||
"time": time(3, 0), | ||
"year": 2015, | ||
"yearmonth": [2015, 1], | ||
}, | ||
] | ||
|
||
|
||
def test_pandas_write_constraints(): | ||
source = Package("data/storage/constraints.json").get_table_resource("constraints") | ||
target = source.write(format="pandas") | ||
with target: | ||
# Assert schema | ||
assert target.schema.to_descriptor() == { | ||
"fields": [ | ||
{"name": "required", "type": "string"}, # constraint removal | ||
{"name": "minLength", "type": "string"}, # constraint removal | ||
{"name": "maxLength", "type": "string"}, # constraint removal | ||
{"name": "pattern", "type": "string"}, # constraint removal | ||
{"name": "enum", "type": "string"}, # constraint removal | ||
{"name": "minimum", "type": "integer"}, # constraint removal | ||
{"name": "maximum", "type": "integer"}, # constraint removal | ||
], | ||
} | ||
|
||
# Assert rows | ||
assert target.read_rows() == [ | ||
{ | ||
"required": "passing", | ||
"minLength": "passing", | ||
"maxLength": "passing", | ||
"pattern": "passing", | ||
"enum": "passing", | ||
"minimum": 5, | ||
"maximum": 5, | ||
}, | ||
] | ||
|
||
|
||
def test_pandas_parser_write_timezone(): | ||
source = TableResource(path="data/timezone.csv") | ||
target = source.write(format="pandas") | ||
with target: | ||
# Assert schema | ||
assert target.schema.to_descriptor() == { | ||
"fields": [ | ||
{"name": "datetime", "type": "datetime"}, | ||
{"name": "time", "type": "time"}, | ||
], | ||
} | ||
|
||
# Assert rows | ||
assert target.read_rows() == [ | ||
{ | ||
"datetime": datetime(2020, 1, 1, 15), | ||
"time": time(15), | ||
}, | ||
{ | ||
"datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzutc()), | ||
"time": time(15, 0, tzinfo=tzutc()), | ||
}, | ||
{ | ||
"datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzoffset(None, 10800)), | ||
"time": time(15, 0, tzinfo=tzoffset(None, 10800)), | ||
}, | ||
{ | ||
"datetime": datetime(2020, 1, 1, 15, 0, tzinfo=tzoffset(None, -10800)), | ||
"time": time(15, 0, tzinfo=tzoffset(None, -10800)), | ||
}, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from __future__ import annotations | ||
|
||
import attrs | ||
|
||
from ...dialect import Control | ||
|
||
|
||
@attrs.define(kw_only=True, repr=False) | ||
class PolarsControl(Control): | ||
"""Polars dialect representation""" | ||
|
||
type = "polars" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import decimal | ||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple | ||
|
||
from dateutil.tz import tzoffset | ||
|
||
from ... import types | ||
from ...platform import platform | ||
from ...schema import Field, Schema | ||
from ...system import Parser | ||
|
||
if TYPE_CHECKING: | ||
from ...resources import TableResource | ||
|
||
|
||
class PolarsParser(Parser): | ||
"""Polars parser implementation.""" | ||
|
||
supported_types = [ | ||
"array", | ||
"boolean", | ||
"datetime", | ||
"date", | ||
"duration", | ||
"integer", | ||
"number", | ||
"object", | ||
"string", | ||
"time", | ||
] | ||
|
||
# Read | ||
|
||
def read_cell_stream_create(self): | ||
pl = platform.polars | ||
assert isinstance(self.resource.data, pl.DataFrame) | ||
dataframe = self.resource.data | ||
|
||
# Schema | ||
schema = self.__read_convert_schema() | ||
if not self.resource.schema: | ||
self.resource.schema = schema | ||
|
||
# Lists | ||
yield schema.field_names | ||
for row in dataframe.iter_rows(): # type: ignore | ||
cells: List[Any] = [v if v is not pl.Null else None for v in row] | ||
yield cells | ||
|
||
def __read_convert_schema(self): | ||
pl = platform.polars | ||
dataframe = self.resource.data | ||
schema = Schema() | ||
|
||
# Fields | ||
for name, dtype in zip(dataframe.columns, dataframe.dtypes): # type: ignore | ||
sample = dataframe.select(pl.first(name)).item() if len(dataframe) else None # type: ignore | ||
type = self.__read_convert_type(dtype, sample=sample) # type: ignore | ||
field = Field.from_descriptor({"name": name, "type": type}) | ||
schema.add_field(field) | ||
|
||
# Return schema | ||
return schema | ||
|
||
def __read_convert_type(self, _: Any, sample: Optional[types.ISample] = None): | ||
|
||
# Python types | ||
if sample is not None: | ||
if isinstance(sample, bool): # type: ignore | ||
return "boolean" | ||
elif isinstance(sample, int): # type: ignore | ||
return "integer" | ||
elif isinstance(sample, float): # type: ignore | ||
return "number" | ||
if isinstance(sample, (list, tuple)): # type: ignore | ||
return "array" | ||
elif isinstance(sample, datetime.datetime): | ||
return "datetime" | ||
elif isinstance(sample, datetime.date): | ||
return "date" | ||
elif isinstance(sample, platform.isodate.Duration): # type: ignore | ||
return "duration" | ||
elif isinstance(sample, dict): | ||
return "object" | ||
elif isinstance(sample, str): | ||
return "string" | ||
elif isinstance(sample, datetime.time): | ||
return "time" | ||
|
||
# Default | ||
return "string" | ||
|
||
# Write | ||
|
||
def write_row_stream(self, source: TableResource): | ||
pl = platform.polars | ||
data_rows: List[Tuple[Any]] = [] | ||
fixed_types = {} | ||
with source: | ||
for row in source.row_stream: | ||
data_values: List[Any] = [] | ||
for field in source.schema.fields: | ||
value = row[field.name] | ||
if isinstance(value, dict): | ||
value = str(value) | ||
if isinstance(value, decimal.Decimal): | ||
value = float(value) | ||
if isinstance(value, datetime.datetime) and value.tzinfo: | ||
value = value.astimezone(datetime.timezone.utc) | ||
if isinstance(value, datetime.time) and value.tzinfo: | ||
value = value.replace( | ||
tzinfo=tzoffset( | ||
datetime.timezone.utc, | ||
value.utcoffset().total_seconds(), # type: ignore | ||
) | ||
) | ||
if value is None and field.type in ("number", "integer"): | ||
fixed_types[field.name] = "number" | ||
value = None | ||
data_values.append(value) | ||
data_rows.append(tuple(data_values)) | ||
# Create dtypes/columns | ||
columns: List[str] = [] | ||
for field in source.schema.fields: | ||
if field.name not in source.schema.primary_key: | ||
columns.append(field.name) | ||
|
||
# Create/set dataframe | ||
dataframe = pl.DataFrame(data_rows) | ||
dataframe.columns = columns | ||
|
||
for field in source.schema.fields: | ||
if ( | ||
field.type == "integer" | ||
and field.name in dataframe.columns | ||
and str(dataframe.select(field.name).dtypes[0]) != "int" | ||
): | ||
dataframe = dataframe.with_columns(pl.col(field.name).cast(int)) | ||
|
||
self.resource.data = dataframe |
Oops, something went wrong.