From 3951931a28c6f9e1c9ea2d5dac8cdc46822c66ad Mon Sep 17 00:00:00 2001 From: haojinIntel Date: Wed, 22 Nov 2023 00:21:11 +0800 Subject: [PATCH] Support to define engine for pipeline (#447) --- RecDP/pyrecdp/LLM/TextPipeline.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/RecDP/pyrecdp/LLM/TextPipeline.py b/RecDP/pyrecdp/LLM/TextPipeline.py index 5afae7d53..a55c28158 100644 --- a/RecDP/pyrecdp/LLM/TextPipeline.py +++ b/RecDP/pyrecdp/LLM/TextPipeline.py @@ -23,8 +23,9 @@ class TextPipeline(BasePipeline): - def __init__(self, pipeline_file=None): + def __init__(self, engine_name='ray', pipeline_file=None): super().__init__() + self.engine_name = engine_name if pipeline_file != None: self.import_from_json(pipeline_file) if pipeline_file.endswith( '.json') else self.import_from_yaml(pipeline_file) @@ -52,7 +53,10 @@ def check_platform(self, executable_sequence): if op.support_spark: spark_list.append(str(op)) if is_ray: - return 'ray' + if self.engine_name == 'spark' and is_spark: + return 'spark' + else: + return 'ray' elif is_spark: return 'spark' else: @@ -206,8 +210,8 @@ def evaluate(self) -> dict: class ResumableTextPipeline(TextPipeline): # Provide a pipeline for large dir. We will handle files one by one and resume when pipeline broken. - def __init__(self, pipeline_file=None): - super().__init__(pipeline_file) + def __init__(self, engine_name='ray', pipeline_file=None): + super().__init__(engine_name, pipeline_file) # Enabling this option will result in a decrease in execution speed self.statistics_flag = False