This repository has been archived by the owner on Jul 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support to generate perplexity_score and suport to store rouge score …
…column. (#427) .
- Loading branch information
Showing
5 changed files
with
154 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import argparse | ||
|
||
from pyrecdp.core.utils import Timer | ||
from pyrecdp.primitives.operations import JsonlReader, ParquetReader, PerfileParquetWriter | ||
|
||
|
||
def perplexity_score_spark(spark_df, language: str = 'en'): | ||
from pyrecdp.primitives.operations import TextPerplexityScore | ||
op = TextPerplexityScore(language=language) | ||
ret = op.process_spark(spark_df.sparkSession, spark_df) | ||
return ret | ||
|
||
|
||
def perplexity_score(data_dir, out_dir, data_file_type="jsonl", language: str = 'en'): | ||
from pyrecdp.primitives.operations import TextPerplexityScore | ||
from pyrecdp.LLM import ResumableTextPipeline | ||
|
||
if data_file_type == 'jsonl': | ||
reader = JsonlReader(data_dir) | ||
elif data_file_type == 'parquet': | ||
reader = ParquetReader(data_dir) | ||
else: | ||
raise NotImplementedError(f"{data_file_type} is not supported in RecDP LLM ResumableTextPipeline yet.") | ||
|
||
pipeline = ResumableTextPipeline() | ||
ops = [ | ||
reader, | ||
TextPerplexityScore(language=language), | ||
PerfileParquetWriter(out_dir) | ||
] | ||
pipeline.add_operations(ops) | ||
pipeline.execute() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--data_dir", dest="data_dir", type=str) | ||
parser.add_argument("--data_file_type", dest="data_file_type", type=str, default="jsonl") | ||
parser.add_argument("--output_dir", dest="output_dir", type=str, default="") | ||
parser.add_argument("--language", dest="language", type=str, default="en") | ||
args = parser.parse_args() | ||
|
||
data_dir = args.data_dir | ||
data_file_type = args.data_file_type | ||
output_dir = args.output_dir | ||
language = args.language | ||
with Timer(f"Generate perplexity score for {data_dir}"): | ||
perplexity_score(data_dir, output_dir, data_file_type, language) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
RecDP/pyrecdp/primitives/operations/text_perplexity_score.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from .base import BaseLLMOperation | ||
from ray.data import Dataset | ||
from pyspark.sql import DataFrame | ||
from pyrecdp.core.model_utils import get_model, prepare_model | ||
from pyrecdp.primitives.operations.base import LLMOPERATORS | ||
from pyrecdp.primitives.operations.utils import get_words_from_document | ||
|
||
|
||
def text_bytesize(s): | ||
return len(s.encode('utf-8')) | ||
|
||
|
||
class TextPerplexityScore(BaseLLMOperation): | ||
def __init__(self, language: str = 'en'): | ||
""" | ||
Generate perplexity score | ||
:param language: Sample in which language. Default: en.(en, zh) | ||
""" | ||
settings = {'language': language} | ||
super().__init__(args_dict=settings) | ||
self.language = language | ||
self.text_key = 'text' | ||
self.inplace = False | ||
self.sp_model_key = prepare_model(lang=language, | ||
model_type='sentencepiece') | ||
self.kl_model_key = prepare_model(lang=language, model_type='kenlm') | ||
self.tokenizer = get_model(self.sp_model_key, self.language, 'sentencepiece') | ||
self.kenlm_model = get_model(self.kl_model_key, self.language, 'kenlm') | ||
|
||
def process_rayds(self, ds: Dataset) -> Dataset: | ||
if self.inplace: | ||
raise NotImplementedError("We don't inplace modify text with normalization") | ||
else: | ||
new_name = 'perplexity' | ||
compute_func = self.get_compute_func() | ||
return ds.map(lambda x: self.process_row(x, self.text_key, new_name, compute_func)) | ||
|
||
def process_spark(self, spark, spark_df: DataFrame) -> DataFrame: | ||
import pyspark.sql.functions as F | ||
from pyspark.sql import types as T | ||
bytesize_udf = F.udf(self.get_compute_func(), T.FloatType()) | ||
return spark_df.withColumn("perplexity", bytesize_udf(F.col(self.text_key))) | ||
|
||
def get_compute_func(self, *args, **kwargs): | ||
tokenizer = self.tokenizer | ||
kenlm_model = self.kenlm_model | ||
|
||
def compute(text): | ||
words = get_words_from_document( | ||
text, | ||
token_func=tokenizer.encode_as_pieces if tokenizer else None) | ||
join_text = ' '.join(words) | ||
# compute perplexity | ||
logits, length = 0, 0 | ||
for line in join_text.splitlines(): | ||
logits += kenlm_model.score(line) | ||
length += (len(line.split()) + 1) | ||
ppl = (10.0 ** (-logits / length)) if length != 0 else 0.0 | ||
perplexity = round(ppl, 1) | ||
return perplexity | ||
|
||
return compute | ||
|
||
|
||
LLMOPERATORS.register(TextPerplexityScore) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters