diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index aa4e5f3150..220fc3fb89 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -127,7 +127,11 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: break # If the current value is a dataclass, resolve the dataclass with the remaining path - if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct: + if ( + len(p.attr_path) > 0 + and type(curr_val.value) is _literals_models.Scalar + and type(curr_val.value.value) is _struct.Struct + ): st = curr_val.value.value new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) literal_type = TypeEngine.to_literal_type(type(new_st)) @@ -729,7 +733,7 @@ def binding_data_from_python_std( lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: - _, v_type = DictTransformer.get_dict_types(t_value_type) + _, v_type = DictTransformer.extract_types_or_metadata(t_value_type) m = _literals_models.BindingDataMap( bindings={ k: binding_data_from_python_std( diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9f9ed0765a..4b1d144c88 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -12,6 +12,7 @@ import textwrap import typing from abc import ABC, abstractmethod +from collections import OrderedDict from functools import lru_cache from typing import Dict, List, NamedTuple, Optional, Type, cast @@ -713,7 +714,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val)) if isinstance(val, dict): - ktype, vtype = DictTransformer.get_dict_types(t) + ktype, vtype = DictTransformer.extract_types_or_metadata(t) # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) return { self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items() @@ -1660,13 +1661,10 @@ class DictTransformer(TypeTransformer[dict]): """ def __init__(self): - super().__init__("Typed Dict", dict) + super().__init__("Python Dictionary", dict) @staticmethod - def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Optional[type]]: - """ - Return the generic Type T of the Dict - """ + def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: _origin = get_origin(t) _args = get_args(t) if _origin is not None: @@ -1679,22 +1677,60 @@ def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Opti raise ValueError( f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed." ) - if _origin is dict and _args is not None: + if _origin in [dict, Annotated] and _args is not None: return _args # type: ignore return None, None @staticmethod - def dict_to_generic_literal(v: dict) -> Literal: + def dict_to_generic_literal(v: dict, allow_pickle: bool) -> Literal: """ Creates a flyte-specific ``Literal`` value from a native python dictionary. """ - return Literal(scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct()))) + from flytekit.types.pickle import FlytePickle + + try: + return Literal( + scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())), + metadata={"format": "json"}, + ) + except TypeError as e: + if allow_pickle: + remote_path = FlytePickle.to_pickle(v) + return Literal( + scalar=Scalar( + generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct()) + ), + metadata={"format": "pickle"}, + ) + raise e + + @staticmethod + def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: + base_type, *metadata = DictTransformer.extract_types_or_metadata(python_type) + + for each_metadata in metadata: + if isinstance(each_metadata, OrderedDict): + allow_pickle = each_metadata.get("allow_pickle", False) + return allow_pickle, base_type + + return False, base_type + + @staticmethod + def dict_types(python_type: Type) -> typing.Tuple[typing.Any, ...]: + if get_origin(python_type) is Annotated: + base_type, *_ = DictTransformer.extract_types_or_metadata(python_type) + tp = get_args(base_type) + else: + tp = DictTransformer.extract_types_or_metadata(python_type) + + return tp def get_literal_type(self, t: Type[dict]) -> LiteralType: """ Transforms a native python dictionary to a flyte-specific ``LiteralType`` """ - tp = self.get_dict_types(t) + tp = self.dict_types(t) + if tp: if tp[0] == str: try: @@ -1710,21 +1746,33 @@ def to_literal( if type(python_val) != dict: raise TypeTransformerFailedError("Expected a dict") + allow_pickle = False + base_type = None + + if get_origin(python_type) is Annotated: + allow_pickle, base_type = DictTransformer.is_pickle(python_type) + if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_generic_literal(python_val) + return self.dict_to_generic_literal(python_val, allow_pickle) lit_map = {} for k, v in python_val.items(): if type(k) != str: raise ValueError("Flyte MapType expects all keys to be strings") # TODO: log a warning for Annotated objects that contain HashMethod - k_type, v_type = self.get_dict_types(python_type) + + if base_type: + _, v_type = get_args(base_type) + else: + _, v_type = self.extract_types_or_metadata(python_type) + lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: if lv and lv.map and lv.map.literals is not None: - tp = self.get_dict_types(expected_python_type) + tp = self.dict_types(expected_python_type) + if tp is None or tp[0] is None: raise TypeError( "TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given " @@ -1741,10 +1789,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict # evaluates to false if lv and lv.scalar and lv.scalar.generic is not None: - try: - return json.loads(_json_format.MessageToJson(lv.scalar.generic)) - except TypeError: - raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + if lv.metadata["format"] == "json": + try: + return json.loads(_json_format.MessageToJson(lv.scalar.generic)) + except TypeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + elif lv.metadata["format"] == "pickle": + from flytekit.types.pickle import FlytePickle + + uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file") + return FlytePickle.from_pickle(uri) + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index ca605a103d..314358c267 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -1,7 +1,10 @@ from typing import Optional from flyteidl.core.execution_pb2 import TaskExecution +from typing_extensions import Annotated +from flytekit import FlyteContextManager, kwtypes +from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, Resource, @@ -54,9 +57,19 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N inputs=inputs, ) - outputs = None + outputs = {"result": {"result": None}} if result: - outputs = {"result": result} + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + literals={ + "result": TypeEngine.to_literal( + ctx, + result, + Annotated[dict, kwtypes(allow_pickle=True)], + TypeEngine.to_literal_type(dict), + ) + } + ) return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py index 2e7c8f5b7b..1cb59eab08 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -32,7 +32,10 @@ def __init__( name=name, task_config=task_config, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)), + interface=Interface( + inputs=inputs, + outputs=kwtypes(result=dict), + ), **kwargs, ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index 2974711f88..ad72a0b7ac 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -1,10 +1,11 @@ -from datetime import timedelta +from datetime import datetime, timedelta from unittest import mock import pytest from flyteidl.core.execution_pb2 import TaskExecution from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals from flytekit.models.core.identifier import ResourceType @@ -12,24 +13,51 @@ @pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_return_value", + [ + ( + { + "ResponseMetadata": { + "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", + "HTTPStatusCode": 200.0, + "RetryAttempts": 0.0, + "HTTPHeaders": { + "content-type": "application/x-amz-json-1.1", + "date": "Wed, 31 Jan 2024 16:43:52 GMT", + "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", + "content-length": "114", + }, + }, + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", + } + ), + ( + { + "ResponseMetadata": { + "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", + "HTTPStatusCode": 200.0, + "RetryAttempts": 0.0, + "HTTPHeaders": { + "content-type": "application/x-amz-json-1.1", + "date": "Wed, 31 Jan 2024 16:43:52 GMT", + "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", + "content-length": "114", + }, + }, + "pickle_check": datetime(2024, 5, 5), + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", + } + ), + (None), + ], +) @mock.patch( "flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call", - return_value={ - "ResponseMetadata": { - "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", - "HTTPStatusCode": 200.0, - "RetryAttempts": 0.0, - "HTTPHeaders": { - "content-type": "application/x-amz-json-1.1", - "date": "Wed, 31 Jan 2024 16:43:52 GMT", - "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", - "content-length": "114", - }, - }, - "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - }, ) -async def test_agent(mock_boto_call): +async def test_agent(mock_boto_call, mock_return_value): + mock_boto_call.return_value = mock_return_value + agent = AgentRegistry.get_agent("boto") task_id = Identifier( resource_type=ResourceType.TASK, @@ -88,9 +116,16 @@ async def test_agent(mock_boto_call): ) resource = await agent.do(task_template, task_inputs) - assert resource.phase == TaskExecution.SUCCEEDED - assert ( - resource.outputs["result"]["EndpointConfigArn"] - == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" - ) + + if mock_return_value: + outputs = literal_map_string_repr(resource.outputs) + if "pickle_check" in mock_return_value: + assert "pickle_file" in outputs["result"] + else: + assert ( + outputs["result"]["EndpointConfigArn"] + == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" + ) + elif mock_return_value is None: + assert resource.outputs["result"] == {"result": None} diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 4799ce3f64..72a95ac2dd 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -529,7 +529,7 @@ def test_stable_cache_key(): } ) key = _calculate_cache_key("task_name_1", "31415", lm) - assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049" + assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb" @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 4df877a84a..d4db4f34fe 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -51,10 +51,22 @@ from flytekit.models import types as model_types from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import BlobType -from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Void +from flytekit.models.literals import ( + Blob, + BlobMetadata, + Literal, + LiteralCollection, + LiteralMap, + Primitive, + Scalar, + Void, +) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs -from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer +from flytekit.types.directory.types import ( + FlyteDirectory, + FlyteDirToMultipartBlobTransformer, +) from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle @@ -90,6 +102,7 @@ def test_type_resolution(): assert type(TypeEngine.get_transformer(typing.Dict[str, int])) == DictTransformer assert type(TypeEngine.get_transformer(typing.Dict)) == DictTransformer assert type(TypeEngine.get_transformer(dict)) == DictTransformer + assert type(TypeEngine.get_transformer(Annotated[dict, kwtypes(allow_pickle=True)])) == DictTransformer assert type(TypeEngine.get_transformer(int)) == SimpleTransformer assert type(TypeEngine.get_transformer(datetime.date)) == SimpleTransformer @@ -130,7 +143,12 @@ def test_file_format_getting_python_value(): with open(file_path, "w") as file1: file1.write("hello world") lv = Literal( - scalar=Scalar(blob=Blob(metadata=BlobMetadata(type=BlobType(format="txt", dimensionality=0)), uri=file_path)) + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="txt", dimensionality=0)), + uri=file_path, + ) + ) ) pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteFile["txt"]) @@ -202,7 +220,8 @@ def test_list_of_single_dataclassjsonmixin(): def test_annotated_type(): class JsonTypeTransformer(TypeTransformer[T]): LiteralType = LiteralType( - simple=SimpleType.STRING, annotation=TypeAnnotation(annotations=dict(protocol="json")) + simple=SimpleType.STRING, + annotation=TypeAnnotation(annotations=dict(protocol="json")), ) def get_literal_type(self, t: Type[T]) -> LiteralType: @@ -212,7 +231,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return json.loads(lv.scalar.primitive.string_value) def to_literal( - self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType + self, + ctx: FlyteContext, + python_val: T, + python_type: typing.Type[T], + expected: LiteralType, ) -> Literal: return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val)))) @@ -239,7 +262,12 @@ def __class_getitem__(cls, item: Type[T]): ) assert ( - TypeEngine.to_literal(FlyteContext.current_context(), test_dict, MyJsonDict, JsonTypeTransformer.LiteralType) + TypeEngine.to_literal( + FlyteContext.current_context(), + test_dict, + MyJsonDict, + JsonTypeTransformer.LiteralType, + ) == test_literal ) @@ -262,7 +290,14 @@ class Foo(DataClassJsonMixin): y: typing.Dict[str, str] z: Bar - foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False})) + foo = Foo( + u=5, + v=None, + w=1, + x=[1], + y={"hello": "10"}, + z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False}), + ) generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) @@ -365,7 +400,12 @@ def test_file_no_downloader_default(): file.write("hello world") lv = Literal( - scalar=Scalar(blob=Blob(metadata=BlobMetadata(type=BlobType(format="", dimensionality=0)), uri=local_file)) + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="", dimensionality=0)), + uri=local_file, + ) + ) ) pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteFile) @@ -383,7 +423,12 @@ def test_dir_no_downloader_default(): local_dir = tempfile.mkdtemp(prefix="temp_example_") lv = Literal( - scalar=Scalar(blob=Blob(metadata=BlobMetadata(type=BlobType(format="", dimensionality=1)), uri=local_dir)) + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=BlobType(format="", dimensionality=1)), + uri=local_dir, + ) + ) ) pv = transformer.to_python_value(ctx, lv, expected_python_type=FlyteDirectory) @@ -408,7 +453,12 @@ def assert_struct(lit: LiteralType): assert lit is not None assert lit.simple == SimpleType.STRUCT - def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: int = 1, curr_depth: int = 0): + def recursive_assert( + lit: LiteralType, + expected: LiteralType, + expected_depth: int = 1, + curr_depth: int = 0, + ): assert curr_depth <= expected_depth assert lit is not None if lit.map_value_type is None: @@ -418,13 +468,29 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in # Type inference assert_struct(d.get_literal_type(dict)) + assert_struct(d.get_literal_type(Annotated[dict, kwtypes(allow_pickle=True)])) assert_struct(d.get_literal_type(typing.Dict[int, int])) recursive_assert(d.get_literal_type(typing.Dict[str, str]), LiteralType(simple=SimpleType.STRING)) - recursive_assert(d.get_literal_type(typing.Dict[str, int]), LiteralType(simple=SimpleType.INTEGER)) - recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]), LiteralType(simple=SimpleType.DATETIME)) - recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]), LiteralType(simple=SimpleType.DURATION)) - recursive_assert(d.get_literal_type(typing.Dict[str, datetime.date]), LiteralType(simple=SimpleType.DATETIME)) - recursive_assert(d.get_literal_type(typing.Dict[str, dict]), LiteralType(simple=SimpleType.STRUCT)) + recursive_assert( + d.get_literal_type(typing.Dict[str, int]), + LiteralType(simple=SimpleType.INTEGER), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.datetime]), + LiteralType(simple=SimpleType.DATETIME), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.timedelta]), + LiteralType(simple=SimpleType.DURATION), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, datetime.date]), + LiteralType(simple=SimpleType.DATETIME), + ) + recursive_assert( + d.get_literal_type(typing.Dict[str, dict]), + LiteralType(simple=SimpleType.STRUCT), + ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), LiteralType(simple=SimpleType.STRING), @@ -471,6 +537,23 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in with pytest.raises(TypeError): d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})), typing.Dict[int, str]) + with pytest.raises(TypeError): + d.to_literal( + ctx, + {"x": datetime.datetime(2024, 5, 5)}, + dict, + LiteralType(simple=SimpleType.STRUCT), + ) + + lv = d.to_literal( + ctx, + {"x": datetime.datetime(2024, 5, 5)}, + Annotated[dict, kwtypes(allow_pickle=True)], + LiteralType(simple=SimpleType.STRUCT), + ) + assert lv.metadata["format"] == "pickle" + assert d.to_python_value(ctx, lv, dict) == {"x": datetime.datetime(2024, 5, 5)} + d.to_python_value( ctx, Literal(map=LiteralMap(literals={"x": Literal(scalar=Scalar(primitive=Primitive(integer=1)))})), @@ -582,7 +665,8 @@ def test_guessing_basic(): lt = model_types.LiteralType( blob=BlobType( - format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE + format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, + dimensionality=BlobType.BlobDimensionality.SINGLE, ) ) pt = TypeEngine.guess_python_type(lt) @@ -676,8 +760,16 @@ def test_dataclass_transformer(): "TeststructSchema": { "additionalProperties": False, "properties": { - "m": {"additionalProperties": {"title": "m", "type": "string"}, "title": "m", "type": "object"}, - "s": {"$ref": "#/definitions/InnerstructSchema", "field_many": False, "type": "object"}, + "m": { + "additionalProperties": {"title": "m", "type": "string"}, + "title": "m", + "type": "object", + }, + "s": { + "$ref": "#/definitions/InnerstructSchema", + "field_many": False, + "type": "object", + }, }, "type": "object", }, @@ -765,7 +857,11 @@ def test_dataclass_transformer_with_dataclassjsonmixin(): "additionalProperties": False, "required": ["a", "b", "c"], }, - "m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}}, + "m": { + "type": "object", + "additionalProperties": {"type": "string"}, + "propertyNames": {"type": "string"}, + }, }, "additionalProperties": False, "required": ["s", "m"], @@ -803,7 +899,10 @@ def test_dataclass_int_preserving(): assert ot == o o = TestStructB( - s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={5: "b"}, n=[[1, 2, 3], [4, 5, 6]], o={1: {2: 3}, 4: {5: 6}} + s=InnerStruct(a=5, b=None, c=[1, 2, 3]), + m={5: "b"}, + n=[[1, 2, 3], [4, 5, 6]], + o={1: {2: 3}, 4: {5: 6}}, ) lv = tf.to_literal(ctx, o, TestStructB, tf.get_literal_type(TestStructB)) ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestStructB) @@ -1008,7 +1107,13 @@ class TestFileStruct(DataClassJsonMixin): f2._remote_source = remote_path o = TestFileStruct( a=f1, - b=TestInnerFileStruct(a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2}), + b=TestInnerFileStruct( + a=JPEGImageFile("s3://tmp/file.jpeg"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), ) ctx = FlyteContext.current_context() @@ -1054,7 +1159,11 @@ def test_flyte_file_in_dataclassjsonmixin(): o = TestFileStruct_flyte_file( a=f1, b=TestInnerFileStruct_flyte_file( - a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + a=JPEGImageFile("s3://tmp/file.jpeg"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, ), ) @@ -1099,7 +1208,13 @@ class TestFileStruct(DataClassJsonMixin): f2 = FlyteDirectory(remote_path) o = TestFileStruct( a=f1, - b=TestInnerFileStruct(a=TensorboardLogs("s3://tensorboard"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2}), + b=TestInnerFileStruct( + a=TensorboardLogs("s3://tensorboard"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, + ), ) ctx = FlyteContext.current_context() @@ -1148,7 +1263,11 @@ def test_flyte_directory_in_dataclassjsonmixin(): o = TestFileStruct_flyte_directory( a=f1, b=TestInnerFileStruct_flyte_directory( - a=TensorboardLogs("s3://tensorboard"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + a=TensorboardLogs("s3://tensorboard"), + b=[f1], + c={"hello": f1}, + d=[f2], + e={"hello": f2}, ), ) @@ -1342,7 +1461,11 @@ def test_enum_type(): assert v == "red" with pytest.raises(ValueError): - TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), Color) + TypeEngine.to_python_value( + ctx, + Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), + Color, + ) with pytest.raises(ValueError): TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color) @@ -1416,7 +1539,8 @@ class Bar(DataClassJsonMixin): pv = Bar(x=3) with pytest.raises( - TypeTransformerFailedError, match="Type of Val '' is not an instance of " + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", ): DataclassTransformer().assert_type(gt, pv) @@ -1441,7 +1565,13 @@ class Args(DataClassJsonMixin): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) sd = StructuredDataset(dataframe=df, file_format="parquet") # Test when v is a dict - vd = {"x": 3, "y": "hello", "file": FlyteFile(pv), "dataset": sd, "another_dataclass": {"z": 4}} + vd = { + "x": 3, + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } DataclassTransformer().assert_type(Args, vd) # Test when v is a dict but missing Optional keys and other keys from dataclass @@ -1449,7 +1579,12 @@ class Args(DataClassJsonMixin): DataclassTransformer().assert_type(Args, md) # Test when v is a dict but missing non-Optional keys from dataclass - md = {"y": "hello", "file": FlyteFile(pv), "dataset": sd, "another_dataclass": {"z": 4}} + md = { + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } with pytest.raises( TypeTransformerFailedError, match=re.escape("The original fields are missing the following keys from the dataclass fields: ['x']"), @@ -1457,7 +1592,14 @@ class Args(DataClassJsonMixin): DataclassTransformer().assert_type(Args, md) # Test when v is a dict but has extra keys that are not in dataclass - ed = {"x": 3, "y": "hello", "file": FlyteFile(pv), "dataset": sd, "another_dataclass": {"z": 4}, "z": "extra"} + ed = { + "x": 3, + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + "z": "extra", + } with pytest.raises( TypeTransformerFailedError, match=re.escape("The original fields have the following extra keys that are not in dataclass fields: ['z']"), @@ -1465,9 +1607,16 @@ class Args(DataClassJsonMixin): DataclassTransformer().assert_type(Args, ed) # Test when the type of value in the dict does not match the expected_type in the dataclass - td = {"x": "3", "y": "hello", "file": FlyteFile(pv), "dataset": sd, "another_dataclass": {"z": 4}} + td = { + "x": "3", + "y": "hello", + "file": FlyteFile(pv), + "dataset": sd, + "another_dataclass": {"z": 4}, + } with pytest.raises( - TypeTransformerFailedError, match="Type of Val '' is not an instance of " + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", ): DataclassTransformer().assert_type(Args, td) @@ -1542,7 +1691,10 @@ def test_union_guess_type(): t = ut.guess_python_type( LiteralType( union_type=UnionType( - variants=[LiteralType(simple=SimpleType.STRING), LiteralType(simple=SimpleType.INTEGER)] + variants=[ + LiteralType(simple=SimpleType.STRING), + LiteralType(simple=SimpleType.INTEGER), + ] ) ) ) @@ -1551,15 +1703,20 @@ def test_union_guess_type(): def test_union_type_with_annotated(): pt = typing.Union[ - Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})] + Annotated[str, FlyteAnnotation({"hello": "world"})], + Annotated[int, FlyteAnnotation({"test": 123})], ] lt = TypeEngine.to_literal_type(pt) assert lt.union_type.variants == [ LiteralType( - simple=SimpleType.STRING, structure=TypeStructure(tag="str"), annotation=TypeAnnotation({"hello": "world"}) + simple=SimpleType.STRING, + structure=TypeStructure(tag="str"), + annotation=TypeAnnotation({"hello": "world"}), ), LiteralType( - simple=SimpleType.INTEGER, structure=TypeStructure(tag="int"), annotation=TypeAnnotation({"test": 123}) + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + annotation=TypeAnnotation({"test": 123}), ), ] assert union_type_tags_unique(lt) @@ -1736,7 +1893,11 @@ def get_literal_type(self, t: typing.Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.INTEGER) def to_literal( - self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType + self, + ctx: FlyteContext, + python_val: T, + python_type: typing.Type[T], + expected: LiteralType, ) -> Literal: if type(python_val) != int: raise TypeTransformerFailedError("Expected an integer") @@ -1846,7 +2007,8 @@ def __init__(self, number: int): TypeEngine.to_literal(ctx, None, FlytePickle, lt) with pytest.raises( - AssertionError, match="Expected value of type but got '1' of type " + AssertionError, + match="Expected value of type but got '1' of type ", ): lt = TypeEngine.to_literal_type(typing.Optional[typing.Any]) TypeEngine.to_literal(ctx, 1, type(None), lt) @@ -1944,7 +2106,10 @@ class Datum(DataClassJSONMixin): generic=_json_format.Parse( typing.cast( DataClassJsonMixin, - TestStructD(s=InnerStruct(a=5, b=None, c=[1, 2, 3]), m={"a": [5]}), + TestStructD( + s=InnerStruct(a=5, b=None, c=[1, 2, 3]), + m={"a": [5]}, + ), ).to_json(), _struct.Struct(), ) @@ -1962,7 +2127,10 @@ class Datum(DataClassJSONMixin): scalar=Scalar( blob=Blob( metadata=BlobMetadata( - type=BlobType(format="jpeg", dimensionality=BlobType.BlobDimensionality.SINGLE) + type=BlobType( + format="jpeg", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ) ), uri="s3://tmp/file.jpeg", ) @@ -2044,7 +2212,10 @@ def test_literal_hash_int_can_be_set(): """ ctx = FlyteContext.current_context() lv = TypeEngine.to_literal( - ctx, 42, Annotated[int, HashMethod(str)], LiteralType(simple=model_types.SimpleType.INTEGER) + ctx, + 42, + Annotated[int, HashMethod(str)], + LiteralType(simple=model_types.SimpleType.INTEGER), ) assert lv.scalar.primitive.integer == 42 assert lv.hash == "42" @@ -2085,8 +2256,14 @@ def _check_annotation(t, annotation): assert isinstance(lt.annotation, TypeAnnotation) assert lt.annotation.annotations == annotation - _check_annotation(typing_extensions.Annotated[int, FlyteAnnotation({"foo": "bar"})], {"foo": "bar"}) - _check_annotation(typing_extensions.Annotated[int, FlyteAnnotation(["foo", "bar"])], ["foo", "bar"]) + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation({"foo": "bar"})], + {"foo": "bar"}, + ) + _check_annotation( + typing_extensions.Annotated[int, FlyteAnnotation(["foo", "bar"])], + ["foo", "bar"], + ) _check_annotation( typing_extensions.Annotated[int, FlyteAnnotation({"d": {"test": "data"}, "l": ["nested", ["list"]]})], {"d": {"test": "data"}, "l": ["nested", ["list"]]}, @@ -2171,17 +2348,32 @@ class AnnotatedDataclassTest(DataClassJsonMixin): [ (dict, LiteralType(simple=SimpleType.STRUCT)), # Annotations are not being copied over to the LiteralType - (typing_extensions.Annotated[dict, "a-tag"], LiteralType(simple=SimpleType.STRUCT)), + ( + typing_extensions.Annotated[dict, "a-tag"], + LiteralType(simple=SimpleType.STRUCT), + ), (typing.Dict[int, str], LiteralType(simple=SimpleType.STRUCT)), - (typing.Dict[str, int], LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER))), - (typing.Dict[str, str], LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING))), + ( + typing.Dict[str, int], + LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER)), + ), + ( + typing.Dict[str, str], + LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING)), + ), ( typing.Dict[str, typing.List[int]], LiteralType(map_value_type=LiteralType(collection_type=LiteralType(simple=SimpleType.INTEGER))), ), (typing.Dict[int, typing.List[int]], LiteralType(simple=SimpleType.STRUCT)), - (typing.Dict[int, typing.Dict[int, int]], LiteralType(simple=SimpleType.STRUCT)), - (typing.Dict[str, typing.Dict[int, int]], LiteralType(map_value_type=LiteralType(simple=SimpleType.STRUCT))), + ( + typing.Dict[int, typing.Dict[int, int]], + LiteralType(simple=SimpleType.STRUCT), + ), + ( + typing.Dict[str, typing.Dict[int, int]], + LiteralType(map_value_type=LiteralType(simple=SimpleType.STRUCT)), + ), ( typing.Dict[str, typing.Dict[str, int]], LiteralType(map_value_type=LiteralType(map_value_type=LiteralType(simple=SimpleType.INTEGER))), @@ -2325,8 +2517,6 @@ class Result_dataclassjsonmixin(DataClassJSONMixin): def test_schema_in_dataclassjsonmixin(): import pandas as pd - from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401 - schema = TestSchema() df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) schema.open().write(df) @@ -2458,12 +2648,20 @@ def test_is_batchable(): # After converting to literal, the result will be # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. # Therefore, the expected list length is [3]. - (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]), + ( + ["foo"] * 5, + Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], + [3], + ), # Case 3: Nested list of FlytePickle objects with batch size 2. # After converting to literal, the result will be # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). - ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]), + ( + [["foo", "foo", "foo"]] * 2, + typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], + [2, 1], + ), # Case 4: Empty list ([[], typing.List[FlytePickle], []]), ], @@ -2526,12 +2724,15 @@ def test_get_underlying_type(t, expected): (None, (None, None)), (typing.Dict, ()), (typing.Dict[str, str], (str, str)), - (Annotated[typing.Dict, "a-tag"], (None, None)), + ( + Annotated[typing.Dict[str, str], kwtypes(allow_pickle=True)], + (typing.Dict[str, str], kwtypes(allow_pickle=True)), + ), (typing.Dict[Annotated[str, "a-tag"], int], (Annotated[str, "a-tag"], int)), ], ) def test_dict_get(t, expected): - assert DictTransformer.get_dict_types(t) == expected + assert DictTransformer.extract_types_or_metadata(t) == expected def test_DataclassTransformer_get_literal_type(): @@ -2590,7 +2791,10 @@ class MyDataClass: assert lv_mashumaro.scalar.generic["x"] == 5 lv_mashumaro_orjson = transformer.to_literal( - ctx, my_dat_class_mashumaro_orjson, MyDataClassMashumaroORJSON, MyDataClassMashumaroORJSON + ctx, + my_dat_class_mashumaro_orjson, + MyDataClassMashumaroORJSON, + MyDataClassMashumaroORJSON, ) assert lv_mashumaro_orjson is not None assert lv_mashumaro_orjson.scalar.generic["x"] == 5