diff --git a/RecDP/pyrecdp/primitives/operations/text_split.py b/RecDP/pyrecdp/primitives/operations/text_split.py index 67b06ca60..2ee11280e 100644 --- a/RecDP/pyrecdp/primitives/operations/text_split.py +++ b/RecDP/pyrecdp/primitives/operations/text_split.py @@ -21,7 +21,7 @@ def prepare_nltk_model(model, lang): import nltk nltk.download('punkt') - prepare_model(model_type="nltk", prepare_model_func=prepare_nltk_model) + prepare_model(model_type="nltk", model_key="nltk_langchain", prepare_model_func=prepare_nltk_model) from pyrecdp.core.class_utils import new_instance splitter = new_instance("langchain.text_splitter", text_splitter, **text_splitter_args) diff --git a/RecDP/tests/test_llmutils_operations.py b/RecDP/tests/test_llmutils_operations.py index 5a7c5c4df..c4bd41af2 100644 --- a/RecDP/tests/test_llmutils_operations.py +++ b/RecDP/tests/test_llmutils_operations.py @@ -228,10 +228,10 @@ def test_document_split_ray(self): with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_rayds(ctx.ds)) - # def test_rag_text_fix_ray(self): - # op = RAGTextFix(chars_to_remove="abcdedfhijklmn") - # with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - # ctx.show(op.process_rayds(ctx.ds)) + def test_rag_text_fix_ray(self): + op = RAGTextFix(chars_to_remove="abcdedfhijklmn") + with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: + ctx.show(op.process_rayds(ctx.ds)) ### ====== Spark ====== ### @@ -393,10 +393,10 @@ def test_document_split_spark(self): with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: ctx.show(op.process_spark(ctx.spark, ctx.ds)) - # def test_rag_text_fix_spark(self): - # op = RAGTextFix(chars_to_remove="abcdedfhijklmn") - # with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: - # ctx.show(op.process_spark(ctx.spark, ctx.ds)) + def test_rag_text_fix_spark(self): + op = RAGTextFix(chars_to_remove="abcdedfhijklmn") + with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx: + ctx.show(op.process_spark(ctx.spark, ctx.ds)) def test_document_embed_ray(self): model_root_path = os.path.join(RECDP_MODELS_CACHE, "huggingface")