Skip to content
This repository has been archived by the owner on Apr 11, 2024. It is now read-only.

Commit

Permalink
Add Dataframe dataset (#63)
Browse files Browse the repository at this point in the history
# Description
## What is the current behavior?
Currently, the dataframe dataset is not supported.

closes:
#18
#21

## What is the new behavior?
Added dataframe dataset.

## Does this introduce a breaking change?
Nope


### Checklist
- [ ] Created tests which fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary

---------

Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
utkarsharma2 and Lee-W authored Aug 25, 2023
1 parent fedd884 commit 482c23a
Show file tree
Hide file tree
Showing 18 changed files with 331 additions and 225 deletions.
24 changes: 15 additions & 9 deletions src/universal_transfer_operator/data_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from universal_transfer_operator.constants import TransferMode
from universal_transfer_operator.data_providers.base import DataProviders
from universal_transfer_operator.datasets.dataframe.base import Dataframe
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.datasets.table import Table
from universal_transfer_operator.integrations.base import TransferIntegrationOptions
Expand All @@ -23,31 +24,36 @@
("sqlite", Table): "universal_transfer_operator.data_providers.database.sqlite",
("snowflake", Table): "universal_transfer_operator.data_providers.database.snowflake",
(None, File): "universal_transfer_operator.data_providers.filesystem.local",
(None, Dataframe): "universal_transfer_operator.data_providers.dataframe.Pandasdataframe",
}


def create_dataprovider(
dataset: Table | File,
dataset: Table | File | Dataframe,
transfer_params: TransferIntegrationOptions | None = None,
transfer_mode: TransferMode = TransferMode.NONNATIVE,
) -> DataProviders:
class_ref = get_dataprovider_class(dataset=dataset)
if transfer_params is None:
transfer_params = TransferIntegrationOptions()
data_provider: DataProviders = class_ref(
dataset=dataset,
transfer_params=transfer_params,
transfer_mode=transfer_mode,
)

class_ref: type[DataProviders] = get_dataprovider_class(dataset=dataset)
if isinstance(dataset, (Table, File)):
data_provider = class_ref( # type: ignore
dataset=dataset,
transfer_params=transfer_params,
transfer_mode=transfer_mode,
)
elif isinstance(dataset, Dataframe):
data_provider = class_ref(dataset=dataset) # type: ignore
return data_provider


def get_dataprovider_class(dataset: Table | File) -> type[DataProviders]:
def get_dataprovider_class(dataset: Table | File | Dataframe) -> type[DataProviders]:
"""
Get dataprovider class based on the dataset
"""
conn_type = None
if dataset.conn_id:
if isinstance(dataset, (Table, File)) and getattr(dataset, "conn_id", None):
conn_type = BaseHook.get_connection(dataset.conn_id).conn_type
module_path = DATASET_CONN_ID_TO_DATAPROVIDER_MAPPING[(conn_type, type(dataset))]
module = importlib.import_module(module_path)
Expand Down
75 changes: 3 additions & 72 deletions src/universal_transfer_operator/data_providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@

import attr
import pandas as pd
from airflow.hooks.base import BaseHook

from universal_transfer_operator.constants import Location
from universal_transfer_operator.datasets.dataframe.base import Dataframe
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.datasets.table import Table
from universal_transfer_operator.integrations.base import TransferIntegrationOptions
from universal_transfer_operator.utils import get_dataset_connection_type

DatasetType = TypeVar("DatasetType", File, Table)
DatasetType = TypeVar("DatasetType", File, Table, Dataframe)


@attr.define
Expand All @@ -37,82 +34,16 @@ class DataProviders(ABC, Generic[DatasetType]):
def __init__(
self,
dataset: DatasetType,
transfer_mode,
transfer_params: TransferIntegrationOptions = TransferIntegrationOptions(),
):
self.dataset: DatasetType = dataset
self.transfer_params = transfer_params
self.transfer_mode = transfer_mode
self.transfer_mapping: set[Location] = set()
self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {}

def __repr__(self):
return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})'

@property
def hook(self) -> BaseHook:
"""Return an instance of the Airflow hook."""
raise NotImplementedError

def check_if_exists(self) -> bool:
"""Return true if the dataset exists"""
raise NotImplementedError

def check_if_transfer_supported(self, source_dataset: DatasetType) -> bool:
"""
Checks if the transfer is supported from source to destination based on source_dataset.
"""
source_connection_type = get_dataset_connection_type(source_dataset)
return Location(source_connection_type) in self.transfer_mapping

def read(self) -> Iterator[pd.DataFrame] | Iterator[DataStream]:
"""Read from filesystem dataset or databases dataset and write to local reference locations or dataframes"""
raise NotImplementedError

def write(self, source_ref: pd.DataFrame | DataStream) -> str: # type: ignore
def write(self, source_ref: pd.DataFrame | DataStream) -> str | DataProviders: # type: ignore
"""Write the data from local reference location or a dataframe to the database dataset or filesystem dataset
:param source_ref: Stream of data to be loaded into output table or a pandas dataframe.
"""
raise NotImplementedError

@property
def openlineage_dataset_namespace(self) -> str:
"""
Returns the open lineage dataset namespace as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
raise NotImplementedError

@property
def openlineage_dataset_name(self) -> str:
"""
Returns the open lineage dataset name as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
raise NotImplementedError

@property
def openlineage_dataset_uri(self) -> str:
"""
Returns the open lineage dataset uri as per
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md
"""
return f"{self.openlineage_dataset_namespace}{self.openlineage_dataset_name}"

def populate_metadata(self):
"""
Given a dataset, check if the dataset has metadata.
"""
raise NotImplementedError

def is_native_path_available( # skipcq: PYL-R0201
self,
source_dataset: File | Table, # skipcq: PYL-W0613
) -> bool:
"""
Check if there is an optimised path for source to destination.
:param source_dataset: File | Table from which we need to transfer data
"""
return False
26 changes: 17 additions & 9 deletions src/universal_transfer_operator/data_providers/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from universal_transfer_operator.data_providers.base import DataProviders, DataStream
from universal_transfer_operator.data_providers.filesystem import resolve_file_path_pattern
from universal_transfer_operator.datasets.dataframe.pandas import PandasDataframe
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.datasets.table import Metadata, Table
from universal_transfer_operator.settings import (
Expand Down Expand Up @@ -64,11 +63,9 @@ def __init__(
self.transfer_params = transfer_params
self.if_exists = self._if_exists
self.transfer_mode = transfer_mode
self.transfer_mapping = set()
self.transfer_mapping: set[Location] = set()
self.LOAD_DATA_NATIVELY_FROM_SOURCE: dict = {}
super().__init__(
dataset=self.dataset, transfer_mode=self.transfer_mode, transfer_params=self.transfer_params
)
super().__init__(dataset=self.dataset)

def __repr__(self):
return f'{self.__class__.__name__}(conn_id="{self.dataset.conn_id})'
Expand Down Expand Up @@ -227,7 +224,7 @@ def load_dataframe_to_table(
) -> str:
"""
Load content of dataframe in output_table.
:param input_dataframe: dataframe
:param output_table: Table to create
:param if_exists: Overwrite file if exists
Expand Down Expand Up @@ -381,6 +378,7 @@ def create_table_using_schema_autodetection(
source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT)

db = SQLDatabase(engine=self.sqlalchemy_engine)

db.prep_table(
source_dataframe,
table.name.lower(),
Expand Down Expand Up @@ -650,7 +648,7 @@ def load_pandas_dataframe_to_table(
)

@staticmethod
def _assert_not_empty_df(df):
def _assert_not_empty_df(df: pd.DataFrame) -> None:
"""Raise error if dataframe empty
param df: A dataframe
Expand Down Expand Up @@ -743,5 +741,15 @@ def export_table_to_pandas_dataframe(self) -> pd.DataFrame:
raise ValueError(f"The table {self.dataset.name} does not exist")

sqla_table = self.get_sqla_table(self.dataset)
df = pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine)
return PandasDataframe.from_pandas_df(df)
return pd.read_sql(sql=sqla_table.select(), con=self.sqlalchemy_engine)

def is_native_path_available(
self,
source_dataset: Table, # skipcq PYL-W0613, PYL-R0201
) -> bool:
"""
Check if there is an optimised path for source to destination.
:param source_dataset: Dataframe from which we need to transfer data
"""
return False
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import logging
import random
import string
from typing import TYPE_CHECKING, ClassVar
from typing import ClassVar, Iterator

import pandas as pd
from pandas import DataFrame, read_json

from universal_transfer_operator import settings
from universal_transfer_operator.constants import ColumnCapitalization, FileType

if TYPE_CHECKING:
from universal_transfer_operator.datasets.file.base import File
from universal_transfer_operator.constants import FileType
from universal_transfer_operator.data_providers.base import DataStream
from universal_transfer_operator.data_providers.dataframe.base import DataframeProvider
from universal_transfer_operator.datasets.dataframe.base import Dataframe
from universal_transfer_operator.datasets.file.base import File

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,11 +48,37 @@ def convert_dataframe_to_file(df: pd.DataFrame) -> File:
return file


class PandasDataframe(DataFrame):
"""Pandas-compatible dataframe class that can be serialized and deserialized into XCom by Airflow 2.5"""

class PandasdataframeDataProvider(DataframeProvider):
version: ClassVar[int] = 1

def read(self) -> Iterator[pd.DataFrame]:
"""Read from dataframe dataset and write to local reference locations or dataframes"""
yield self.dataset.dataframe

def write(self, source_ref: pd.DataFrame | DataStream) -> PandasdataframeDataProvider:
"""Write the data to the dataframe dataset or filesystem dataset"""
if isinstance(source_ref, pd.DataFrame):
return PandasdataframeDataProvider(dataset=Dataframe(dataframe=source_ref))
else:
return PandasdataframeDataProvider(
dataset=Dataframe(
dataframe=source_ref.actual_file.type.export_to_dataframe(
stream=source_ref.remote_obj_buffer
)
)
)

def equals(self, other: PandasdataframeDataProvider) -> bool:
"""Check equality of two PandasdataframeDataProvider"""
if isinstance(other, PandasdataframeDataProvider):
return bool(self.dataset.dataframe.equals(other.dataset.dataframe))
if isinstance(other, pd.DataFrame):
return self.dataset.equals(other)
return False

def __eq__(self, other) -> bool:
return self.equals(other)

def serialize(self):
# Store in the metadata DB if Dataframe < 100 kb
df_size = self.memory_usage(deep=True).sum()
Expand All @@ -72,36 +97,17 @@ def serialize(self):
@staticmethod
def deserialize(data: dict, version: int):
if version > 1:
raise TypeError(f"version > {PandasDataframe.version}")
raise TypeError(f"version > {PandasdataframeDataProvider.version}")
if isinstance(data, dict) and data.get("class", "") == "File":
file = File.from_json(data)
if file.is_dataframe:
logger.info("Retrieving file from %s using %s conn_id ", file.path, file.conn_id)
return file.export_to_dataframe()
return file
return PandasDataframe.from_pandas_df(read_json(data["data"]))
return PandasdataframeDataProvider.from_pandas_df(pd.read_json(data["data"]))

@classmethod
def from_pandas_df(cls, df: DataFrame) -> DataFrame | PandasDataframe:
def from_pandas_df(cls, df: pd.DataFrame) -> pd.DataFrame | PandasdataframeDataProvider:
if not settings.NEED_CUSTOM_SERIALIZATION:
return df
return cls(df)


def convert_columns_names_capitalization(
df: pd.DataFrame, columns_names_capitalization: ColumnCapitalization
):
"""
Convert cols of a dataframe to required case. Options - lower/Upper
:param df: dataframe whose cols will be altered
:param columns_names_capitalization: String Literal with possible values - lower/Upper
"""
if isinstance(df, pd.DataFrame):
df = PandasDataframe.from_pandas_df(df)
if columns_names_capitalization == "lower":
df.columns = [col_label.lower() for col_label in df.columns] # skipcq: PYL-W0201
elif columns_names_capitalization == "upper":
df.columns = [col_label.upper() for col_label in df.columns] # skipcq: PYL-W0201

return df
46 changes: 46 additions & 0 deletions src/universal_transfer_operator/data_providers/dataframe/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import Iterator

import pandas as pd
from universal_transfer_operator.data_providers.base import DataProviders, DataStream
from universal_transfer_operator.datasets.dataframe.base import Dataframe


class DataframeProvider(DataProviders[Dataframe]):
"""Base class to import dataframe implementation. We intend to support different implementation of dataframes."""

def __init__(
self,
dataset: Dataframe,
) -> None:
self.dataset = dataset
super().__init__(dataset=self.dataset)

def read(self) -> Iterator[Dataframe]:
"""Read from dataframe dataset and write to local reference locations or dataframes"""
raise NotImplementedError

def write(self, source_ref: pd.DataFrame | DataStream) -> DataframeProvider: # type: ignore
"""Write the data to the dataframe dataset or filesystem dataset"""
raise NotImplementedError

def serialize(self):
"""Store in the metadata DB if Dataframe"""
raise NotImplementedError

@staticmethod
def deserialize(data: dict, version: int):
"""Extract from metadata DB"""
raise NotImplementedError

def is_native_path_available(
self,
source_dataset: Dataframe, # skipcq PYL-W0613, PYL-R0201
) -> bool:
"""
Check if there is an optimised path for source to destination.
:param source_dataset: Dataframe from which we need to transfer data
"""
return False
Loading

0 comments on commit 482c23a

Please sign in to comment.