diff --git a/nlu/pipeline.py b/nlu/pipeline.py index 27e68ddd..7a34fa90 100644 --- a/nlu/pipeline.py +++ b/nlu/pipeline.py @@ -1194,7 +1194,7 @@ def save(self, path, component='entire_pipeline', overwrite=False): print(f'Stored model in {path}') # else : print('Please fit untrained pipeline first or predict on a String to save it') def predict(self, data, output_level='', positions=False, keep_stranger_features=True, metadata=False, - multithread=True, drop_irrelevant_cols=True, verbose=False): + multithread=True, drop_irrelevant_cols=True, verbose=False,return_spark_df=False): ''' Annotates a Pandas Dataframe/Pandas Series/Numpy Array/Spark DataFrame/Python List strings /Python String @@ -1372,8 +1372,8 @@ def predict(self, data, output_level='', positions=False, keep_stranger_features except: print( "If you use Modin, make sure you have installed 'pip install modin[ray]' or 'pip install modin[dask]' backend for Modin ") - - return self.pythonify_spark_dataframe(sdf, self.output_different_levels, + if return_spark_df : return sdf + else : return self.pythonify_spark_dataframe(sdf, self.output_different_levels, keep_stranger_features=keep_stranger_features, stranger_features=stranger_features, output_metadata=metadata, index_provided=index_provided,