diff --git a/pydantic_mongo/abstract_repository.py b/pydantic_mongo/abstract_repository.py index 00f5e46..4466f7e 100644 --- a/pydantic_mongo/abstract_repository.py +++ b/pydantic_mongo/abstract_repository.py @@ -10,6 +10,7 @@ TypeVar, Union, Mapping, + List, ) import asyncio @@ -37,7 +38,7 @@ T = TypeVar("T", bound=BaseModel) OutputT = TypeVar("OutputT", bound=BaseModel) -Sort = Sequence[Tuple[str, int]] +Sort = Union[str, Sequence[Tuple[str, int]], Tuple[str, int]] class AbstractRepository(Generic[T]): @@ -98,14 +99,19 @@ def __map_id(self, data: dict) -> dict: query["_id"] = query.pop("id") return query - def __map_sort(self, sort: Sort) -> Optional[Sort]: + def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, Any]]: result = [] - for item in sort: - key = item[0] - ordering = item[1] - if key == "id": - key = "_id" - result.append((key, ordering)) + if isinstance(sort, str): + return sort + elif isinstance(sort, tuple): + return [sort] + elif isinstance(sort, list): + for item in sort: + key = item[0] + ordering = item[1] + if key == "id": + key = "_id" + result.append((key, ordering)) return result def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT: @@ -422,14 +428,19 @@ def __map_id(self, data: dict) -> dict: query["_id"] = query.pop("id") return query - def __map_sort(self, sort: Sort) -> Optional[Sort]: + def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, Any]]: result = [] - for item in sort: - key = item[0] - ordering = item[1] - if key == "id": - key = "_id" - result.append((key, ordering)) + if isinstance(sort, str): + return sort + elif isinstance(sort, tuple): + return [sort] + elif isinstance(sort, list): + for item in sort: + key = item[0] + ordering = item[1] + if key == "id": + key = "_id" + result.append((key, ordering)) return result def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT: