Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov committed Feb 27, 2024
1 parent 24c91f6 commit 9d9b2e1
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 14 deletions.
4 changes: 3 additions & 1 deletion lilac/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class ExampleIn(BaseModel):

@field_validator('text')
@classmethod
def parse_text(cls, text: str) -> str:
def parse_text(cls, text: Optional[str]) -> Optional[str]:
"""Fixes surrogate errors in text: https://github.com/ijl/orjson/blob/master/README.md#str ."""
if not text:
return None
return text.encode('utf-8', 'replace').decode('utf-8')


Expand Down
8 changes: 6 additions & 2 deletions lilac/embeddings/bge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk
from ..utils import log

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,7 +70,10 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
lambda docs: self._model.encode(docs)['dense_vecs'],
docs,
Expand Down
8 changes: 6 additions & 2 deletions lilac/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Cohere embeddings."""
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

import numpy as np
from typing_extensions import override

from ..env import env
from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import TextChunk
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding, identity_chunker
Expand Down Expand Up @@ -65,5 +66,8 @@ def _embed_fn(docs: list[str]) -> list[np.ndarray]:
).embeddings
]

chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(_embed_fn, docs, self.local_batch_size, chunker=chunker)
12 changes: 9 additions & 3 deletions lilac/embeddings/gte.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
import itertools
from typing import TYPE_CHECKING, ClassVar, Iterator, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Iterator, Optional, cast

import modal
from typing_extensions import override
Expand Down Expand Up @@ -69,7 +69,10 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
self._model.encode, docs, self.local_batch_size * 16, chunker=chunker
)
Expand All @@ -79,7 +82,10 @@ def compute_garden(self, docs: Iterator[str]) -> Iterator[Item]:
# Trim the docs to the max context size.

trimmed_docs = (doc[:GTE_CONTEXT_SIZE] for doc in docs)
chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
text_chunks: Iterator[tuple[int, TextChunk]] = (
(i, chunk) for i, doc in enumerate(trimmed_docs) for chunk in chunker(doc)
)
Expand Down
9 changes: 7 additions & 2 deletions lilac/embeddings/nomic_embed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

import numpy as np
from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk

if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer

Expand Down Expand Up @@ -76,7 +78,10 @@ def _encode(doc: list[str]) -> list[np.ndarray]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(_encode, docs, self.local_batch_size * 16, chunker=chunker)

@override
Expand Down
8 changes: 6 additions & 2 deletions lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""OpenAI embeddings."""
from typing import ClassVar, Optional
from typing import Callable, ClassVar, Optional, cast

import numpy as np
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand All @@ -8,6 +8,7 @@
from ..env import env
from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import TextChunk
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding, identity_chunker
Expand Down Expand Up @@ -92,5 +93,8 @@ def embed_fn(texts: list[str]) -> list[np.ndarray]:
)
return [np.array(embedding.embedding, dtype=np.float32) for embedding in response.data]

chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(embed_fn, docs, self.local_batch_size, chunker=chunker)
8 changes: 6 additions & 2 deletions lilac/embeddings/sbert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Sentence-BERT embeddings. Open-source models, designed to run on device."""
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk
from ..tasks import TaskExecutionType

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,7 +48,10 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = clustering_spacy_chunker if self._split else identity_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
self._model.encode, docs, self.local_batch_size * 16, chunker=chunker
)
Expand Down

0 comments on commit 9d9b2e1

Please sign in to comment.