Skip to content

Commit

Permalink
Merge pull request #65 from PalmPalm7/main
Browse files Browse the repository at this point in the history
Updated chunking_document.
  • Loading branch information
russellb authored Jul 6, 2024
2 parents 6251693 + 19bf54b commit 6842582
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions src/instructlab/sdg/utils/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

# Standard
from typing import List
import logging
import re

# Third Party
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter

_DEFAULT_CHUNK_OVERLAP = 100

logger = logging.getLogger(__name__)


def _num_tokens_from_words(num_words) -> int:
return int(num_words * 1.3) # 1 word ~ 1.3 token
Expand All @@ -21,12 +25,24 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s
"""
Iterates over the documents and splits them into chunks based on the word count provided by the user.
Args:
documents (dict): List of documents retrieved from git (can also consist of a single document).
documents (list): List of documents retrieved from git (can also consist of a single document).
server_ctx_size (int): Context window size of server.
chunk_word_count (int): Maximum number of words to chunk a document.
Returns:
List[str]: List of chunked documents.
"""

# Checks for input type error
if isinstance(documents, str):
documents = [documents]
logger.info(
"Converted single string into a list of string. Assumed the string passed in is the document. Normally, chunk_document() should take a list as input."
)
elif not isinstance(documents, list):
raise TypeError(
"Expected: documents to be a list, but got {}".format(type(documents))
)

no_tokens_per_doc = _num_tokens_from_words(chunk_word_count)
if no_tokens_per_doc > int(server_ctx_size - 1024):
raise ValueError(
Expand All @@ -36,15 +52,24 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s
)
)
)
# Placeholder for params
content = []
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " "],
chunk_size=_num_chars_from_tokens(no_tokens_per_doc),
chunk_overlap=_DEFAULT_CHUNK_OVERLAP,
chunk_size = _num_chars_from_tokens(no_tokens_per_doc)
chunk_overlap = _DEFAULT_CHUNK_OVERLAP

# Using Markdown as default, document-specific chunking will be implemented in seperate pr.
text_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.MARKDOWN,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)

# Determine file type for heuristics, default with markdown
for docs in documents:
# Use regex to remove unnecessary dashes in front of pipe characters in a markdown table.
docs = re.sub(r"-{2,}\|", "-|", docs)
# Remove unnecessary spaces in front of pipe characters in a markdown table.
docs = re.sub(r"\ +\|", " |", docs)
temp = text_splitter.create_documents([docs])
content.extend([item.page_content for item in temp])

return content

0 comments on commit 6842582

Please sign in to comment.