diff --git a/RecDP/examples/notebooks/llmutils/rag_pipeline.ipynb b/RecDP/examples/notebooks/llmutils/rag_pipeline.ipynb index 559e80a50..fc79fa170 100644 --- a/RecDP/examples/notebooks/llmutils/rag_pipeline.ipynb +++ b/RecDP/examples/notebooks/llmutils/rag_pipeline.ipynb @@ -48,9 +48,19 @@ "id": "bMqBJ9eckIs6" }, "source": [ - "### 2. Set parameters" + "## 2. Set parameters according to your environment\n" ] }, + { + "cell_type": "markdown", + "source": [ + "\n", + "### 2.1 Parametera about vector store.\n" + ], + "metadata": { + "id": "nBa-OiRcQhLr" + } + }, { "cell_type": "code", "execution_count": null, @@ -59,18 +69,76 @@ }, "outputs": [], "source": [ + "# Where to store vectore store data\n", "out_dir=/content/vs_store\n", "vector_store_type=\"FAISS\"\n", - "index_name=\"knowledge_db\"\n", - "\n", + "index_name=\"knowledge_db\"" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### 2.2 Parametera about TextSplitter" + ], + "metadata": { + "id": "PmgACKQzQv7z" + } + }, + { + "cell_type": "code", + "source": [ "text_splitter = \"RecursiveCharacterTextSplitter\"\n", - "text_splitter_args = {\"chunk_size\": 500, \"chunk_overlap\": 0}\n", - "\n", - "target_urls = [\"https://www.intc.com/news-events/press-releases/detail/1655/intel-reports-third-quarter-2023-financial-results\"]\n", - "\n", + "text_splitter_args = {\"chunk_size\": 500, \"chunk_overlap\": 0}" + ], + "metadata": { + "id": "tvXP1IysQyza" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 2.3 Parametera about Embedding" + ], + "metadata": { + "id": "WrdD1PdBQ0ax" + } + }, + { + "cell_type": "code", + "source": [ "embeddings_type=\"HuggingFaceEmbeddings\"\n", "embeddings_args={'model_name': f\"sentence-transformers/all-mpnet-base-v2\"}" - ] + ], + "metadata": { + "id": "Pr_caSYPQ21R" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### 2.4 Specify the data you need to process" + ], + "metadata": { + "id": "YjMrPVCJQ56w" + } + }, + { + "cell_type": "code", + "source": [ + "# web data\n", + "target_urls = [\"https://www.intc.com/news-events/press-releases/detail/1655/intel-reports-third-quarter-2023-financial-results\"]\n", + "# or some file data\n", + "# data_path = \"/content/my_pdf_path\"" + ], + "metadata": { + "id": "lfl_tOq6Q5fY" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -78,7 +146,7 @@ "id": "JjTnnzw_kRVV" }, "source": [ - "## 3. Extract data and build a knowledge database" + "## 3. Use recdp to extract data and build a knowledge database" ] }, { @@ -102,7 +170,7 @@ "outputs": [], "source": [ "from pyrecdp.LLM import TextPipeline\n", - "from pyrecdp.primitives.operations import Url_Loader, DocumentSplit, DocumentIngestion" + "from pyrecdp.primitives.operations import Url_Loader, DocumentSplit, DocumentIngestion, RAGTextFix" ] }, { @@ -128,7 +196,10 @@ "source": [ "pipeline = TextPipeline()\n", "ops = [\n", - " Url_Loader(urls=target_urls, target_tag='div', target_attrs={'class': 'main-content'}),\n", + " Url_Loader(urls=target_urls),\n", + " # DirectoryLoader(data_path, glob=\"**/*.pdf\"),\n", + " # Use operators provided by Recdp to process the data\n", + " RAGTextFix(),\n", " DocumentSplit(text_splitter=text_splitter, text_splitter_args=text_splitter_args),\n", " DocumentIngestion(\n", " vector_store=vector_store_type,\n", diff --git a/RecDP/pyrecdp/LLM/README.md b/RecDP/pyrecdp/LLM/README.md index ac131107c..58486b235 100644 --- a/RecDP/pyrecdp/LLM/README.md +++ b/RecDP/pyrecdp/LLM/README.md @@ -48,14 +48,23 @@ pip install pyrecdp[LLM] --pre ### Data pipeline #### 1. RAG Data Pipeline - Build from public HTML [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intel/e2eAIOK/blob/main/RecDP/examples/notebooks/llmutils/rag_pipeline.ipynb) - -``` +Retrieval-augmented generation (RAG) for large language models (LLMs) aims to improve prediction quality by using an external datastore at inference time to build a richer prompt that includes some combination of context, history, and recent/relevant knowledge (RAG LLMs). +Recdp LLM can provide a pipeline for ingesting data from a source and indexing it. We mainly provide the following capabilities. +- **Load Data**: Load your data from source. You can use `UrlLoader` or `DirectoryLoader` for this. +- **Improve Data Quality**: Clean up text for LLM RAG to use. It mainly solves the problem of sentences being split by incorrect line breaks after parsing the file, removing special characters, fixing unicode errors, and so on. +- **Split Text**: `DocumentSplit` helps break large Documents into smaller chunks. This is useful for indexing data and make it better used by the model. +- **Vector Store**: In order to retrieve your data, We provide `DocumentIngestion` use a VectorStore and Embeddings model to store and index your data. + +Here is a basic RAG Data Pipeline example: +```python from pyrecdp.primitives.operations import * from pyrecdp.LLM import TextPipeline pipeline = TextPipeline() ops = [ Url_Loader(urls=["https://www.intc.com/news-events/press-releases/detail/1655/intel-reports-third-quarter-2023-financial-results"], target_tag='div', target_attrs={'class': 'main-content'}), + # DirectoryLoader(files_path, glob="**/*.pdf"), + RAGTextFix(), DocumentSplit(), DocumentIngestion( vector_store='FAISS', diff --git a/RecDP/pyrecdp/primitives/llmutils/rag_data_extractor.py b/RecDP/pyrecdp/primitives/llmutils/rag_data_extractor.py new file mode 100644 index 000000000..4e7cf2cee --- /dev/null +++ b/RecDP/pyrecdp/primitives/llmutils/rag_data_extractor.py @@ -0,0 +1,83 @@ +import argparse +from typing import Optional, List + +from pyrecdp.core.utils import Timer +from pyrecdp.primitives.operations.logging_utils import logger + +from pyrecdp.LLM import TextPipeline +from pyrecdp.primitives.operations import UrlLoader, DocumentSplit, DocumentIngestion, RAGTextFix, DirectoryLoader + + +def rag_data_prepare( + files_path: str = None, + target_urls: List[str] = None, + text_splitter: str = "RecursiveCharacterTextSplitter", + text_splitter_args: Optional[dict] = None, + vs_output_dir: str = "recdp_vs", + vector_store_type: str = 'FAISS', + index_name: str = 'recdp_index', + embeddings_type: str = 'HuggingFaceEmbeddings', + embeddings_args: Optional[dict] = None, +): + if bool(files_path): + loader = DirectoryLoader(files_path, glob="**/*.pdf") + elif bool(target_urls): + loader = UrlLoader(urls=target_urls, target_tag='div') + else: + logger.error("You must specify at least one parameter in files_path and target_urls") + exit(1) + if text_splitter_args is None: + text_splitter_args = {"chunk_size": 500, "chunk_overlap": 0} + if embeddings_args is None: + embeddings_args = {'model_name': f"sentence-transformers/all-mpnet-base-v2"} + pipeline = TextPipeline() + ops = [ + loader, + RAGTextFix(), + DocumentSplit(text_splitter=text_splitter, text_splitter_args=text_splitter_args), + DocumentIngestion( + vector_store=vector_store_type, + vector_store_args={ + "output_dir": vs_output_dir, + "index": index_name + }, + embeddings=embeddings_type, + embeddings_args=embeddings_args + ), + ] + pipeline.add_operations(ops) + pipeline.execute() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # data_files, dup_dir, ngram_size, num_perm, bands, ranges + # pipeline = minHashLSH_prepare(df, num_perm = 256, ngram_size = 6, bands = 9, ranges = 13) + parser.add_argument("--files_path", dest="files_path", type=str) + parser.add_argument("--target_urls", dest="target_urls", type=str) + parser.add_argument("--text_splitter", dest="text_splitter", type=str, default='RecursiveCharacterTextSplitter') + parser.add_argument("--vs_output_dir", dest="vs_output_dir", type=str, default='recdp_vs') + parser.add_argument("--vector_store_type", dest="vector_store_type", type=str, default='FAISS') + parser.add_argument("--index_name", dest="index_name", type=str, default='recdp_index') + parser.add_argument("--embeddings_type", dest="embeddings_type", type=str, default='HuggingFaceEmbeddings') + args = parser.parse_args() + files_path = args.files_path + if args.target_urls: + target_urls = args.target_urls.split(",") + else: + target_urls = [] + text_splitter = args.text_splitter + vs_output_dir = args.vs_output_dir + vector_store_type = args.vector_store_type + index_name = args.index_name + embeddings_type = args.embeddings_type + + with Timer(f"Process RAG data"): + rag_data_prepare(files_path=files_path, + target_urls=target_urls, + text_splitter=text_splitter, + vs_output_dir=vs_output_dir, + vector_store_type=vector_store_type, + index_name=index_name, + embeddings_type=embeddings_type, + ) diff --git a/RecDP/pyrecdp/primitives/operations/__init__.py b/RecDP/pyrecdp/primitives/operations/__init__.py index 1bada9daa..10554b204 100644 --- a/RecDP/pyrecdp/primitives/operations/__init__.py +++ b/RecDP/pyrecdp/primitives/operations/__init__.py @@ -44,7 +44,7 @@ from .text_normalize import TextNormalize from .text_bytesize import TextBytesize from .filter import * - from .text_fixer import TextFix + from .text_fixer import TextFix, RAGTextFix from .text_language_identify import LanguageIdentify from .text_split import DocumentSplit, ParagraphsTextSplitter from .text_pii_remove import PIIRemoval @@ -65,11 +65,7 @@ from .text_perplexity_score import TextPerplexityScore from .random_select import RandomSelect from .text_ingestion import DocumentIngestion - from .doc_loader import DirectoryLoader, DocumentLoader, Url_Loader - from .text_specific_chars_remove import TextSpecificCharsRemove - from .text_unicode_fixer import TextUnicodeFixer - from .text_whitespace_normalization import TextWhitespaceNormalization - from .text_sentence_resplit import TextSentenceResplit + from .doc_loader import DirectoryLoader, DocumentLoader, UrlLoader from .text_to_qa import TextToQA except: pass diff --git a/RecDP/pyrecdp/primitives/operations/doc_loader.py b/RecDP/pyrecdp/primitives/operations/doc_loader.py index 1bda5a9d3..afab46eb8 100644 --- a/RecDP/pyrecdp/primitives/operations/doc_loader.py +++ b/RecDP/pyrecdp/primitives/operations/doc_loader.py @@ -148,7 +148,7 @@ def load_html_to_md(page_url, target_tag: str = None, target_attrs: dict = None) ) -class Url_Loader(BaseLLMOperation): +class UrlLoader(BaseLLMOperation): def __init__(self, urls: list = None, target_tag: str = None, target_attrs: dict = None, args_dict: Optional[dict] = None): settings = { @@ -183,4 +183,4 @@ def process_spark(self, spark, spark_df=None): return self.cache -LLMOPERATORS.register(Url_Loader) +LLMOPERATORS.register(UrlLoader) diff --git a/RecDP/pyrecdp/primitives/operations/text_fixer.py b/RecDP/pyrecdp/primitives/operations/text_fixer.py index 7e83d1f73..7d0f4935d 100644 --- a/RecDP/pyrecdp/primitives/operations/text_fixer.py +++ b/RecDP/pyrecdp/primitives/operations/text_fixer.py @@ -2,11 +2,14 @@ from ray.data import Dataset from pyspark.sql import DataFrame -import os +from typing import List, Union import re from typing import Dict from selectolax.parser import HTMLParser +from .constant import VARIOUS_WHITESPACES +from pyrecdp.core.model_utils import prepare_model, get_model + CPAT = re.compile("copyright", re.IGNORECASE) PAT = re.compile("/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/") @@ -258,7 +261,7 @@ def process_rayds(self, ds: Dataset) -> Dataset: if self.actual_func is None: self.actual_func = get_fixer_by_type(self.text_type) return ds.map(lambda x: self.process_row(x, self.text_key, new_name, self.actual_func)) - + def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: import pyspark.sql.functions as F fix_by_type_udf = F.udf(get_fixer_by_type(self.text_type)) @@ -270,3 +273,70 @@ def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: LLMOPERATORS.register(TextFix) + + +class RAGTextFix(BaseLLMOperation): + def __init__(self, text_key='text', chars_to_remove: Union[str, List[str]] = '◆●■►▼▲▴∆▻▷❖♡□', language: str = 'en'): + """ + Clean up text for LLM RAG to use. + Step 1: Fix unicode errors in text using ftfy + Step 2: Normalize different kinds of whitespaces to whitespace ' ' (0x20) in text + Different kinds of whitespaces can be found here: + https://en.wikipedia.org/wiki/Whitespace_character + Step 3: Clean specific chars in text. + Step 4: Re segment sentences in the text to avoid sentence segmentation errors caused by unnecessary line breaks + :param language: Supported language. Default: en. (en) + :param chars_to_remove: Chars to remove. Default: '◆●■►▼▲▴∆▻▷❖♡□' + + """ + settings = {'chars_to_remove': chars_to_remove, 'text_key': text_key, 'language': language} + super().__init__(settings) + self.support_spark = True + self.support_ray = True + self.text_key = text_key + self.inplace = True + self.chars_to_remove = chars_to_remove + self.language = language + + def process_rayds(self, ds: Dataset) -> Dataset: + remover = self.get_compute_func() + new_ds = ds.map(lambda x: self.process_row(x, self.text_key, self.text_key, remover)) + return new_ds + + def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: + import pyspark.sql.functions as F + custom_udf = F.udf(self.get_compute_func()) + return spark_df.withColumn(self.text_key, custom_udf(F.col(self.text_key))) + + def get_compute_func(self): + import ftfy + pattern = '[' + '|'.join(self.chars_to_remove) + ']' + model_key = prepare_model(lang=self.language, model_type='nltk') + nltk_model = get_model(model_key, lang=self.language, model_type='nltk') + + def compute(text): + # fix unicode errors + text = ftfy.fix_text(text) + # normalize different kinds of whitespaces + text = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' for char in text + ]) + # clean specific chars in text. + text = re.sub(pattern=pattern, repl=r'', + string=text, flags=re.DOTALL) + # Re segment sentences + paragraph_break_pattern = "\\n\s*\\n" + replace_str = '*^*^*' + text = re.sub(pattern=paragraph_break_pattern, repl=replace_str, + string=text, flags=re.DOTALL) + sentences = nltk_model.tokenize(text) + new_sentences = [] + for sentence in sentences: + new_sentences.append(sentence.replace("\n", " ")) + new_text = ' '.join(new_sentences).replace(replace_str, "\n\n") + return new_text + + return compute + + +LLMOPERATORS.register(RAGTextFix) diff --git a/RecDP/pyrecdp/primitives/operations/text_ingestion.py b/RecDP/pyrecdp/primitives/operations/text_ingestion.py index ae270e66d..434c512da 100644 --- a/RecDP/pyrecdp/primitives/operations/text_ingestion.py +++ b/RecDP/pyrecdp/primitives/operations/text_ingestion.py @@ -1,3 +1,4 @@ +import os.path from abc import ABC, abstractmethod from typing import Optional, List, Dict, Tuple, Union, Iterable @@ -28,12 +29,14 @@ def __init__(self, text_column: str, vector_store: str, text_embeddings: TextEmbeddings, embeddings_column: str, - vector_store_args: Optional[Dict] = None): + vector_store_args: Optional[Dict] = None, + override: bool = False): self.text_column = text_column self.embeddings_column = embeddings_column self.vector_store = vector_store self.text_embeddings = text_embeddings self.vector_store_args = vector_store_args + self.override = override @abstractmethod def persist(self, ds: Union[Dataset, DataFrame]): @@ -53,11 +56,15 @@ def __persist_to_faiss(self, ds): text_embeddings.append((row[self.text_column], row[self.embeddings_column])) from langchain.vectorstores.faiss import FAISS - db = FAISS.from_embeddings(text_embeddings, embedding=self.text_embeddings.underlying_embeddings()) + index_name = self.vector_store_args.get("index", "index") + if not self.override and os.path.exists(os.path.join(self.vector_store_args["output_dir"], index_name+".faiss")): + db = FAISS.load_local(self.vector_store_args["output_dir"], self.text_embeddings.underlying_embeddings(), index_name) + db.add_embeddings(text_embeddings) + else: + db = FAISS.from_embeddings(text_embeddings, embedding=self.text_embeddings.underlying_embeddings()) if "output_dir" not in self.vector_store_args: raise ValueError(f"You must have `output_dir` option specify for vector store {self.vector_store}") - index_name = self.vector_store_args.get("index", "index") db.save_local(self.vector_store_args["output_dir"], index_name) def persist(self, ds): @@ -99,6 +106,7 @@ class BaseDocumentIngestion(BaseLLMOperation, ABC): def __init__(self, text_column: str = 'text', embeddings_column: str = 'embedding', + override: bool = False, compute_min_size: Optional[int] = None, compute_max_size: Optional[int] = None, batch_size: Optional[int] = None, @@ -116,6 +124,8 @@ def __init__(self, batch_size: The batch size to use when computing the document embeddings(If embedding with Ray). num_cpus: The number of CPUs to use when computing the document embeddings(If embedding with Ray). num_gpus: The number of GPUs to use when computing the document embeddings(If embedding with Ray). + override: Whether to force override the previous vector store data. Default: False, + """ settings = settings or {} settings.update({ @@ -125,6 +135,7 @@ def __init__(self, 'batch_size': batch_size, 'num_gpus': num_gpus, 'num_cpus': num_cpus, + 'override': override }) super().__init__(settings) self.support_ray = True @@ -135,6 +146,7 @@ def __init__(self, self.batch_size = batch_size self.num_cpus = num_cpus self.num_gpus = num_gpus + self.override = override self.embeddings_column = embeddings_column self.text_embeddings = self._get_text_embeddings() self.vector_store = self._get_vector_store() @@ -189,6 +201,7 @@ def __init__(self, text_column: str = 'text', embeddings_column: str = 'embedding', vector_store: str = 'FAISS', + override: bool = False, vector_store_args: Optional[dict] = None, embeddings: str = 'HuggingFaceEmbeddings', embeddings_args: Optional[dict] = None, @@ -242,6 +255,7 @@ def __init__(self, embeddings_column=embeddings_column, compute_min_size=compute_min_size, compute_max_size=compute_max_size, + override=override, batch_size=batch_size, num_gpus=num_gpus, num_cpus=num_cpus, @@ -262,7 +276,8 @@ def _get_vector_store(self) -> VectorStore: embeddings_column=self.embeddings_column, text_embeddings=self.text_embeddings, vector_store=self.vector_store, - vector_store_args=self.vector_store_args + vector_store_args=self.vector_store_args, + override=self.override ) def _get_text_embeddings(self) -> TextEmbeddings: diff --git a/RecDP/pyrecdp/primitives/operations/text_sentence_resplit.py b/RecDP/pyrecdp/primitives/operations/text_sentence_resplit.py deleted file mode 100644 index 885ca6d1f..000000000 --- a/RecDP/pyrecdp/primitives/operations/text_sentence_resplit.py +++ /dev/null @@ -1,54 +0,0 @@ -from pyrecdp.core.model_utils import prepare_model, get_model -from pyrecdp.primitives.operations.base import BaseLLMOperation, LLMOPERATORS -from ray.data import Dataset -from pyspark.sql import DataFrame - -import re - - -class TextSentenceResplit(BaseLLMOperation): - def __init__(self, text_key='text', language: str = 'en'): - """ - Re segment sentences in the text to avoid sentence segmentation errors caused by unnecessary line breaks - - :param language: Supported language. Default: en. (en) - - """ - settings = {'language': language, 'text_key': text_key} - super().__init__(settings) - self.support_spark = True - self.support_ray = True - self.text_key = text_key - self.language = language - self.inplace = True - - def process_rayds(self, ds: Dataset) -> Dataset: - remover = self.get_compute_func() - return ds.map(lambda x: self.process_row(x, self.text_key, self.text_key, remover)) - - def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: - import pyspark.sql.functions as F - custom_udf = F.udf(self.get_compute_func()) - return spark_df.withColumn(self.text_key, custom_udf(F.col(self.text_key))) - - def get_compute_func(self): - model_key = prepare_model(lang=self.language, model_type='nltk') - nltk_model = get_model(model_key, lang=self.language, model_type='nltk') - - def compute(text): - pattern = "\\n\s*\\n" - replace_str = '*^*^*' - text = re.sub(pattern=pattern, repl=replace_str, - string=text, flags=re.DOTALL) - - sentences = nltk_model.tokenize(text) - new_sentences = [] - for sentence in sentences: - new_sentences.append(sentence.replace("\n", " ")) - new_text = ' '.join(new_sentences).replace(replace_str, "\n\n") - return new_text - - return compute - - -LLMOPERATORS.register(TextSentenceResplit) diff --git a/RecDP/pyrecdp/primitives/operations/text_specific_chars_remove.py b/RecDP/pyrecdp/primitives/operations/text_specific_chars_remove.py deleted file mode 100644 index 056b0f549..000000000 --- a/RecDP/pyrecdp/primitives/operations/text_specific_chars_remove.py +++ /dev/null @@ -1,39 +0,0 @@ -from pyrecdp.primitives.operations.base import BaseLLMOperation, LLMOPERATORS -from ray.data import Dataset -from pyspark.sql import DataFrame - -from typing import List, Union -import re - - -class TextSpecificCharsRemove(BaseLLMOperation): - def __init__(self, text_key='text', chars_to_remove: Union[str, List[str]] = '◆●■►▼▲▴∆▻▷❖♡□'): - settings = {'chars_to_remove': chars_to_remove, 'text_key': text_key} - super().__init__(settings) - self.support_spark = True - self.support_ray = True - self.text_key = text_key - self.chars_to_remove = chars_to_remove - self.inplace = True - - def process_rayds(self, ds: Dataset) -> Dataset: - remover = self.get_remover() - return ds.map(lambda x: self.process_row(x, self.text_key, self.text_key, remover)) - - def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: - import pyspark.sql.functions as F - custom_udf = F.udf(self.get_remover()) - return spark_df.withColumn(self.text_key, custom_udf(F.col(self.text_key))) - - def get_remover(self): - pattern = '[' + '|'.join(self.chars_to_remove) + ']' - - def remover(text): - text = re.sub(pattern=pattern, repl=r'', - string=text, flags=re.DOTALL) - return text - - return remover - - -LLMOPERATORS.register(TextSpecificCharsRemove) diff --git a/RecDP/pyrecdp/primitives/operations/text_unicode_fixer.py b/RecDP/pyrecdp/primitives/operations/text_unicode_fixer.py deleted file mode 100644 index 1b05ccc3a..000000000 --- a/RecDP/pyrecdp/primitives/operations/text_unicode_fixer.py +++ /dev/null @@ -1,37 +0,0 @@ -from pyrecdp.primitives.operations.base import BaseLLMOperation, LLMOPERATORS -from ray.data import Dataset -from pyspark.sql import DataFrame - - - -class TextUnicodeFixer(BaseLLMOperation): - def __init__(self, text_key='text'): - """ - Fix unicode errors in text using ftfy - """ - settings = {'text_key': text_key} - super().__init__(settings) - self.support_spark = True - self.support_ray = True - self.text_key = text_key - self.inplace = True - - def process_rayds(self, ds: Dataset) -> Dataset: - remover = self.get_compute_func() - return ds.map(lambda x: self.process_row(x, self.text_key, self.text_key, remover)) - - def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: - import pyspark.sql.functions as F - custom_udf = F.udf(self.get_compute_func()) - return spark_df.withColumn(self.text_key, custom_udf(F.col(self.text_key))) - - def get_compute_func(self): - import ftfy - - def compute(text): - new_text = ftfy.fix_text(text) - return new_text - return compute - - -LLMOPERATORS.register(TextUnicodeFixer) diff --git a/RecDP/pyrecdp/primitives/operations/text_whitespace_normalization.py b/RecDP/pyrecdp/primitives/operations/text_whitespace_normalization.py deleted file mode 100644 index 077b2ea01..000000000 --- a/RecDP/pyrecdp/primitives/operations/text_whitespace_normalization.py +++ /dev/null @@ -1,42 +0,0 @@ -from pyrecdp.primitives.operations.base import BaseLLMOperation, LLMOPERATORS -from ray.data import Dataset -from pyspark.sql import DataFrame - -from pyrecdp.primitives.operations.constant import VARIOUS_WHITESPACES - - -class TextWhitespaceNormalization(BaseLLMOperation): - def __init__(self, text_key='text'): - """ - Normalize different kinds of whitespaces to whitespace ' ' (0x20) in text - Different kinds of whitespaces can be found here: - https://en.wikipedia.org/wiki/Whitespace_character - """ - settings = {'text_key': text_key} - super().__init__(settings) - self.support_spark = True - self.support_ray = True - self.text_key = text_key - self.inplace = True - - def process_rayds(self, ds: Dataset) -> Dataset: - remover = self.get_compute_func() - return ds.map(lambda x: self.process_row(x, self.text_key, self.text_key, remover)) - - def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: - import pyspark.sql.functions as F - custom_udf = F.udf(self.get_compute_func()) - return spark_df.withColumn(self.text_key, custom_udf(F.col(self.text_key))) - - def get_compute_func(self): - def compute(text): - # replace all kinds of whitespaces with ' ' - new_text = ''.join([ - char if char not in VARIOUS_WHITESPACES else ' ' for char in text - ]) - return new_text - - return compute - - -LLMOPERATORS.register(TextWhitespaceNormalization) diff --git a/RecDP/tests/test_llmutils_operations.py b/RecDP/tests/test_llmutils_operations.py index 07e1d958c..5a7c5c4df 100644 --- a/RecDP/tests/test_llmutils_operations.py +++ b/RecDP/tests/test_llmutils_operations.py @@ -217,25 +217,21 @@ def test_gopherqualityfilter_ray(self): op = GopherQualityFilter() with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_rayds(ctx.ds)) - def test_text_specific_chars_remove_ray(self): - op = TextSpecificCharsRemove(chars_to_remove="abcdedfhijklmn") - with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_rayds(ctx.ds)) - def test_text_unicode_fixer_ray(self): - op = TextUnicodeFixer() + def test_document_load_ray(self): + op = DirectoryLoader("tests/data/llm_data/document", glob="**/*.pdf") with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_rayds(ctx.ds)) + ctx.show(op.process_rayds()) - def test_text_whitespace_normalization_ray(self): - op = TextWhitespaceNormalization() + def test_document_split_ray(self): + op = DocumentSplit() with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_rayds(ctx.ds)) - def test_sentence_resplit_ray(self): - op = TextSentenceResplit() - with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_rayds(ctx.ds)) + # def test_rag_text_fix_ray(self): + # op = RAGTextFix(chars_to_remove="abcdedfhijklmn") + # with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: + # ctx.show(op.process_rayds(ctx.ds)) ### ====== Spark ====== ### @@ -387,45 +383,20 @@ def test_gopherqualityfilter_spark(self): with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_spark(ctx.spark, ctx.ds)) - def test_document_load_ray(self): - op = DirectoryLoader("tests/data/llm_data/document", glob="**/*.pdf") - with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_rayds()) - def test_document_load_spark(self): op = DirectoryLoader("tests/data/llm_data/document", glob="**/*.pdf") with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_spark(ctx.spark)) - def test_document_split_ray(self): - op = DocumentSplit() - with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_rayds(ctx.ds)) - def test_document_split_spark(self): op = DocumentSplit() with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_spark(ctx.spark, ctx.ds)) - def test_text_specific_chars_remove_spark(self): - op = TextSpecificCharsRemove(chars_to_remove="abcdedfhijklmn") - with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_spark(ctx.spark, ctx.ds)) - - def test_unicode_fixer_spark(self): - op = TextUnicodeFixer() - with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_spark(ctx.spark, ctx.ds)) - - def test_whitespace_normalization_spark(self): - op = TextWhitespaceNormalization() - with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_spark(ctx.spark, ctx.ds)) - - def test_sentence_resplit_spark(self): - op = TextSentenceResplit() - with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - ctx.show(op.process_spark(ctx.spark, ctx.ds)) + # def test_rag_text_fix_spark(self): + # op = RAGTextFix(chars_to_remove="abcdedfhijklmn") + # with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: + # ctx.show(op.process_spark(ctx.spark, ctx.ds)) def test_document_embed_ray(self): model_root_path = os.path.join(RECDP_MODELS_CACHE, "huggingface") @@ -459,7 +430,6 @@ def test_document_embed_spark(self): with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_spark(ctx.spark, ctx.ds)) - def test_document_paragraphs_split_ray(self): model_root_path = os.path.join(RECDP_MODELS_CACHE, "huggingface") model_name = f"{model_root_path}/sentence-transformers/all-mpnet-base-v2" @@ -472,4 +442,4 @@ def test_document_paragraphs_split_spark(self): model_name = f"{model_root_path}/sentence-transformers/all-mpnet-base-v2" op = ParagraphsTextSplitter(model_name=model_name) with SparkContext("tests/data/llm_data/arxiv_sample_100.jsonl") as ctx: - ctx.show(op.process_spark(ctx.spark, ctx.ds)) \ No newline at end of file + ctx.show(op.process_spark(ctx.spark, ctx.ds)) diff --git a/RecDP/tests/test_llmutils_pipelines.py b/RecDP/tests/test_llmutils_pipelines.py index 880668a3e..7acccc8eb 100644 --- a/RecDP/tests/test_llmutils_pipelines.py +++ b/RecDP/tests/test_llmutils_pipelines.py @@ -190,9 +190,9 @@ def test_llm_rag_pipeline(self): faiss_output_dir = 'tests/data/faiss' pipeline = TextPipeline() ops = [ - Url_Loader(["https://www.intc.com/news-events/press-releases/detail/" + UrlLoader(["https://www.intc.com/news-events/press-releases/detail/" "1655/intel-reports-third-quarter-2023-financial-results"], - target_tag='div', target_attrs={'class': 'main-content'}), + target_tag='div', target_attrs={'class': 'main-content'}), DocumentSplit(text_splitter='RecursiveCharacterTextSplitter'), DocumentIngestion( vector_store='FAISS',