diff --git a/gluestick/reader.py b/gluestick/reader.py index e164a6d..900c3ae 100644 --- a/gluestick/reader.py +++ b/gluestick/reader.py @@ -1,8 +1,10 @@ import os import json import pandas as pd +from pandas.io.parsers import TextFileReader import pyarrow as pa import pyarrow.parquet as pq +import pyarrow.dataset as ds class Reader: """A reader for gluestick ETL files.""" @@ -34,8 +36,82 @@ def __str__(self): def __repr__(self): return str(list(self.input_files.keys())) + def get_in_chunks(self, stream, default=None, catalog_types=False, schema=None, **kwargs): + if not kwargs.get("chunksize"): + print("This method is restricted to chunksize. Assuming chunksize=1000000.") + kwargs["chunksize"] = 1000000 + filepath = self.input_files.get(stream) + if not filepath: + return default + if filepath.endswith(".parquet"): + catalog = self.read_catalog() + if catalog and catalog_types: + try: + dataset = ds.dataset(filepath, schema=schema, format="parquet") + for batch in dataset.to_batches(batch_size=kwargs["chunksize"]): + headers = batch.to_pandas(safe=False).columns.tolist() + types_params = self.get_types_from_catalog(catalog, stream, headers=headers) + dtype_dict = types_params.get('dtype') + parse_dates = types_params.get('parse_dates') + + # Mapping pandas dtypes to pyarrow types + type_mapping = { + 'int64': pa.int64(), + 'float64': pa.float64(), + "": pa.float64(), + 'string': pa.string(), + 'object': pa.string(), + 'datetime64[ns]': pa.timestamp('ns'), + 'bool': pa.bool_(), + 'boolean': pa.bool_(), + # TODO: Add more mappings as needed + } + + if dtype_dict: + # Convert dtype dictionary to pyarrow schema + fields = [(col, type_mapping[str(dtype).lower()]) for col, dtype in dtype_dict.items()] + fields.extend([(col, pa.timestamp('ns')) for col in parse_dates]) + if not schema: + schema = pa.schema(fields) + yield from self.get_in_chunks(stream, default, catalog_types, schema, **kwargs) + for col, dtype in dtype_dict.items(): + # NOTE: bools require explicit conversion at the end because if there are empty values (NaN) + # pyarrow/pd defaults to convert to string + if str(dtype).lower() in ["bool", "boolean"]: + batch[col] = batch[col].astype('boolean') + elif str(dtype).lower() in ["int64"]: + batch[col] = batch[col].astype('Int64') + yield batch + except: + # NOTE: silencing errors to avoid breaking existing workflow + print(f"Failed to parse catalog_types for {stream}. Ignoring.") + dataset = ds.dataset(filepath, schema=schema, format="parquet") + for batch in dataset.to_batches(batch_size=kwargs["chunksize"]): + yield batch.to_pandas(safe=False) + else: + dataset = ds.dataset(filepath, schema=schema, format="parquet") + for batch in dataset.to_batches(batch_size=kwargs["chunksize"]): + yield batch.to_pandas(safe=False) + elif filepath.endswith(".csv"): + catalog = self.read_catalog() + if catalog and catalog_types: + types_params = self.get_types_from_catalog(catalog, stream) + kwargs.update(types_params) + chunks_generator: TextFileReader = pd.read_csv(filepath, **kwargs) + # if a date field value is empty read_csv will read it as "object" + # make sure all date fields are typed as date + for chunk in chunks_generator: + for date_col in kwargs.get("parse_dates", []): + chunk[date_col] = pd.to_datetime(chunk[date_col], errors='coerce') + yield chunk + else: + raise ValueError(f"Unsupported file type: {filepath}") + def get(self, stream, default=None, catalog_types=False, **kwargs): """Read the selected file.""" + if kwargs.get("chunksize"): + print("This method is not supported for chunksize. Please use get_in_chunks instead. Ignoring chunksize.") + kwargs.pop("chunksize") filepath = self.input_files.get(stream) if not filepath: return default @@ -85,7 +161,7 @@ def get(self, stream, default=None, catalog_types=False, **kwargs): if catalog and catalog_types: types_params = self.get_types_from_catalog(catalog, stream) kwargs.update(types_params) - df = pd.read_csv(filepath, **kwargs) + df: pd.DataFrame = pd.read_csv(filepath, **kwargs) # if a date field value is empty read_csv will read it as "object" # make sure all date fields are typed as date for date_col in kwargs.get("parse_dates", []):