diff --git a/semantic_chunkers/chunkers/base.py b/semantic_chunkers/chunkers/base.py index 4351e57..2e55e5c 100644 --- a/semantic_chunkers/chunkers/base.py +++ b/semantic_chunkers/chunkers/base.py @@ -2,8 +2,8 @@ from colorama import Fore, Style from pydantic.v1 import BaseModel, Extra - from semantic_router.encoders.base import BaseEncoder + from semantic_chunkers.schema import Chunk from semantic_chunkers.splitters.base import BaseSplitter diff --git a/semantic_chunkers/chunkers/consecutive.py b/semantic_chunkers/chunkers/consecutive.py index ac28866..c8c8541 100644 --- a/semantic_chunkers/chunkers/consecutive.py +++ b/semantic_chunkers/chunkers/consecutive.py @@ -1,11 +1,11 @@ from typing import Any, List -from tqdm.auto import tqdm import numpy as np - from semantic_router.encoders.base import BaseEncoder -from semantic_chunkers.schema import Chunk +from tqdm.auto import tqdm + from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.schema import Chunk from semantic_chunkers.splitters.base import BaseSplitter from semantic_chunkers.splitters.sentence import RegexSplitter @@ -58,7 +58,9 @@ def _chunk(self, splits: List[Any], batch_size: int = 64) -> List[Chunk]: self.chunks = chunks return chunks - async def _async_chunk(self, splits: List[Any], batch_size: int = 64) -> List[Chunk]: + async def _async_chunk( + self, splits: List[Any], batch_size: int = 64 + ) -> List[Chunk]: """Merge splits into chunks using semantic similarity. :param splits: splits to be merged into chunks. @@ -90,7 +92,6 @@ async def _async_chunk(self, splits: List[Any], batch_size: int = 64) -> List[Ch self.chunks = chunks return chunks - def __call__(self, docs: List[Any]) -> List[List[Chunk]]: """Split documents into smaller chunks based on semantic similarity. @@ -120,4 +121,4 @@ async def acall(self, docs: List[Any]) -> List[List[Chunk]]: splits = doc doc_chunks = await self._async_chunk(splits) all_chunks.append(doc_chunks) - return all_chunks \ No newline at end of file + return all_chunks diff --git a/semantic_chunkers/chunkers/cumulative.py b/semantic_chunkers/chunkers/cumulative.py index 20a9e3a..b538b14 100644 --- a/semantic_chunkers/chunkers/cumulative.py +++ b/semantic_chunkers/chunkers/cumulative.py @@ -1,11 +1,11 @@ from typing import Any, List -from tqdm.auto import tqdm import numpy as np - from semantic_router.encoders import BaseEncoder -from semantic_chunkers.schema import Chunk +from tqdm.auto import tqdm + from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.schema import Chunk from semantic_chunkers.splitters.base import BaseSplitter from semantic_chunkers.splitters.sentence import RegexSplitter @@ -76,7 +76,9 @@ def _chunk(self, splits: List[Any], batch_size: int = 64) -> List[Chunk]: return chunks - async def _async_chunk(self, splits: List[Any], batch_size: int = 64) -> List[Chunk]: + async def _async_chunk( + self, splits: List[Any], batch_size: int = 64 + ) -> List[Chunk]: """Merge splits into chunks using semantic similarity. :param splits: splits to be merged into chunks. @@ -166,4 +168,4 @@ async def acall(self, docs: List[str]) -> List[List[Chunk]]: splits = doc doc_chunks = await self._async_chunk(splits) all_chunks.append(doc_chunks) - return all_chunks \ No newline at end of file + return all_chunks diff --git a/semantic_chunkers/chunkers/statistical.py b/semantic_chunkers/chunkers/statistical.py index addceba..97e5629 100644 --- a/semantic_chunkers/chunkers/statistical.py +++ b/semantic_chunkers/chunkers/statistical.py @@ -2,16 +2,15 @@ from typing import Any, List import numpy as np - from semantic_router.encoders.base import BaseEncoder -from semantic_chunkers.schema import Chunk +from tqdm.auto import tqdm + from semantic_chunkers.chunkers.base import BaseChunker +from semantic_chunkers.schema import Chunk from semantic_chunkers.splitters.base import BaseSplitter from semantic_chunkers.splitters.sentence import RegexSplitter -from semantic_chunkers.utils.text import tiktoken_length from semantic_chunkers.utils.logger import logger - -from tqdm.auto import tqdm +from semantic_chunkers.utils.text import tiktoken_length @dataclass @@ -236,7 +235,6 @@ def __call__(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]: raise ValueError("The document must be a string.") return all_chunks - async def acall(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk]]: """Split documents into smaller chunks based on semantic similarity. @@ -265,7 +263,6 @@ async def acall(self, docs: List[str], batch_size: int = 64) -> List[List[Chunk] raise ValueError("The document must be a string.") return all_chunks - def _encode_documents(self, docs: List[str]) -> np.ndarray: """ Encodes a list of documents into embeddings. If the number of documents diff --git a/semantic_chunkers/splitters/__init__.py b/semantic_chunkers/splitters/__init__.py index c6d858a..5b3d258 100644 --- a/semantic_chunkers/splitters/__init__.py +++ b/semantic_chunkers/splitters/__init__.py @@ -1,7 +1,6 @@ from semantic_chunkers.splitters.base import BaseSplitter from semantic_chunkers.splitters.sentence import RegexSplitter - __all__ = [ "BaseSplitter", "RegexSplitter", diff --git a/semantic_chunkers/splitters/sentence.py b/semantic_chunkers/splitters/sentence.py index cd8b2b3..667cc91 100644 --- a/semantic_chunkers/splitters/sentence.py +++ b/semantic_chunkers/splitters/sentence.py @@ -1,6 +1,7 @@ -import regex from typing import List +import regex + from semantic_chunkers.splitters.base import BaseSplitter diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index 3ec576b..8ef8139 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -2,15 +2,16 @@ import numpy as np import pytest - from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.cohere import CohereEncoder -from semantic_chunkers import BaseChunker -from semantic_chunkers import BaseSplitter -from semantic_chunkers import ConsecutiveChunker -from semantic_chunkers import CumulativeChunker -from semantic_chunkers import StatisticalChunker +from semantic_chunkers import ( + BaseChunker, + BaseSplitter, + ConsecutiveChunker, + CumulativeChunker, + StatisticalChunker, +) def test_consecutive_sim_splitter(): @@ -112,6 +113,7 @@ def test_cumulative_sim_splitter(): # The expected outcome needs to match the logic defined in your mock_encoder's side_effect assert len(splits) == 5, f"{len(splits)}" + @pytest.mark.asyncio async def test_async_cumulative_sim_splitter(): # Mock the BaseEncoder