diff --git a/src/datachain/lib/convert/values_to_tuples.py b/src/datachain/lib/convert/values_to_tuples.py index 9f0b4cc07..9675a6c11 100644 --- a/src/datachain/lib/convert/values_to_tuples.py +++ b/src/datachain/lib/convert/values_to_tuples.py @@ -1,7 +1,12 @@ from collections.abc import Sequence from typing import Any, Union -from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type +from datachain.lib.data_model import ( + DataType, + DataTypeNames, + DataValuesType, + is_chain_type, +) from datachain.lib.utils import DataChainParamsError @@ -15,7 +20,7 @@ def __init__(self, ds_name, msg): def values_to_tuples( # noqa: C901, PLR0912 ds_name: str = "", output: Union[None, DataType, Sequence[str], dict[str, DataType]] = None, - **fr_map, + **fr_map: Sequence[DataValuesType], ) -> tuple[Any, Any, Any]: if output: if not isinstance(output, (Sequence, str, dict)): @@ -47,10 +52,10 @@ def values_to_tuples( # noqa: C901, PLR0912 f" number of signals '{len(fr_map)}'", ) - types_map = {} + types_map: dict[str, type] = {} length = -1 for k, v in fr_map.items(): - if not isinstance(v, Sequence) or isinstance(v, str): + if not isinstance(v, Sequence) or isinstance(v, str): # type: ignore[unreachable] raise ValuesToTupleError(ds_name, f"signals '{k}' is not a sequence") len_ = len(v) @@ -64,15 +69,16 @@ def values_to_tuples( # noqa: C901, PLR0912 if len_ == 0: raise ValuesToTupleError(ds_name, f"signal '{k}' is empty list") - typ = type(v[0]) + first_element = next(iter(v)) + typ = type(first_element) if not is_chain_type(typ): raise ValuesToTupleError( ds_name, f"signal '{k}' has unsupported type '{typ.__name__}'." f" Please use DataModel types: {DataTypeNames}", ) - if typ is list: - types_map[k] = list[type(v[0][0])] # type: ignore[misc] + if isinstance(first_element, list): + types_map[k] = list[type(first_element[0])] # type: ignore[assignment, misc] else: types_map[k] = typ @@ -98,7 +104,7 @@ def values_to_tuples( # noqa: C901, PLR0912 if len(output) > 1: # type: ignore[arg-type] tuple_type = tuple(output_types) res_type = tuple[tuple_type] # type: ignore[valid-type] - res_values = list(zip(*fr_map.values())) + res_values: Sequence[Any] = list(zip(*fr_map.values())) else: res_type = output_types[0] # type: ignore[misc] res_values = next(iter(fr_map.values())) diff --git a/src/datachain/lib/data_model.py b/src/datachain/lib/data_model.py index 5afb0ac63..bcb58b915 100644 --- a/src/datachain/lib/data_model.py +++ b/src/datachain/lib/data_model.py @@ -18,6 +18,7 @@ ] DataType = Union[type[BaseModel], StandardType] DataTypeNames = "BaseModel, int, str, float, bool, list, dict, bytes, datetime" +DataValuesType = Union[BaseModel, int, str, float, bool, list, dict, bytes, datetime] class DataModel(BaseModel):