Skip to content

Commit

Permalink
mypy: add more types
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Aug 18, 2024
1 parent 1396894 commit f6ffa62
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/datachain/lib/convert/values_to_tuples.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections.abc import Sequence
from typing import Any, Union

from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
from datachain.lib.data_model import (
DataType,
DataTypeNames,
DataValuesType,
is_chain_type,
)
from datachain.lib.utils import DataChainParamsError


Expand All @@ -15,7 +20,7 @@ def __init__(self, ds_name, msg):
def values_to_tuples( # noqa: C901, PLR0912
ds_name: str = "",
output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None,
**fr_map,
**fr_map: Sequence[DataValuesType],
) -> tuple[Any, Any, Any]:
if output:
if not isinstance(output, (Sequence, str, dict)):
Expand Down Expand Up @@ -47,10 +52,10 @@ def values_to_tuples( # noqa: C901, PLR0912
f" number of signals '{len(fr_map)}'",
)

types_map = {}
types_map: dict[str, type] = {}
length = -1
for k, v in fr_map.items():
if not isinstance(v, Sequence) or isinstance(v, str):
if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable]
raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence")
len_ = len(v)

Expand All @@ -64,15 +69,16 @@ def values_to_tuples( # noqa: C901, PLR0912
if len_ == 0:
raise ValuesToTupleError(ds_name, f"signal '{k}' is empty list")

typ = type(v[0])
first_element = next(iter(v))
typ = type(first_element)
if not is_chain_type(typ):
raise ValuesToTupleError(
ds_name,
f"signal '{k}' has unsupported type '{typ.__name__}'."
f" Please use DataModel types: {DataTypeNames}",
)
if typ is list:
types_map[k] = list[type(v[0][0])] # type: ignore[misc]
if isinstance(first_element, list):
types_map[k] = list[type(first_element[0])] # type: ignore[assignment, misc]
else:
types_map[k] = typ

Expand All @@ -98,7 +104,7 @@ def values_to_tuples( # noqa: C901, PLR0912
if len(output) > 1: # type: ignore[arg-type]
tuple_type = tuple(output_types)
res_type = tuple[tuple_type] # type: ignore[valid-type]
res_values = list(zip(*fr_map.values()))
res_values: Sequence[Any] = list(zip(*fr_map.values()))
else:
res_type = output_types[0] # type: ignore[misc]
res_values = next(iter(fr_map.values()))
Expand Down
1 change: 1 addition & 0 deletions src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
]
DataType = Union[type[BaseModel], StandardType]
DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime"
DataValuesType = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime]


class DataModel(BaseModel):
Expand Down

0 comments on commit f6ffa62

Please sign in to comment.