Skip to content

Commit

Permalink
Switch splitter to new algorithm, fixing last blocking bug. (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee authored Oct 26, 2023
1 parent 085cf14 commit abcf2b9
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 11 deletions.
4 changes: 2 additions & 2 deletions lilac/embeddings/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..env import env
from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import compute_split_embeddings

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,6 +54,6 @@ def embed_fn(texts: list[str]) -> list[np.ndarray]:
return self._model.embed(texts, truncate='END').embeddings

docs = cast(Iterable[str], docs)
split_fn = split_text if self._split else None
split_fn = clustering_spacy_chunker if self._split else None
yield from compute_split_embeddings(
docs, COHERE_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
4 changes: 2 additions & 2 deletions lilac/embeddings/gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import compute_split_embeddings
from .transformer_utils import get_model

Expand Down Expand Up @@ -48,7 +48,7 @@ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
batch_size, model = get_model(self._model_name, _OPTIMAL_BATCH_SIZES[self._model_name])
embed_fn = model.encode
split_fn = split_text if self._split else None
split_fn = clustering_spacy_chunker if self._split else None
docs = cast(Iterable[str], docs)
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)

Expand Down
4 changes: 2 additions & 2 deletions lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..env import env
from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import compute_split_embeddings

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,6 +63,6 @@ def embed_fn(texts: list[str]) -> list[np.ndarray]:
return [np.array(embedding['embedding'], dtype=np.float32) for embedding in response['data']]

docs = cast(Iterable[str], docs)
split_fn = split_text if self._split else None
split_fn = clustering_spacy_chunker if self._split else None
yield from compute_split_embeddings(
docs, OPENAI_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
4 changes: 2 additions & 2 deletions lilac/embeddings/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..env import env
from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import compute_split_embeddings

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,6 +57,6 @@ def embed_fn(texts: list[str]) -> list[np.ndarray]:
return [np.array(response['embedding'], dtype=np.float32)]

docs = cast(Iterable[str], docs)
split_fn = split_text if self._split else None
split_fn = clustering_spacy_chunker if self._split else None
yield from compute_split_embeddings(
docs, PALM_BATCH_SIZE, embed_fn, split_fn, num_parallel_requests=NUM_PARALLEL_REQUESTS)
4 changes: 2 additions & 2 deletions lilac/embeddings/sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..schema import Item, RichData
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import split_text
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import compute_split_embeddings
from .transformer_utils import get_model

Expand Down Expand Up @@ -33,6 +33,6 @@ def compute(self, docs: Iterable[RichData]) -> Iterable[Item]:
"""Call the embedding function."""
batch_size, model = get_model(MINI_LM_MODEL, _OPTIMAL_BATCH_SIZES[MINI_LM_MODEL])
embed_fn = model.encode
split_fn = split_text if self._split else None
split_fn = clustering_spacy_chunker if self._split else None
docs = cast(Iterable[str], docs)
yield from compute_split_embeddings(docs, batch_size, embed_fn=embed_fn, split_fn=split_fn)
3 changes: 2 additions & 1 deletion lilac/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def lilac_span(start: int, end: int, metadata: dict[str, Any] = {}) -> Item:

def lilac_embedding(start: int, end: int, embedding: Optional[np.ndarray]) -> Item:
"""Creates a lilac embedding item, representing a vector with a pointer to a slice of text."""
return lilac_span(start, end, {EMBEDDING_KEY: embedding})
# Cast to int; we've had issues where start/end were np.int64, which caused downstream sadness.
return lilac_span(int(start), int(end), {EMBEDDING_KEY: embedding})


def _parse_field_like(field_like: object, dtype: Optional[Union[DataType, str]] = None) -> Field:
Expand Down

0 comments on commit abcf2b9

Please sign in to comment.