Skip to content

Commit

Permalink
Annotated StructuredDataset: support nested_types (#2252)
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
austin362667 and pingsutw authored Apr 30, 2024
1 parent 0116ff3 commit 0cc8bbc
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 3 deletions.
33 changes: 30 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import types
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, Optional, Type, Union

import _datetime
Expand Down Expand Up @@ -114,6 +114,22 @@ def iter(self) -> Generator[DF, None, None]:
)


# flat the nested column map recursively
def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
result = {}
for key, value in sub_dict.items():
current_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
result.update(flatten_dict(sub_dict=value, parent_key=current_key))
elif is_dataclass(value):
fields = getattr(value, "__dataclass_fields__")
d = {k: v.type for k, v in fields.items()}
result.update(flatten_dict(sub_dict=d, parent_key=current_key))
else:
result[current_key] = value
return result


def extract_cols_and_format(
t: typing.Any,
) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]:
Expand Down Expand Up @@ -142,7 +158,17 @@ def extract_cols_and_format(
if get_origin(t) is Annotated:
base_type, *annotate_args = get_args(t)
for aa in annotate_args:
if isinstance(aa, StructuredDatasetFormat):
if hasattr(aa, "__annotations__"):
# handle dataclass argument
d = collections.OrderedDict()
dm = vars(aa)
d.update(dm["__annotations__"])
ordered_dict_cols = d
elif isinstance(aa, dict):
d = collections.OrderedDict()
d.update(aa)
ordered_dict_cols = d
elif isinstance(aa, StructuredDatasetFormat):
if fmt != "":
raise ValueError(f"A format was already specified {fmt}, cannot use {aa}")
fmt = aa
Expand Down Expand Up @@ -826,7 +852,8 @@ def _convert_ordered_dict_of_columns_to_list(
converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = []
if column_map is None or len(column_map) == 0:
return converted_cols
for k, v in column_map.items():
flat_column_map = flatten_dict(column_map)
for k, v in flat_column_map.items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
return converted_cols
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import typing
from dataclasses import dataclass

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -28,7 +29,16 @@
NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
BQ_PATH = "bq://flyte-dataset:flyte.table"


@dataclass
class MyCols:
Name: str
Age: int


my_cols = kwtypes(Name=str, Age=int)
my_dataclass_cols = MyCols
my_dict_cols = {"Name": str, "Age": int}
fields = [("Name", pa.string()), ("Age", pa.int32())]
arrow_schema = pa.schema(fields)
pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
Expand Down Expand Up @@ -157,6 +167,18 @@ def t4(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return dataset.open(pd.DataFrame).all()


@task
def t4a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t4b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# s3 (parquet) -> pandas -> s3 (parquet)
return dataset.open(pd.DataFrame).all()


@task
def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
# s3 (parquet) -> pandas -> bq
Expand All @@ -170,6 +192,20 @@ def t6(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t6a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t6b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# bq -> pandas -> s3 (parquet)
df = dataset.open(pd.DataFrame).all()
return df


@task
def t7(
df1: pd.DataFrame, df2: pd.DataFrame
Expand All @@ -193,6 +229,20 @@ def t8a(dataframe: pa.Table) -> pa.Table:
return dataframe


@task
def t8b(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dataclass_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t8c(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dict_cols]:
# Arrow table -> s3 (parquet)
print(dataframe.columns)
return StructuredDataset(dataframe=dataframe)


@task
def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]:
# numpy -> Arrow table -> s3 (parquet)
Expand All @@ -206,6 +256,20 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray:
return np_array


@task
def t10a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


@task
def t10b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> np.ndarray:
# s3 (parquet) -> Arrow table -> numpy
np_array = dataset.open(np.ndarray).all()
return np_array


StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler())
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler())

Expand All @@ -223,6 +287,20 @@ def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
return df


@task
def t12a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def t12b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def generate_pandas() -> pd.DataFrame:
return pd_df
Expand All @@ -249,15 +327,25 @@ def wf():
t3(dataset=StructuredDataset(uri=PANDAS_PATH))
t3a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4(dataset=StructuredDataset(uri=PANDAS_PATH))
t4a(dataset=StructuredDataset(uri=PANDAS_PATH))
t4b(dataset=StructuredDataset(uri=PANDAS_PATH))
t5(dataframe=df)
t6(dataset=StructuredDataset(uri=BQ_PATH))
t6a(dataset=StructuredDataset(uri=BQ_PATH))
t6b(dataset=StructuredDataset(uri=BQ_PATH))
t7(df1=df, df2=df)
t8(dataframe=arrow_df)
t8a(dataframe=arrow_df)
t8b(dataframe=arrow_df)
t8c(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
t10a(dataset=StructuredDataset(uri=NUMPY_PATH))
t10b(dataset=StructuredDataset(uri=NUMPY_PATH))
t11(dataframe=df)
t12(dataset=StructuredDataset(uri=PANDAS_PATH))
t12a(dataset=StructuredDataset(uri=PANDAS_PATH))
t12b(dataset=StructuredDataset(uri=PANDAS_PATH))


def test_structured_dataset_wf():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass

import pyarrow as pa
import pytest
from typing_extensions import Annotated

from flytekit import FlyteContextManager, StructuredDataset, kwtypes, task, workflow

pd = pytest.importorskip("pandas")

PANDAS_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory()
BQ_PATH = "bq://flyte-dataset:flyte.table"


data = [
{
"company": "XYZ pvt ltd",
"location": "London",
"info": {"president": "Rakesh Kapoor", "contacts": {"email": "[email protected]", "tel": "9876543210"}},
},
{
"company": "ABC pvt ltd",
"location": "USA",
"info": {"president": "Kapoor Rakesh", "contacts": {"email": "[email protected]", "tel": "0123456789"}},
},
]


@dataclass
class ContactsField:
email: str
tel: str


@dataclass
class InfoField:
president: str
contacts: ContactsField


@dataclass
class CompanyField:
location: str
info: InfoField
company: str


MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)]
MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})]
MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})]
MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField]
MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}]
MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)]
MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))]


@task()
def create_pd_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

return StructuredDataset(dataframe=df, uri=PANDAS_PATH)


@task()
def create_bq_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

# Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints.
return StructuredDataset(dataframe=df, uri=BQ_PATH)


@task()
def create_np_table() -> StructuredDataset:
df = pd.json_normalize(data, max_level=0)
print("original dataframe: \n", df)

return StructuredDataset(dataframe=df, uri=NUMPY_PATH)


@task()
def create_ar_table() -> StructuredDataset:
df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0))
print("original dataframe: \n", df)

return StructuredDataset(
dataframe=df,
)


@task()
def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyArgDataset dataframe: \n", t)
return t


@task()
def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyDictDataset dataframe: \n", t)
return t


@task()
def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyDictListDataset dataframe: \n", t)
return t


@task()
def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyTopDataClassDataset dataframe: \n", t)
return t


@task()
def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyTopDictDataset dataframe: \n", t)
return t


@task()
def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MySecondDataClassDataset dataframe: \n", t)
return t


@task()
def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame:
t = sd.open(pd.DataFrame).all()
print("MyNestedDataClassDataset dataframe: \n", t)
return t


@workflow
def wf():
pd_sd = create_pd_table()
print_table_by_arg(sd=pd_sd)
print_table_by_dict(sd=pd_sd)
print_table_by_list_dict(sd=pd_sd)
print_table_by_top_dataclass(sd=pd_sd)
print_table_by_top_dict(sd=pd_sd)
print_table_by_second_dataclass(sd=pd_sd)
print_table_by_nested_dataclass(sd=pd_sd)
bq_sd = create_pd_table()
print_table_by_arg(sd=bq_sd)
print_table_by_dict(sd=bq_sd)
print_table_by_list_dict(sd=bq_sd)
print_table_by_top_dataclass(sd=bq_sd)
print_table_by_top_dict(sd=bq_sd)
print_table_by_second_dataclass(sd=bq_sd)
print_table_by_nested_dataclass(sd=bq_sd)
np_sd = create_pd_table()
print_table_by_arg(sd=np_sd)
print_table_by_dict(sd=np_sd)
print_table_by_list_dict(sd=np_sd)
print_table_by_top_dataclass(sd=np_sd)
print_table_by_top_dict(sd=np_sd)
print_table_by_second_dataclass(sd=np_sd)
print_table_by_nested_dataclass(sd=np_sd)
ar_sd = create_pd_table()
print_table_by_arg(sd=ar_sd)
print_table_by_dict(sd=ar_sd)
print_table_by_list_dict(sd=ar_sd)
print_table_by_top_dataclass(sd=ar_sd)
print_table_by_top_dict(sd=ar_sd)
print_table_by_second_dataclass(sd=ar_sd)
print_table_by_nested_dataclass(sd=ar_sd)


def test_structured_dataset_wf():
wf()

0 comments on commit 0cc8bbc

Please sign in to comment.