Skip to content

Commit

Permalink
Merge pull request #7 from aurelio-labs/stat-chunker-batch-fix
Browse files Browse the repository at this point in the history
feat: Process large docs in batches
  • Loading branch information
jamescalam authored May 28, 2024
2 parents a544ce3 + da4e594 commit 0cd1186
Showing 1 changed file with 75 additions and 21 deletions.
96 changes: 75 additions & 21 deletions semantic_chunkers/chunkers/statistical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import Any, List

import numpy as np

Expand All @@ -9,6 +9,8 @@
from semantic_chunkers.utils.text import tiktoken_length
from semantic_chunkers.utils.logger import logger

from tqdm.auto import tqdm


@dataclass
class ChunkStatistics:
Expand Down Expand Up @@ -62,19 +64,82 @@ def __init__(
self.enable_statistics = enable_statistics
self.statistics: ChunkStatistics

def __call__(self, docs: List[str]) -> List[List[Chunk]]:
"""Chunk documents into smaller chunks based on semantic similarity.
def _chunk(
self, splits: List[Any], batch_size: int = 64, enforce_max_tokens: bool = False
) -> List[Chunk]:
"""Merge splits into chunks using semantic similarity, with optional enforcement of maximum token limits per chunk.
:param splits: Splits to be merged into chunks.
:param batch_size: Number of splits to process in one batch.
:param enforce_max_tokens: If True, further split chunks that exceed the maximum token limit.
:return: List of chunks.
"""
# Split the docs that already exceed max_split_tokens to smaller chunks
if enforce_max_tokens:
new_splits = []
for split in splits:
token_count = tiktoken_length(split)
if token_count > self.max_split_tokens:
logger.info(
f"Single document exceeds the maximum token limit "
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
_splits = self._split(split)
new_splits.extend(_splits)
else:
new_splits.append(split)

splits = [split for split in new_splits if split and split.strip()]

chunks = []
last_split = None
for i in tqdm(range(0, len(splits), batch_size)):
batch_splits = splits[i : i + batch_size]
if last_split is not None:
batch_splits = last_split.splits + batch_splits

encoded_splits = self._encode_documents(batch_splits)
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(batch_splits, similarities)
else:
self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(
batch_splits, split_indices, similarities
)

if len(doc_chunks) > 1:
chunks.extend(doc_chunks[:-1])
last_split = doc_chunks[-1]
else:
last_split = doc_chunks[0]

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)

if self.enable_statistics:
print(self.statistics)

if last_split:
chunks.append(last_split)

return chunks

def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]:
"""Split documents into smaller chunks based on semantic similarity.
:param docs: list of text documents to be split, if only wanted to
split a single document, pass it as a list with a single element.
:return: list of DocumentChunk objects containing the split documents.
:return: list of Chunk objects containing the split documents.
"""
if not docs:
raise ValueError("At least one document is required for splitting.")

all_chunks = []

for doc in docs:
token_count = tiktoken_length(doc)
if token_count > self.max_split_tokens:
Expand All @@ -83,23 +148,12 @@ def __call__(self, docs: List[str]) -> List[List[Chunk]]:
f"of {self.max_split_tokens}. "
"Splitting to sentences before semantically merging."
)
splits = self._split(doc)
encoded_splits = self._encode_documents(splits)
similarities = self._calculate_similarity_scores(encoded_splits)
if self.dynamic_threshold:
self._find_optimal_threshold(splits, similarities)
if isinstance(doc, str):
splits = self._split(doc)
doc_chunks = self._chunk(splits, batch_size=batch_size)
all_chunks.append(doc_chunks)
else:
self.calculated_threshold = self.encoder.score_threshold
split_indices = self._find_split_indices(similarities=similarities)
doc_chunks = self._split_documents(splits, split_indices, similarities)

if self.plot_chunks:
self.plot_similarity_scores(similarities, split_indices, doc_chunks)

if self.enable_statistics:
print(self.statistics)
all_chunks.append(doc_chunks)

raise ValueError("The document must be a string.")
return all_chunks

def _encode_documents(self, docs: List[str]) -> np.ndarray:
Expand Down

0 comments on commit 0cd1186

Please sign in to comment.