Skip to content

Commit

Permalink
add missing type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
wirthual committed Dec 6, 2024
1 parent a0b5cc4 commit 7917c56
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ async def astop(self):
await engine.astop()

async def embed(
self, *, model: str, sentences: list[str], matryoshka_dim=None
self, *, model: str, sentences: list[str], matryoshka_dim: Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple sentences
Expand Down Expand Up @@ -393,7 +393,7 @@ async def classify(
return await self[model].classify(sentences=sentences, raw_scores=raw_scores)

async def image_embed(
self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim=None
self, *, model: str, images: list[Union[str, "ImageClassType"]], matryoshka_dim:Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple images
Expand Down Expand Up @@ -432,7 +432,7 @@ def __getitem__(self, index_or_name: Union[str, int]) -> "AsyncEmbeddingEngine":
)

async def audio_embed(
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None
self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""embed multiple audios
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
)

async def embed(
self, sentences: list[str], matryoshka_dim=None
self, sentences: list[str], matryoshka_dim:Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a sentence to be embedded. Awaits until embedded.
Expand Down Expand Up @@ -240,7 +240,7 @@ async def classify(
return classifications, usage

async def image_embed(
self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim=None
self, *, images: list[Union[str, "ImageClassType", bytes]], matryoshka_dim:Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule a images and sentences to be embedded. Awaits until embedded.
Expand Down Expand Up @@ -269,7 +269,7 @@ async def image_embed(
return embeddings, usage

async def audio_embed(
self, *, audios: list[Union[str, bytes]], matryoshka_dim=None
self, *, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None
) -> tuple[list["EmbeddingReturnType"], int]:
"""Schedule audios and sentences to be embedded. Awaits until embedded.
Expand Down
6 changes: 3 additions & 3 deletions libs/infinity_emb/infinity_emb/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def stop(self):
self.async_run(self.async_engine_array.astop).result()

@add_start_docstrings(AsyncEngineArray.embed.__doc__)
def embed(self, *, model: str, sentences: list[str], matryoshka_dim=None):
def embed(self, *, model: str, sentences: list[str], matryoshka_dim:Optional[int]=None):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.embed,
Expand Down Expand Up @@ -211,7 +211,7 @@ def classify(self, *, model: str, sentences: list[str], raw_scores: bool = False
)

@add_start_docstrings(AsyncEngineArray.image_embed.__doc__)
def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim=None):
def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.image_embed,
Expand All @@ -221,7 +221,7 @@ def image_embed(self, *, model: str, images: list[Union[str, bytes]], matryoshka
)

@add_start_docstrings(AsyncEngineArray.audio_embed.__doc__)
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim=None):
def audio_embed(self, *, model: str, audios: list[Union[str, bytes]], matryoshka_dim:Optional[int]=None):
"""sync interface of AsyncEngineArray"""
return self.async_run(
self.async_engine_array.audio_embed,
Expand Down

0 comments on commit 7917c56

Please sign in to comment.