diff --git a/src/instructlab/sdg/utils/chunking.py b/src/instructlab/sdg/utils/chunking.py index 160e9f96..e2b9843a 100644 --- a/src/instructlab/sdg/utils/chunking.py +++ b/src/instructlab/sdg/utils/chunking.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from typing import List import logging import re +from pathlib import Path +from typing import List # Third Party from langchain_text_splitters import Language, RecursiveCharacterTextSplitter +from instructlab.sdg.utils.docprocessor import DocProcessor _DEFAULT_CHUNK_OVERLAP = 100 @@ -21,28 +23,69 @@ def _num_chars_from_tokens(num_tokens) -> int: return int(num_tokens * 4) # 1 token ~ 4 English character -def chunk_document(documents: List | str, server_ctx_size, chunk_word_count) -> List[str]: +def _extract_filetypes_from_docs(documents: List): + """Separate documents into lists based on their filetype. + + Currently, only Markdown and PDF are supported. + Args: + documents (list): List of documents retrieved from git (can also consist of a single document). + Returns: + (List[str], List[str]): Lists of Markdown and PDF documents, respectively + """ + md_docs = [] + pdf_docs = [] + + for doc in documents: + filetype = doc.rsplit(".") + if filetype == "md": + md_docs.append(doc) + elif filetype == "pdf": + pdf_docs.append(doc) + else: + raise ValueError(f"Received document of type .{filetype}, which is not a supported filetype") + + return md_docs, pdf_docs + + +def chunk_document(documents: List | str, server_ctx_size, chunk_word_count, qna_yaml_path=None) -> List[str]: """ - Iterates over the documents and splits them into chunks based on the word count provided by the user. + Iterate over the documents and split them into chunks based on the word count provided by the user. Args: documents (list): List of documents retrieved from git (can also consist of a single document). server_ctx_size (int): Context window size of server. chunk_word_count (int): Maximum number of words to chunk a document. + qna_yaml_path (TODO): Path to the qna_yaml corresponding with these documents. Returns: List[str]: List of chunked documents. """ - - # Checks for input type error + # Check for input type error if isinstance(documents, str): documents = [documents] logger.info( "Converted single string into list of strings. Assumed the string passed in is the document. Normally, chunk_document() should take a list as input." ) + elif not isinstance(documents, list): - raise TypeError( - "Expected: documents to be a list, but got {}".format(type(documents)) - ) + raise TypeError(f"Expected list of documents but got {type(documents)}") + + md_docs, pdf_docs = _extract_filetypes_from_docs(documents) + chunked_mds = chunk_markdowns(md_docs, server_ctx_size, chunk_word_count) + chunked_pdfs = chunk_pdfs(pdf_docs) + + + return chunked_mds + chunked_pdfs + + +def chunk_markdowns(documents: List | str, server_ctx_size, chunk_word_count) -> List[str]: + """Naively chunk markdown documents based on the word count provided by the user. + Args: + documents (list): List of markdown documents. + server_ctx_size (int): Context window size of server. + chunk_word_count (int): Maximum number of words to chunk a document. + Returns: + List[str]: List of chunked documents. + """ num_tokens_per_doc = _num_tokens_from_words(chunk_word_count) if num_tokens_per_doc > int(server_ctx_size - 1024): raise ValueError( @@ -66,16 +109,24 @@ def chunk_document(documents: List | str, server_ctx_size, chunk_word_count) -> # Determine file type for heuristics, default with markdown for doc in documents: - filetype = doc.rsplit(".") - if filetype == "md": - # Use regex to remove unnecessary dashes in front of pipe characters in a markdown table. - doc = re.sub(r"-{2,}\|", "-|", doc) - # Remove unnecessary spaces in front of pipe characters in a markdown table. - doc = re.sub(r"\ +\|", " |", doc) - temp = md_text_splitter.create_documents([doc]) - content.extend([item.page_content for item in temp]) - elif filetype == "pdf": - pass - else: - raise ValueError(f"Received document of type .{filetype}, which is not a supported filetype") + # Use regex to remove unnecessary dashes in front of pipe characters in a markdown table. + doc = re.sub(r"-{2,}\|", "-|", doc) + # Remove unnecessary spaces in front of pipe characters in a markdown table. + doc = re.sub(r"\ +\|", " |", doc) + temp = md_text_splitter.create_documents([doc]) + content.extend([item.page_content for item in temp]) return content + +def chunk_pdfs(pdf_docs: List, qna_yaml_path=None): + """Semantically chunk PDF documents. + + TODO + """ + tokenizer_name = "TODO" + + chunked_pdfs = [] + for doc in pdf_docs: + dp = DocProcessor(Path(doc).parent, tokenizer_name, user_config_path=qna_yaml_path) + # TODO + + return chunked_pdfs