From 1fa55e4e9d91d349604fa8abc96389b4160f6f23 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 30 Apr 2024 17:39:19 +0800 Subject: [PATCH] Annotated StructuredDataset: support `nested_types` (#2252) Signed-off-by: Austin Liu Signed-off-by: Kevin Su Co-authored-by: Kevin Su Signed-off-by: Jan Fiedler --- .../types/structured/structured_dataset.py | 33 +++- .../test_structured_dataset_workflow.py | 88 +++++++++ ...tured_dataset_workflow_with_nested_type.py | 179 ++++++++++++++++++ 3 files changed, 297 insertions(+), 3 deletions(-) create mode 100644 tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 8faed9ff45..d6dc6b49e5 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -4,7 +4,7 @@ import types import typing from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass, field, is_dataclass from typing import Dict, Generator, Optional, Type, Union import _datetime @@ -114,6 +114,22 @@ def iter(self) -> Generator[DF, None, None]: ) +# flat the nested column map recursively +def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict: + result = {} + for key, value in sub_dict.items(): + current_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + result.update(flatten_dict(sub_dict=value, parent_key=current_key)) + elif is_dataclass(value): + fields = getattr(value, "__dataclass_fields__") + d = {k: v.type for k, v in fields.items()} + result.update(flatten_dict(sub_dict=d, parent_key=current_key)) + else: + result[current_key] = value + return result + + def extract_cols_and_format( t: typing.Any, ) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]: @@ -142,7 +158,17 @@ def extract_cols_and_format( if get_origin(t) is Annotated: base_type, *annotate_args = get_args(t) for aa in annotate_args: - if isinstance(aa, StructuredDatasetFormat): + if hasattr(aa, "__annotations__"): + # handle dataclass argument + d = collections.OrderedDict() + dm = vars(aa) + d.update(dm["__annotations__"]) + ordered_dict_cols = d + elif isinstance(aa, dict): + d = collections.OrderedDict() + d.update(aa) + ordered_dict_cols = d + elif isinstance(aa, StructuredDatasetFormat): if fmt != "": raise ValueError(f"A format was already specified {fmt}, cannot use {aa}") fmt = aa @@ -826,7 +852,8 @@ def _convert_ordered_dict_of_columns_to_list( converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: return converted_cols - for k, v in column_map.items(): + flat_column_map = flatten_dict(column_map) + for k, v in flat_column_map.items(): lt = self._get_dataset_column_literal_type(v) converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) return converted_cols diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index 3b0bf96e7a..91fa72b526 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -1,5 +1,6 @@ import os import typing +from dataclasses import dataclass import numpy as np import pyarrow as pa @@ -28,7 +29,16 @@ NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() BQ_PATH = "bq://flyte-dataset:flyte.table" + +@dataclass +class MyCols: + Name: str + Age: int + + my_cols = kwtypes(Name=str, Age=int) +my_dataclass_cols = MyCols +my_dict_cols = {"Name": str, "Age": int} fields = [("Name", pa.string()), ("Age", pa.int32())] arrow_schema = pa.schema(fields) pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) @@ -157,6 +167,18 @@ def t4(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return dataset.open(pd.DataFrame).all() +@task +def t4a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # s3 (parquet) -> pandas -> s3 (parquet) + return dataset.open(pd.DataFrame).all() + + +@task +def t4b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # s3 (parquet) -> pandas -> s3 (parquet) + return dataset.open(pd.DataFrame).all() + + @task def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: # s3 (parquet) -> pandas -> bq @@ -170,6 +192,20 @@ def t6(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return df +@task +def t6a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # bq -> pandas -> s3 (parquet) + df = dataset.open(pd.DataFrame).all() + return df + + +@task +def t6b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # bq -> pandas -> s3 (parquet) + df = dataset.open(pd.DataFrame).all() + return df + + @task def t7( df1: pd.DataFrame, df2: pd.DataFrame @@ -193,6 +229,20 @@ def t8a(dataframe: pa.Table) -> pa.Table: return dataframe +@task +def t8b(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dataclass_cols]: + # Arrow table -> s3 (parquet) + print(dataframe.columns) + return StructuredDataset(dataframe=dataframe) + + +@task +def t8c(dataframe: pa.Table) -> Annotated[StructuredDataset, my_dict_cols]: + # Arrow table -> s3 (parquet) + print(dataframe.columns) + return StructuredDataset(dataframe=dataframe) + + @task def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]: # numpy -> Arrow table -> s3 (parquet) @@ -206,6 +256,20 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray: return np_array +@task +def t10a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> np.ndarray: + # s3 (parquet) -> Arrow table -> numpy + np_array = dataset.open(np.ndarray).all() + return np_array + + +@task +def t10b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> np.ndarray: + # s3 (parquet) -> Arrow table -> numpy + np_array = dataset.open(np.ndarray).all() + return np_array + + StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler()) StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler()) @@ -223,6 +287,20 @@ def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame: return df +@task +def t12a(dataset: Annotated[StructuredDataset, my_dataclass_cols]) -> pd.DataFrame: + # csv -> pandas + df = dataset.open(pd.DataFrame).all() + return df + + +@task +def t12b(dataset: Annotated[StructuredDataset, my_dict_cols]) -> pd.DataFrame: + # csv -> pandas + df = dataset.open(pd.DataFrame).all() + return df + + @task def generate_pandas() -> pd.DataFrame: return pd_df @@ -249,15 +327,25 @@ def wf(): t3(dataset=StructuredDataset(uri=PANDAS_PATH)) t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) t4(dataset=StructuredDataset(uri=PANDAS_PATH)) + t4a(dataset=StructuredDataset(uri=PANDAS_PATH)) + t4b(dataset=StructuredDataset(uri=PANDAS_PATH)) t5(dataframe=df) t6(dataset=StructuredDataset(uri=BQ_PATH)) + t6a(dataset=StructuredDataset(uri=BQ_PATH)) + t6b(dataset=StructuredDataset(uri=BQ_PATH)) t7(df1=df, df2=df) t8(dataframe=arrow_df) t8a(dataframe=arrow_df) + t8b(dataframe=arrow_df) + t8c(dataframe=arrow_df) t9(dataframe=np_array) t10(dataset=StructuredDataset(uri=NUMPY_PATH)) + t10a(dataset=StructuredDataset(uri=NUMPY_PATH)) + t10b(dataset=StructuredDataset(uri=NUMPY_PATH)) t11(dataframe=df) t12(dataset=StructuredDataset(uri=PANDAS_PATH)) + t12a(dataset=StructuredDataset(uri=PANDAS_PATH)) + t12b(dataset=StructuredDataset(uri=PANDAS_PATH)) def test_structured_dataset_wf(): diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py new file mode 100644 index 0000000000..62c0f6d651 --- /dev/null +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow_with_nested_type.py @@ -0,0 +1,179 @@ +from dataclasses import dataclass + +import pyarrow as pa +import pytest +from typing_extensions import Annotated + +from flytekit import FlyteContextManager, StructuredDataset, kwtypes, task, workflow + +pd = pytest.importorskip("pandas") + +PANDAS_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() +NUMPY_PATH = FlyteContextManager.current_context().file_access.get_random_local_directory() +BQ_PATH = "bq://flyte-dataset:flyte.table" + + +data = [ + { + "company": "XYZ pvt ltd", + "location": "London", + "info": {"president": "Rakesh Kapoor", "contacts": {"email": "contact@xyz.com", "tel": "9876543210"}}, + }, + { + "company": "ABC pvt ltd", + "location": "USA", + "info": {"president": "Kapoor Rakesh", "contacts": {"email": "contact@abc.com", "tel": "0123456789"}}, + }, +] + + +@dataclass +class ContactsField: + email: str + tel: str + + +@dataclass +class InfoField: + president: str + contacts: ContactsField + + +@dataclass +class CompanyField: + location: str + info: InfoField + company: str + + +MyArgDataset = Annotated[StructuredDataset, kwtypes(company=str)] +MyDictDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str}})] +MyDictListDataset = Annotated[StructuredDataset, kwtypes(info={"contacts": {"tel": str, "email": str}})] +MyTopDataClassDataset = Annotated[StructuredDataset, CompanyField] +MyTopDictDataset = Annotated[StructuredDataset, {"company": str, "location": str}] +MySecondDataClassDataset = Annotated[StructuredDataset, kwtypes(info=InfoField)] +MyNestedDataClassDataset = Annotated[StructuredDataset, kwtypes(info=kwtypes(contacts=ContactsField))] + + +@task() +def create_pd_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", df) + + return StructuredDataset(dataframe=df, uri=PANDAS_PATH) + + +@task() +def create_bq_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", df) + + # Enable one of GCP `uri` below if you want. You can replace `uri` with your own google cloud endpoints. + return StructuredDataset(dataframe=df, uri=BQ_PATH) + + +@task() +def create_np_table() -> StructuredDataset: + df = pd.json_normalize(data, max_level=0) + print("original dataframe: \n", df) + + return StructuredDataset(dataframe=df, uri=NUMPY_PATH) + + +@task() +def create_ar_table() -> StructuredDataset: + df = pa.Table.from_pandas(pd.json_normalize(data, max_level=0)) + print("original dataframe: \n", df) + + return StructuredDataset( + dataframe=df, + ) + + +@task() +def print_table_by_arg(sd: MyArgDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyArgDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_dict(sd: MyDictDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_list_dict(sd: MyDictListDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyDictListDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_top_dataclass(sd: MyTopDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyTopDataClassDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_top_dict(sd: MyTopDictDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyTopDictDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_second_dataclass(sd: MySecondDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MySecondDataClassDataset dataframe: \n", t) + return t + + +@task() +def print_table_by_nested_dataclass(sd: MyNestedDataClassDataset) -> pd.DataFrame: + t = sd.open(pd.DataFrame).all() + print("MyNestedDataClassDataset dataframe: \n", t) + return t + + +@workflow +def wf(): + pd_sd = create_pd_table() + print_table_by_arg(sd=pd_sd) + print_table_by_dict(sd=pd_sd) + print_table_by_list_dict(sd=pd_sd) + print_table_by_top_dataclass(sd=pd_sd) + print_table_by_top_dict(sd=pd_sd) + print_table_by_second_dataclass(sd=pd_sd) + print_table_by_nested_dataclass(sd=pd_sd) + bq_sd = create_pd_table() + print_table_by_arg(sd=bq_sd) + print_table_by_dict(sd=bq_sd) + print_table_by_list_dict(sd=bq_sd) + print_table_by_top_dataclass(sd=bq_sd) + print_table_by_top_dict(sd=bq_sd) + print_table_by_second_dataclass(sd=bq_sd) + print_table_by_nested_dataclass(sd=bq_sd) + np_sd = create_pd_table() + print_table_by_arg(sd=np_sd) + print_table_by_dict(sd=np_sd) + print_table_by_list_dict(sd=np_sd) + print_table_by_top_dataclass(sd=np_sd) + print_table_by_top_dict(sd=np_sd) + print_table_by_second_dataclass(sd=np_sd) + print_table_by_nested_dataclass(sd=np_sd) + ar_sd = create_pd_table() + print_table_by_arg(sd=ar_sd) + print_table_by_dict(sd=ar_sd) + print_table_by_list_dict(sd=ar_sd) + print_table_by_top_dataclass(sd=ar_sd) + print_table_by_top_dict(sd=ar_sd) + print_table_by_second_dataclass(sd=ar_sd) + print_table_by_nested_dataclass(sd=ar_sd) + + +def test_structured_dataset_wf(): + wf()