Skip to content

Commit

Permalink
Update find_by sorting (jefersondaniel#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
Artucuno committed Dec 7, 2023
1 parent 5de2c5c commit 4e0c7fe
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TypeVar,
Union,
Mapping,
List,
)

import asyncio
Expand Down Expand Up @@ -37,7 +38,7 @@
T = TypeVar("T", bound=BaseModel)
OutputT = TypeVar("OutputT", bound=BaseModel)

Sort = Sequence[Tuple[str, int]]
Sort = Union[str, Sequence[Tuple[str, int]], Tuple[str, int]]


class AbstractRepository(Generic[T]):
Expand Down Expand Up @@ -98,14 +99,19 @@ def __map_id(self, data: dict) -> dict:
query["_id"] = query.pop("id")
return query

def __map_sort(self, sort: Sort) -> Optional[Sort]:
def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, Any]]:
result = []
for item in sort:
key = item[0]
ordering = item[1]
if key == "id":
key = "_id"
result.append((key, ordering))
if isinstance(sort, str):
return sort
elif isinstance(sort, tuple):
return [sort]
elif isinstance(sort, list):
for item in sort:
key = item[0]
ordering = item[1]
if key == "id":
key = "_id"
result.append((key, ordering))
return result

def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT:
Expand Down Expand Up @@ -422,14 +428,19 @@ def __map_id(self, data: dict) -> dict:
query["_id"] = query.pop("id")
return query

def __map_sort(self, sort: Sort) -> Optional[Sort]:
def __map_sort(self, sort: Sort) -> str | list[tuple] | list[tuple[str | Any, Any]]:
result = []
for item in sort:
key = item[0]
ordering = item[1]
if key == "id":
key = "_id"
result.append((key, ordering))
if isinstance(sort, str):
return sort
elif isinstance(sort, tuple):
return [sort]
elif isinstance(sort, list):
for item in sort:
key = item[0]
ordering = item[1]
if key == "id":
key = "_id"
result.append((key, ordering))
return result

def to_model_custom(self, output_type: Type[OutputT], data: Union[dict, Mapping[str, Any]]) -> OutputT:
Expand Down

0 comments on commit 4e0c7fe

Please sign in to comment.