Skip to content

Commit

Permalink
Fix linting and formatting for context-aware chunking
Browse files Browse the repository at this point in the history
Signed-off-by: Khaled Sulayman <[email protected]>
  • Loading branch information
khaledsulayman committed Nov 6, 2024
1 parent 1fe1f76 commit 6851ad5
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 80 deletions.
10 changes: 7 additions & 3 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

# Third Party
# instructlab - All of these need to go away (other than sdg) - issue #6
from datasets import Dataset
from xdg_base_dirs import xdg_data_dirs, xdg_data_home
import openai

Expand Down Expand Up @@ -367,12 +366,17 @@ def generate_data(
is_knowledge = False
leaf_node_path = leaf_node[0]["taxonomy_path"].replace("->", "_")
samples = leaf_node_to_samples(
leaf_node, taxonomy, server_ctx_size, chunk_word_count, document_output_dir, model_name
leaf_node,
taxonomy,
server_ctx_size,
chunk_word_count,
document_output_dir,
model_name,
)

if not samples:
raise GenerateException("Error: No samples found in leaf node.")

if "document" in samples.column_names:
pipe = knowledge_pipe
is_knowledge = True
Expand Down
119 changes: 66 additions & 53 deletions src/instructlab/sdg/utils/chunkers.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import json
import logging
import re
import yaml
# Standard
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Iterable, List, Tuple, DefaultDict
from typing import DefaultDict, Iterable, List, Tuple
import json
import logging
import re

# Third Party
from datasets import Dataset, concatenate_datasets
from docling.datamodel.base_models import PipelineOptions
from datasets import Dataset
from docling.datamodel.document import ConvertedDocument, DocumentConversionInput
from docling.document_converter import ConversionStatus, DocumentConverter
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from tabulate import tabulate
from transformers import AutoTokenizer

import yaml

logger = logging.getLogger(__name__)
_DEFAULT_CHUNK_OVERLAP = 100
Expand All @@ -25,6 +24,7 @@
def _num_tokens_from_words(num_words) -> int:
return int(num_words * 1.3) # 1 word ~ 1.3 token


def _num_chars_from_tokens(num_tokens) -> int:
return int(num_tokens * 4) # 1 token ~ 4 English character

Expand All @@ -36,34 +36,35 @@ class FileTypes(Enum):

class ChunkerBase(ABC):
@abstractmethod
def chunk_documents():
def chunk_documents(self):
pass


class DocumentChunker:
"""A factory chunker class that instantiates the applicable chunker
Currently, only Markdown and PDF are supported. For Markdown, returns
TextSplitChunker, and for PDF, returns ContextAwareChunker"""

def __new__(
cls,
leaf_node = None,
taxonomy_path = None,
output_dir: Path = None,
leaf_node,
taxonomy_path,
output_dir: Path,
server_ctx_size=4096,
chunk_word_count=1024,
tokenizer_model_name: str = None,
tokenizer_model_name: str | None = None,
):
"""Insantiate the appropriate chunker for the provided document
Args:
leaf_node: a leaf node dict containing "documents",
"filepaths", and "taxonomy_path" keys
output_dir (Path): directory where artifacts should be stored
server_ctx_size (int): Context window size of server
chunk_word_count (int): Maximum number of words to chunk a document
tokenizer_model_name (str): name of huggingface model to get
tokenizer from
tokenizer from
Returns:
TextSplitChunker | ContextAwareChunker: Object of the appropriate
chunker class for the provided filetype
Expand All @@ -88,47 +89,51 @@ def __new__(

doc_dict = cls._split_docs_by_filetype(documents, filepaths)
if len(doc_dict.keys()) > 1:
raise ValueError(f"Received multiple document types")
raise ValueError("Received multiple document types")

if FileTypes.MD in doc_dict:
doc_contents = [d for d, _ in doc_dict[FileTypes.MD]]
return TextSplitChunker(
doc_dict[FileTypes.MD],
doc_contents,
server_ctx_size,
chunk_word_count,
output_dir,
)

if FileTypes.PDF in doc_dict:
doc_paths = [p for _, p in doc_dict[FileTypes.PDF]]
return ContextAwareChunker(
doc_dict[FileTypes.PDF],
doc_paths,
filepaths,
taxonomy_path / leaf_node_path / "qna.yaml",
output_dir,
output_dir,
chunk_word_count,
tokenizer_model_name,
)

@staticmethod
def _split_docs_by_filetype(documents: List[str], filepaths: List[Path]) -> defaultdict[any, List]:
def _split_docs_by_filetype(
documents: List[str], filepaths: List[Path]
) -> DefaultDict[FileTypes, List[Tuple[str, Path]]]:
"""Separate documents into lists based on their filetype.
Currently, only Markdown and PDF are supported.
Args:
documents (List[str]): A list of the document contents as strings
filepaths (List[Path]): Corresponding document filepaths
Returns:
defaultdict: Dictionary with either ".md" or ".pdf" as a key.
DefaultDict: Dictionary with either ".md" or ".pdf" as a key.
Markdown items contain document contents, PDF items contain
paths to documents.
"""
doc_dict = defaultdict(list)
for doc, path in zip(documents, filepaths):
if path.suffix == ".md":
# append doc contents
doc_dict[FileTypes.MD].append(doc)
doc_dict[FileTypes.MD].append((doc, path))
elif path.suffix == ".pdf":
# append doc paths
doc_dict[FileTypes.PDF].append(path)
doc_dict[FileTypes.PDF].append((doc, path))
else:
raise ValueError(
f"Received document of type .{path.suffix}, which is not a supported filetype"
Expand Down Expand Up @@ -170,7 +175,7 @@ def chunk_documents(self) -> List:
return chunk_markdowns(self.document_contents, chunk_size)


class ContextAwareChunker(ChunkerBase):
class ContextAwareChunker(ChunkerBase): # pylint: disable=too-many-instance-attributes
def __init__(
self,
document_paths,
Expand All @@ -185,7 +190,11 @@ def __init__(
self.leaf_node_path = leaf_node_path
self.output_dir = self._path_validator(output_dir)
self.chunk_word_count = chunk_word_count
self.tokenizer_model_name = tokenizer_model_name if tokenizer_model_name is not None else "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.tokenizer_model_name = (
tokenizer_model_name
if tokenizer_model_name is not None
else "mistralai/Mixtral-8x7B-Instruct-v0.1"
)
self.qna_yaml = self._load_qna_yaml(
self._path_validator(leaf_node_path) if leaf_node_path else None
)
Expand Down Expand Up @@ -229,7 +238,7 @@ def _path_validator(self, path) -> Path:
raise FileNotFoundError(f"{path} does not exist.")
return path

def _load_qna_yaml(self, qna_yaml_path: Path) -> dict:
def _load_qna_yaml(self, qna_yaml_path: Path | None) -> dict:
"""
Load the qna YAML file.
Args:
Expand All @@ -254,7 +263,6 @@ def _process_parsed_docling_json(self, json_fp: Path) -> Dataset:
with open(json_fp, "r", encoding="utf-8") as f:
data = json.load(f)

file_name = json_fp.stem
chunks = self.build_chunks_from_docling_json(
data,
max_token_per_chunk=500,
Expand All @@ -265,8 +273,10 @@ def _process_parsed_docling_json(self, json_fp: Path) -> Dataset:
num_tokens_per_doc = _num_tokens_from_words(self.chunk_word_count)
chunk_size = _num_chars_from_tokens(num_tokens_per_doc)
return chunk_markdowns(fused_texts, chunk_size)

def fuse_texts(self, text_list: List, short_length_threshold: int = 130):

def fuse_texts(
self, text_list: List, short_length_threshold: int = 130
) -> List[str]:
"""
Fuse short texts with preceding longer texts if their token count is below the threshold.
Args:
Expand All @@ -277,11 +287,13 @@ def fuse_texts(self, text_list: List, short_length_threshold: int = 130):
Returns:
list: List of fused texts.
"""
fused_texts = []
fused_texts: List[str] = []
previous_long_text = ""

for text in text_list:
token_count = self.get_token_count(text, self.tokenizer) # Use tokenizer for token count
token_count = self.get_token_count(
text, self.tokenizer
) # Use tokenizer for token count

if token_count <= short_length_threshold and previous_long_text:
# Append the short text to the last long text
Expand All @@ -292,7 +304,7 @@ def fuse_texts(self, text_list: List, short_length_threshold: int = 130):
previous_long_text = text

return fused_texts

def create_tokenizer(self, model_name: str):
"""
Create a tokenizer instance from a pre-trained model or a local directory.
Expand All @@ -311,7 +323,6 @@ def create_tokenizer(self, model_name: str):
logger.error(f"Failed to load tokenizer from {model_name}: {str(e)}")
raise


def get_token_count(self, text, tokenizer):
"""
Get the number of tokens in a text using the provided tokenizer.
Expand All @@ -323,7 +334,6 @@ def get_token_count(self, text, tokenizer):
"""
return len(tokenizer.tokenize(text))


def add_heading_formatting(self, text):
"""
Add heading formatting to the text if the first part is short.
Expand All @@ -341,7 +351,6 @@ def add_heading_formatting(self, text):
text = ".".join(text)
return text


def generate_table_from_parsed_rep(self, item):
"""
Generate the table from the parsed representation and return as a string.
Expand All @@ -360,9 +369,9 @@ def generate_table_from_parsed_rep(self, item):
return ""

table = []
for i, row in enumerate(data):
for _, row in enumerate(data):
trow = []
for j, cell in enumerate(row):
for _, cell in enumerate(row):
trow.append(cell["text"])
table.append(trow)

Expand All @@ -371,7 +380,6 @@ def generate_table_from_parsed_rep(self, item):
table_text += f"\nCaption: {caption}\n"
return table_text


def get_table(self, json_book, table_ref):
"""
Retrieve a table from a document based on a reference string.
Expand All @@ -382,10 +390,11 @@ def get_table(self, json_book, table_ref):
str: Formatted table string.
"""
parts = table_ref.split("/")
table_text = self.generate_table_from_parsed_rep(json_book[parts[1]][int(parts[2])])
table_text = self.generate_table_from_parsed_rep(
json_book[parts[1]][int(parts[2])]
)
return table_text


def get_table_page_number(self, json_book, idx):
"""
Get the page number of a table or other document element.
Expand All @@ -407,11 +416,10 @@ def get_table_page_number(self, json_book, idx):
if prev_page_num is not None and next_page_num is not None:
if prev_page_num == next_page_num:
return prev_page_num
else:
return next_page_num
elif prev_page_num is not None:
return next_page_num
if prev_page_num is not None:
return prev_page_num
elif next_page_num is not None:
if next_page_num is not None:
return next_page_num

def build_chunks_from_docling_json(
Expand Down Expand Up @@ -448,7 +456,7 @@ def build_chunks_from_docling_json(
"page-header",
]:
continue
elif book_element["type"] == "footnote":
if book_element["type"] == "footnote":
current_book_page_number = book_element["prov"][0]["page"]
elif book_element["type"] in [
"subtitle-level-1",
Expand All @@ -458,7 +466,9 @@ def build_chunks_from_docling_json(
"equation",
]: # 'page-header',
if book_element["type"] == "table":
current_book_page_number = self.get_table_page_number(json_book, idx)
current_book_page_number = self.get_table_page_number(
json_book, idx
)
else:
current_book_page_number = book_element["prov"][0]["page"]
book_text = book_element["text"]
Expand Down Expand Up @@ -492,16 +502,20 @@ def build_chunks_from_docling_json(
>= max_token_per_chunk
and len(current_buffer) > 1
):
chunk_text = '\n\n'.join(current_buffer[:-1])
print(f"Current chunk size {self.get_token_count(chunk_text, tokenizer)} and max is {max_token_per_chunk}")
chunk_text = "\n\n".join(current_buffer[:-1])
print(
f"Current chunk size {self.get_token_count(chunk_text, tokenizer)} and max is {max_token_per_chunk}"
)

document_chunks.append("\n\n".join(current_buffer[:-1]))

if (
self.get_token_count(current_buffer[-1], tokenizer)
>= max_token_per_chunk
):
print(f"This is too big a document to be left in the current buffer {self.get_token_count(current_buffer[-1], tokenizer)}")
print(
f"This is too big a document to be left in the current buffer {self.get_token_count(current_buffer[-1], tokenizer)}"
)
document_chunks.append(current_buffer[-1])
current_buffer = []
else:
Expand All @@ -518,17 +532,16 @@ def build_chunks_from_docling_json(

try:
prev_page_number = current_book_page_number
except Exception as e:
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(f"Error processing book element: {book_element}, {str(e)}")

if "\n\n".join(current_buffer) not in document_chunks:
document_chunks.append("\n\n".join(current_buffer))
return document_chunks


def export_documents(self, converted_docs: Iterable[ConvertedDocument]):
"""Write converted documents to json files
Check for successful conversions and write those to the docling artifacts directory.
Returns:
Path: path to directory with docling json artifacts
Expand Down
Loading

0 comments on commit 6851ad5

Please sign in to comment.