diff --git a/llmclient/embeddings.py b/llmclient/embeddings.py index ce15503..e03de2e 100644 --- a/llmclient/embeddings.py +++ b/llmclient/embeddings.py @@ -1,10 +1,11 @@ import asyncio from abc import ABC, abstractmethod +from collections import Counter from enum import StrEnum +from itertools import chain from typing import Any import litellm -import numpy as np import tiktoken from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -171,13 +172,9 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]: enc_batch = self.enc.encode_ordinary_batch(texts) # now get frequency of each token rel to length return [ - ( - np.bincount([xi % self.ndim for xi in x], minlength=self.ndim).astype( - float - ) - / len(x) - ).tolist() + [token_counts.get(xi, 0) / len(x) for xi in range(self.ndim)] for x in enc_batch + if (token_counts := Counter(xi % self.ndim for xi in x)) ] @@ -199,7 +196,11 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]: all_embeds = await asyncio.gather( *[m.embed_documents(texts) for m in self.models] ) - return np.concatenate(all_embeds, axis=1).tolist() + + return [ + list(chain.from_iterable(embed_group)) + for embed_group in zip(*all_embeds, strict=True) + ] def set_mode(self, mode: EmbeddingModes) -> None: # Set mode for all component models @@ -217,6 +218,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, **kwargs): super().__init__(**kwargs) try: + import numpy as np # noqa: F401 from sentence_transformers import SentenceTransformer except ImportError as exc: raise ImportError( @@ -240,6 +242,8 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]: Returns: A list of embedding vectors. """ + import numpy as np + # Extract additional configurations if needed batch_size = self.config.get("batch_size", 32) device = self.config.get("device", "cpu") diff --git a/pyproject.toml b/pyproject.toml index 8c6d278..a85772d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "fhaviary>=0.8.2", # For core namespace "limits", "litellm>=1.44", # For LITELLM_LOG addition - "numpy", "pydantic~=2.0,>=2.10.1,<2.10.2", "tiktoken>=0.4.0", "typing-extensions; python_version <= '3.11'", # for typing.override @@ -40,7 +39,7 @@ requires-python = ">=3.11" [project.optional-dependencies] dev = [ - "fh-llm-client[image,local]", + "fh-llm-client[local]", "fhaviary[xml]", "ipython>=8", # Pin to keep recent "mypy>=1.8", # Pin for mutable-override @@ -58,10 +57,8 @@ dev = [ "python-dotenv", "refurb>=2", # Pin to keep recent ] -image = [ - "Pillow", -] local = [ + "numpy", "sentence-transformers", ] diff --git a/uv.lock b/uv.lock index 8d118c6..a161538 100644 --- a/uv.lock +++ b/uv.lock @@ -563,14 +563,13 @@ wheels = [ [[package]] name = "fh-llm-client" -version = "0.0.4.dev6+g6de1e91.d20241206" +version = "0.0.4.dev3+g418fa3b.d20241209" source = { editable = "." } dependencies = [ { name = "coredis" }, { name = "fhaviary" }, { name = "limits" }, { name = "litellm" }, - { name = "numpy" }, { name = "pydantic" }, { name = "tiktoken" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, @@ -581,7 +580,7 @@ dev = [ { name = "fhaviary", extra = ["xml"] }, { name = "ipython" }, { name = "mypy" }, - { name = "pillow" }, + { name = "numpy" }, { name = "pre-commit" }, { name = "pylint-pydantic" }, { name = "pytest" }, @@ -597,10 +596,8 @@ dev = [ { name = "refurb" }, { name = "sentence-transformers" }, ] -image = [ - { name = "pillow" }, -] local = [ + { name = "numpy" }, { name = "sentence-transformers" }, ] @@ -610,7 +607,7 @@ codeflash = [ { name = "fhaviary", extra = ["xml"] }, { name = "ipython" }, { name = "mypy" }, - { name = "pillow" }, + { name = "numpy" }, { name = "pre-commit" }, { name = "pylint-pydantic" }, { name = "pytest" }, @@ -630,7 +627,7 @@ dev = [ { name = "fhaviary", extra = ["xml"] }, { name = "ipython" }, { name = "mypy" }, - { name = "pillow" }, + { name = "numpy" }, { name = "pre-commit" }, { name = "pylint-pydantic" }, { name = "pytest" }, @@ -650,15 +647,14 @@ dev = [ [package.metadata] requires-dist = [ { name = "coredis" }, - { name = "fh-llm-client", extras = ["image", "local"], marker = "extra == 'dev'" }, + { name = "fh-llm-client", extras = ["local"], marker = "extra == 'dev'" }, { name = "fhaviary", specifier = ">=0.8.2" }, { name = "fhaviary", extras = ["xml"], marker = "extra == 'dev'" }, { name = "ipython", marker = "extra == 'dev'", specifier = ">=8" }, { name = "limits" }, { name = "litellm", specifier = ">=1.44" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8" }, - { name = "numpy" }, - { name = "pillow", marker = "extra == 'image'" }, + { name = "numpy", marker = "extra == 'local'" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.4" }, { name = "pydantic", specifier = "~=2.0,>=2.10.1,<2.10.2" }, { name = "pylint-pydantic", marker = "extra == 'dev'" },