diff --git a/gluestick/reader.py b/gluestick/reader.py index e164a6d..33d48d5 100644 --- a/gluestick/reader.py +++ b/gluestick/reader.py @@ -1,8 +1,10 @@ import os import json import pandas as pd +from pandas.io.parsers import TextFileReader import pyarrow as pa import pyarrow.parquet as pq +import pyarrow.dataset as ds class Reader: """A reader for gluestick ETL files.""" @@ -34,64 +36,128 @@ def __str__(self): def __repr__(self): return str(list(self.input_files.keys())) + def get_in_chunks(self, stream, default=None, catalog_types=False, schema=None, **kwargs): + if not kwargs.get("chunksize"): + print("This method is restricted to chunksize. Assuming chunksize=1000000.") + kwargs["chunksize"] = 1000000 + filepath = self.input_files.get(stream) + if not filepath: + return default + if filepath.endswith(".parquet"): + return self._process_parquet_in_chunks(filepath, stream, catalog_types, schema, **kwargs) + elif filepath.endswith(".csv"): + return self._process_csv_in_chunks(filepath, stream, catalog_types, **kwargs) + else: + raise ValueError(f"Unsupported file type: {filepath}") + def get(self, stream, default=None, catalog_types=False, **kwargs): """Read the selected file.""" + if kwargs.get("chunksize"): + print("This method is not supported for chunksize. Please use get_in_chunks instead. Ignoring chunksize.") + kwargs.pop("chunksize") filepath = self.input_files.get(stream) if not filepath: return default if filepath.endswith(".parquet"): - catalog = self.read_catalog() - if catalog and catalog_types: - try: - headers = pq.read_table(filepath).to_pandas(safe=False).columns.tolist() - types_params = self.get_types_from_catalog(catalog, stream, headers=headers) - dtype_dict = types_params.get('dtype') - parse_dates = types_params.get('parse_dates') - - # Mapping pandas dtypes to pyarrow types - type_mapping = { - 'int64': pa.int64(), - 'float64': pa.float64(), - "": pa.float64(), - 'string': pa.string(), - 'object': pa.string(), - 'datetime64[ns]': pa.timestamp('ns'), - 'bool': pa.bool_(), - 'boolean': pa.bool_(), - # TODO: Add more mappings as needed - } - - if dtype_dict: - # Convert dtype dictionary to pyarrow schema - fields = [(col, type_mapping[str(dtype).lower()]) for col, dtype in dtype_dict.items()] - fields.extend([(col, pa.timestamp('ns')) for col in parse_dates]) - schema = pa.schema(fields) - df = pq.read_table(filepath, schema=schema).to_pandas(safe=False) - for col, dtype in dtype_dict.items(): - # NOTE: bools require explicit conversion at the end because if there are empty values (NaN) - # pyarrow/pd defaults to convert to string - if str(dtype).lower() in ["bool", "boolean"]: - df[col] = df[col].astype('boolean') - elif str(dtype).lower() in ["int64"]: - df[col] = df[col].astype('Int64') - return df - except: - # NOTE: silencing errors to avoid breaking existing workflow - print(f"Failed to parse catalog_types for {stream}. Ignoring.") - pass - - return pq.read_table(filepath).to_pandas(safe=False) + return self._process_parquet(filepath, stream, catalog_types, **kwargs) + return self._process_csv(filepath, stream, catalog_types, **kwargs) + + def _process_parquet_in_chunks(self, filepath, stream, catalog_types, schema, **kwargs): + catalog = self.read_catalog() + if catalog and catalog_types: + try: + dataset = ds.dataset(filepath, schema=schema, format="parquet") + for batch in dataset.to_batches(batch_size=kwargs["chunksize"]): + yield from self._apply_catalog_types(batch, catalog, stream, schema, **kwargs) + except: + print(f"Failed to parse catalog_types for {stream}. Ignoring.") + dataset = ds.dataset(filepath, schema=schema, format="parquet") + for batch in dataset.to_batches(batch_size=kwargs["chunksize"]): + yield batch.to_pandas(safe=False) + + def _process_parquet(self, filepath, stream, catalog_types, **kwargs): + catalog = self.read_catalog() + if catalog and catalog_types: + try: + headers = pq.read_table(filepath).to_pandas(safe=False).columns.tolist() + types_params = self.get_types_from_catalog(catalog, stream, headers=headers) + dtype_dict = types_params.get('dtype') + parse_dates = types_params.get('parse_dates') + + schema = self._create_schema(dtype_dict, parse_dates) + df = pq.read_table(filepath, schema=schema).to_pandas(safe=False) + return self._convert_dtypes(df, dtype_dict) + except: + print(f"Failed to parse catalog_types for {stream}. Ignoring.") + return pq.read_table(filepath).to_pandas(safe=False) + + def _process_csv_in_chunks(self, filepath, stream, catalog_types, **kwargs): + catalog = self.read_catalog() + if catalog and catalog_types: + types_params = self.get_types_from_catalog(catalog, stream) + kwargs.update(types_params) + chunks_generator: TextFileReader = pd.read_csv(filepath, **kwargs) + for chunk in chunks_generator: + for date_col in kwargs.get("parse_dates", []): + chunk[date_col] = pd.to_datetime(chunk[date_col], errors='coerce') + yield chunk + + def _process_csv(self, filepath, stream, catalog_types, **kwargs): catalog = self.read_catalog() if catalog and catalog_types: types_params = self.get_types_from_catalog(catalog, stream) kwargs.update(types_params) - df = pd.read_csv(filepath, **kwargs) - # if a date field value is empty read_csv will read it as "object" - # make sure all date fields are typed as date + df: pd.DataFrame = pd.read_csv(filepath, **kwargs) for date_col in kwargs.get("parse_dates", []): df[date_col] = pd.to_datetime(df[date_col], errors='coerce') return df + def _apply_catalog_types(self, batch, catalog, stream, schema, **kwargs): + headers = batch.to_pandas(safe=False).columns.tolist() + types_params = self.get_types_from_catalog(catalog, stream, headers=headers) + dtype_dict = types_params.get('dtype') + parse_dates = types_params.get('parse_dates') + + if dtype_dict: + fields = [(col, self._type_mapping(str(dtype).lower())) for col, dtype in dtype_dict.items()] + fields.extend([(col, pa.timestamp('ns')) for col in parse_dates]) + if not schema: + schema = pa.schema(fields) + yield from self.get_in_chunks(stream, default, catalog_types, schema, **kwargs) + for col, dtype in dtype_dict.items(): + if str(dtype).lower() in ["bool", "boolean"]: + batch[col] = batch[col].astype('boolean') + elif str(dtype).lower() in ["int64"]: + batch[col] = batch[col].astype('Int64') + yield batch + + def _create_schema(self, dtype_dict, parse_dates): + fields = [(col, self._type_mapping(str(dtype).lower())) for col, dtype in dtype_dict.items()] + fields.extend([(col, pa.timestamp('ns')) for col in parse_dates]) + return pa.schema(fields) + + def _convert_dtypes(self, df, dtype_dict): + for col, dtype in dtype_dict.items(): + if str(dtype).lower() in ["bool", "boolean"]: + df[col] = df[col].astype('boolean') + elif str(dtype).lower() in ["int64"]: + df[col] = df[col].astype('Int64') + return df + + def _type_mapping(self, dtype): + type_mapping = { + 'int64': pa.int64(), + 'float64': pa.float64(), + "": pa.float64(), + 'string': pa.string(), + 'object': pa.string(), + 'datetime64[ns]': pa.timestamp('ns'), + 'bool': pa.bool_(), + 'boolean': pa.bool_(), + # TODO: Add more mappings as needed + } + return type_mapping.get(dtype, pa.string()) + def get_metadata(self, stream): """Get metadata from parquet file.""" file = self.input_files.get(stream)