Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous List Transformations #1738

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 27 additions & 46 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import collections
import copy
import dataclasses
Expand All @@ -8,9 +9,11 @@
import inspect
import json as _json
import mimetypes
import os
import textwrap
import typing
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import Dict, NamedTuple, Optional, Type, cast

Expand All @@ -28,7 +31,7 @@
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import timeit
from flytekit.core.utils import coroutine, timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.lazy_import.lazy_module import is_imported
from flytekit.loggers import logger
Expand Down Expand Up @@ -1041,62 +1044,40 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
"""
This function evaluates whether the provided type is batchable or not.
It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle.
"""
from flytekit.types.pickle import FlytePickle

if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
@timeit("ListTransformer: to_python_value")
@coroutine
async def to_literal(
self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType
) -> Literal:
if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
break
if batch_size > 0:
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore
else:
lit_list = []
else:
# Set maximum number of threads to the number of processors on the machine, multiplied by 5 since it is I/O bound task
# limit it to 32 to avoid consuming surprisingly large resource on many core machine.
with ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 5)) as pool:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
loop = asyncio.get_running_loop()
lit_future_list = [
loop.run_in_executor(pool, TypeEngine.to_literal, ctx, x, t, expected.collection_type) for x in python_val # type: ignore
]
lit_list = await asyncio.gather(*lit_future_list)
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
@timeit("ListTransformer: to_python_value")
@coroutine
async def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
try:
lits = lv.collection.literals
except AttributeError:
raise TypeTransformerFailedError()
if self.is_batchable(expected_python_type):
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
if len(batch_list) > 0 and type(batch_list[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that
# won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
else:

# Set maximum number of threads to the number of processors on the machine, multiplied by 5 since it is I/O bound task
# limit it to 32 to avoid consuming surprisingly large resource on many core machine.
with ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 5)) as pool:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]
loop = asyncio.get_running_loop()
val_future_list = [loop.run_in_executor(pool, TypeEngine.to_python_value, ctx, x, st) for x in lits]
return await asyncio.gather(*val_future_list)

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down
9 changes: 9 additions & 0 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime
import os as _os
import shutil as _shutil
Expand Down Expand Up @@ -331,3 +332,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
end_process_time - self._start_process_time,
)
)


def coroutine(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
return asyncio.run(func(*args, **kwargs))

return wrapper
2 changes: 1 addition & 1 deletion flytekit/types/pickle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
FlytePickle
"""

from .pickle import BatchSize, FlytePickle
from .pickle import FlytePickle
13 changes: 0 additions & 13 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
T = typing.TypeVar("T")


class BatchSize:
"""
Flyte-specific object used to wrap the hash function for a specific type
"""

def __init__(self, val: int):
self._val = val

@property
def val(self) -> int:
return self._val


class FlytePickle(typing.Generic[T]):
"""
This type is only used by flytekit internally. User should not use this type.
Expand Down
7 changes: 1 addition & 6 deletions tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType
from flytekit.tools.translator import get_serializable
from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer
from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -57,11 +57,6 @@ def test_get_literal_type():
assert lt == expected_lt


def test_batch_size():
bs = BatchSize(5)
assert bs.val == 5


def test_nested():
class Foo(object):
def __init__(self, number: int):
Expand Down
6 changes: 2 additions & 4 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
from dataclasses_json import dataclass_json
from typing_extensions import Annotated

from flytekit import LaunchPlan, task, workflow
from flytekit.core import context_manager
Expand All @@ -15,7 +14,6 @@
translate_inputs_to_literals,
)
from flytekit.exceptions.user import FlyteAssertion
from flytekit.types.pickle.pickle import BatchSize


def test_create_and_link_node():
Expand Down Expand Up @@ -101,11 +99,11 @@ class MyDataclass(object):

@pytest.mark.parametrize(
"input",
[2.0, MyDataclass(i=1, a=["h", "e"]), [1, 2, 3], ["foo"] * 5],
[2.0, MyDataclass(i=1, a=["h", "e"]), [1, 2, 3]],
)
def test_translate_inputs_to_literals(input):
@task
def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], BatchSize(2)]]):
def t1(a: typing.Union[float, typing.List[int], MyDataclass]):
print(a)

ctx = context_manager.FlyteContext.current_context()
Expand Down
68 changes: 2 additions & 66 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from pandas._testing import assert_frame_equal
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, get_args

from flytekit import kwtypes
from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -55,7 +55,7 @@
from flytekit.types.file import FileExt, JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.schema.types_pandas import PandasDataFrameTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset
Expand Down Expand Up @@ -1636,70 +1636,6 @@ def test_file_ext_with_flyte_file_wrong_type():
assert str(e.value) == "Underlying type of File Extension must be of type <str>"


def test_is_batchable():
assert ListTransformer.is_batchable(typing.List[int]) is False
assert ListTransformer.is_batchable(typing.List[str]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False
assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False

assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True
assert (
ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)])
is True
)


@pytest.mark.parametrize(
"python_val, python_type, expected_list_length",
[
# Case 1: List of FlytePickle objects with default batch size.
# (By default, the batch_size is set to the length of the whole list.)
# After converting to literal, the result will be [batched_FlytePickle(5 items)].
# Therefore, the expected list length is [1].
([{"foo"}] * 5, typing.List[FlytePickle], [1]),
# Case 2: List of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]),
# Case 4: Empty list
([[], typing.List[FlytePickle], []]),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)

tmp_lv = lv
for length in expected_list_length:
# Check that after converting to literal, the length of the literal list is equal to:
# - the length of the original list divided by the batch size if not nested
# - the length of the original list if it contains a nested list
assert len(tmp_lv.collection.literals) == length
tmp_lv = tmp_lv.collection.literals[0]

pv = TypeEngine.to_python_value(ctx, lv, python_type)
# Check that after converting literal to Python value, the result is equal to the original python values.
assert pv == python_val
if get_origin(python_type) is Annotated:
pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0])
# Remove the annotation and check that after converting to Python value, the result is equal
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val


@pytest.mark.parametrize(
"t,expected",
[
Expand Down