Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve map task issues for batchable list #1772

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 60 additions & 36 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,62 +1041,86 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
def _is_batchable(self, sub_type: Type[T]) -> bool:
"""
This function evaluates whether the provided type is batchable or not.
It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle.
Determines whether the list is batchable, given its subtype.
Returns True if the subtype is transformed using FlytePickleTransformer, otherwise False.
"""
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer

if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False
return isinstance(TypeEngine.get_transformer(sub_type), FlytePickleTransformer)

def _get_batch_size(self, python_val: list, python_type: Type[T]) -> int:
"""
Retrieves the batch size for a list eligible for batching.
This function helps in determining the number of items to store in a single pickle file.
By default, all items in the list are stored in a single pickle file.
However, users can specify a different batch size using the `BatchSize` annotation.
An example annotation `Annotated[List[Any], BatchSize(2)]` would set the batch size to 2.
"""
from flytekit.types.pickle.pickle import BatchSize

batch_size = max(len(python_val), 1) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize) and annotation.val >= 1:
batch_size = annotation.val
break
return batch_size

def _get_placeholder(self) -> Literal:
"""
Returns a placeholder in the form of a None literal.
"""
return Literal(scalar=Scalar(none_type=Void()))

def _is_placeholder(self, lit: Literal) -> bool:
"""
Verifies if the provided literal is a placeholder (None literal).
"""
return lit.scalar.none_type is not None

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle
sub_type = self.get_sub_type(python_type)

if self._is_batchable(sub_type):
batch_size = self._get_batch_size(python_val, python_type)
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], sub_type, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore
# Add placeholders to preserve the original list length for map task compatibility. Map task requires the list length unchanged.
num_placeholders = len(python_val) - len(lit_list)
lit_list += [self._get_placeholder()] * num_placeholders

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
break
if batch_size > 0:
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore
else:
lit_list = []
else:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
lit_list = [TypeEngine.to_literal(ctx, x, sub_type, expected.collection_type) for x in python_val] # type: ignore

return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
try:
lits = lv.collection.literals
except AttributeError:
raise TypeTransformerFailedError()
if self.is_batchable(expected_python_type):
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
if len(batch_list) > 0 and type(batch_list[0]) is list:
sub_type = self.get_sub_type(expected_python_type)

if self._is_batchable(sub_type):
batches = []
for lit in lits:
if self._is_placeholder(lit):
break
batches.append(TypeEngine.to_python_value(ctx, lit, sub_type))
if len(batches) > 0 and type(batches[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that
# won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
# won't merge the elements in the list. Therefore, we should check if the batches[0] is the list first.
return [item for batch in batches for item in batch]
return batches

else:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]
return [TypeEngine.to_python_value(ctx, x, sub_type) for x in lits]

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down
68 changes: 25 additions & 43 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,55 +1636,37 @@ def test_file_ext_with_flyte_file_wrong_type():
assert str(e.value) == "Underlying type of File Extension must be of type <str>"


def test_is_batchable():
Copy link
Member Author

@Yicheng-Lu-llll Yicheng-Lu-llll Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, whether a list is batchable depends solely on whether its subtype uses FlytePickleTransformer. There appears to be no need to test the is_batchable function. The correctness is now handled by the tests for TypeEngine.get_transformer.

assert ListTransformer.is_batchable(typing.List[int]) is False
assert ListTransformer.is_batchable(typing.List[str]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False
assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False
assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False

assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True
assert (
ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)])
is True
)


@pytest.mark.parametrize(
"python_val, python_type, expected_list_length",
"python_val, python_type, pickle_literal_length",
[
# Case 1: List of FlytePickle objects with default batch size.
# (By default, the batch_size is set to the length of the whole list.)
# After converting to literal, the result will be [batched_FlytePickle(5 items)].
# Therefore, the expected list length is [1].
([{"foo"}] * 5, typing.List[FlytePickle], [1]),
# Case 2: List of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]),
# Case 4: Empty list
([[], typing.List[FlytePickle], []]),
# Test case 1: A list of batchable objects with a default batch size.
# By default, the batch size is set to the total length of the list.
# Once transformed into a literal, the output will be
# [pickle literal containing 3 items, followed by two none literals].
# Hence, the length of the pickle literal list is 1.
([{"foo"}] * 3, typing.List[typing.Any], 1),
# Test case 2: A list of batchable objects with a batch size of 2.
# Once transformed into a literal, the output will be
# [pickle literal containing 2 items, pickle literal containing 2 items,
# pickle literal containing 1 item, followed by two none literals].
# Therefore, the length of the pickle literal list is 3.
(["foo"] * 5, Annotated[typing.List[typing.Any], HashMethod(function=str), BatchSize(2)], 3),
# Test case 3: An empty list
# In this scenario, the length of the pickle literal list should be 0.
([], typing.List[typing.Any], 0),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
def test_batchable_list(python_val, python_type, pickle_literal_length):
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)

tmp_lv = lv
for length in expected_list_length:
# Check that after converting to literal, the length of the literal list is equal to:
# - the length of the original list divided by the batch size if not nested
# - the length of the original list if it contains a nested list
assert len(tmp_lv.collection.literals) == length
tmp_lv = tmp_lv.collection.literals[0]
# Verifying that the length of the literal list matches the length of the original list for map task compatibility.
assert len(lv.collection.literals) == len(python_val)
# Confirming that the number of pickle literals in the literal list is equal to provided pickle literal length.
assert all(lit.scalar.blob is not None for lit in lv.collection.literals[:pickle_literal_length])
# Confirming that all the remaining literals in the literal list are none literals.
assert all(lit.scalar.none_type is not None for lit in lv.collection.literals[pickle_literal_length:])

pv = TypeEngine.to_python_value(ctx, lv, python_type)
# Check that after converting literal to Python value, the result is equal to the original python values.
Expand All @@ -1695,8 +1677,8 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length):
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
# data = task0() # task0() -> Annotated[typing.List[typing.Any], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[typing.Any])
assert pv == python_val


Expand Down