From 19bf54b59dc0cdea4947f8537c699bf3490ae2a6 Mon Sep 17 00:00:00 2001
From: Andy Xie <anxie@redhat.com>
Date: Tue, 2 Jul 2024 14:21:44 -0400
Subject: [PATCH] Updated chunking_document.

1. Applied document-specific test splitter from Langchain in replace of original naive version.
2. Made heuristics changes to markdown file, especially using regex to trim markdown tables in attempt to fit in the whole table with limited context window.
3. For updated chunk_document() function, see Chunking_Demo.ipynb on chunking with server_ctx_size=4096, chunk_word_count=1024). Granite 7b has 4k context windows.

Signed-off-by: Andy Xie <anxie@redhat.com>
---
 src/instructlab/sdg/utils/chunking.py | 39 ++++++++++++++++++++++-----
 1 file changed, 32 insertions(+), 7 deletions(-)

diff --git a/src/instructlab/sdg/utils/chunking.py b/src/instructlab/sdg/utils/chunking.py
index 79f1e16c..0d04d41e 100644
--- a/src/instructlab/sdg/utils/chunking.py
+++ b/src/instructlab/sdg/utils/chunking.py
@@ -2,12 +2,16 @@
 
 # Standard
 from typing import List
+import logging
+import re
 
 # Third Party
-from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
 
 _DEFAULT_CHUNK_OVERLAP = 100
 
+logger = logging.getLogger(__name__)
+
 
 def _num_tokens_from_words(num_words) -> int:
     return int(num_words * 1.3)  # 1 word ~ 1.3 token
@@ -21,12 +25,24 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s
     """
     Iterates over the documents and splits them into chunks based on the word count provided by the user.
     Args:
-        documents (dict): List of documents retrieved from git (can also consist of a single document).
+        documents (list): List of documents retrieved from git (can also consist of a single document).
         server_ctx_size (int): Context window size of server.
         chunk_word_count (int): Maximum number of words to chunk a document.
     Returns:
          List[str]: List of chunked documents.
     """
+
+    # Checks for input type error
+    if isinstance(documents, str):
+        documents = [documents]
+        logger.info(
+            "Converted single string into a list of string. Assumed the string passed in is the document. Normally, chunk_document() should take a list as input."
+        )
+    elif not isinstance(documents, list):
+        raise TypeError(
+            "Expected: documents to be a list, but got {}".format(type(documents))
+        )
+
     no_tokens_per_doc = _num_tokens_from_words(chunk_word_count)
     if no_tokens_per_doc > int(server_ctx_size - 1024):
         raise ValueError(
@@ -36,15 +52,24 @@ def chunk_document(documents: List, server_ctx_size, chunk_word_count) -> List[s
                 )
             )
         )
+    # Placeholder for params
     content = []
-    text_splitter = RecursiveCharacterTextSplitter(
-        separators=["\n\n", "\n", " "],
-        chunk_size=_num_chars_from_tokens(no_tokens_per_doc),
-        chunk_overlap=_DEFAULT_CHUNK_OVERLAP,
+    chunk_size = _num_chars_from_tokens(no_tokens_per_doc)
+    chunk_overlap = _DEFAULT_CHUNK_OVERLAP
+
+    # Using Markdown as default, document-specific chunking will be implemented in seperate pr.
+    text_splitter = RecursiveCharacterTextSplitter.from_language(
+        language=Language.MARKDOWN,
+        chunk_size=chunk_size,
+        chunk_overlap=chunk_overlap,
     )
 
+    # Determine file type for heuristics, default with markdown
     for docs in documents:
+        # Use regex to remove unnecessary dashes in front of pipe characters in a markdown table.
+        docs = re.sub(r"-{2,}\|", "-|", docs)
+        # Remove unnecessary spaces in front of pipe characters in a markdown table.
+        docs = re.sub(r"\  +\|", " |", docs)
         temp = text_splitter.create_documents([docs])
         content.extend([item.page_content for item in temp])
-
     return content