Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
Support to generate perplexity_score and suport to store rouge score …
Browse files Browse the repository at this point in the history
…column. (#427)

.
  • Loading branch information
yao531441 authored Nov 1, 2023
1 parent 8e82905 commit 9c62d4d
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 18 deletions.
48 changes: 48 additions & 0 deletions RecDP/pyrecdp/primitives/llmutils/perplexity_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import argparse

from pyrecdp.core.utils import Timer
from pyrecdp.primitives.operations import JsonlReader, ParquetReader, PerfileParquetWriter


def perplexity_score_spark(spark_df, language: str = 'en'):
from pyrecdp.primitives.operations import TextPerplexityScore
op = TextPerplexityScore(language=language)
ret = op.process_spark(spark_df.sparkSession, spark_df)
return ret


def perplexity_score(data_dir, out_dir, data_file_type="jsonl", language: str = 'en'):
from pyrecdp.primitives.operations import TextPerplexityScore
from pyrecdp.LLM import ResumableTextPipeline

if data_file_type == 'jsonl':
reader = JsonlReader(data_dir)
elif data_file_type == 'parquet':
reader = ParquetReader(data_dir)
else:
raise NotImplementedError(f"{data_file_type} is not supported in RecDP LLM ResumableTextPipeline yet.")

pipeline = ResumableTextPipeline()
ops = [
reader,
TextPerplexityScore(language=language),
PerfileParquetWriter(out_dir)
]
pipeline.add_operations(ops)
pipeline.execute()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", dest="data_dir", type=str)
parser.add_argument("--data_file_type", dest="data_file_type", type=str, default="jsonl")
parser.add_argument("--output_dir", dest="output_dir", type=str, default="")
parser.add_argument("--language", dest="language", type=str, default="en")
args = parser.parse_args()

data_dir = args.data_dir
data_file_type = args.data_file_type
output_dir = args.output_dir
language = args.language
with Timer(f"Generate perplexity score for {data_dir}"):
perplexity_score(data_dir, output_dir, data_file_type, language)
1 change: 1 addition & 0 deletions RecDP/pyrecdp/primitives/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .text_custom import TextCustomerMap, TextCustomerFilter
from .text_toxicity import TextToxicity
from .text_rouge_score_dedup import RougeScoreDedup
from .text_perplexity_score import TextPerplexityScore
66 changes: 66 additions & 0 deletions RecDP/pyrecdp/primitives/operations/text_perplexity_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from .base import BaseLLMOperation
from ray.data import Dataset
from pyspark.sql import DataFrame
from pyrecdp.core.model_utils import get_model, prepare_model
from pyrecdp.primitives.operations.base import LLMOPERATORS
from pyrecdp.primitives.operations.utils import get_words_from_document


def text_bytesize(s):
return len(s.encode('utf-8'))


class TextPerplexityScore(BaseLLMOperation):
def __init__(self, language: str = 'en'):
"""
Generate perplexity score
:param language: Sample in which language. Default: en.(en, zh)
"""
settings = {'language': language}
super().__init__(args_dict=settings)
self.language = language
self.text_key = 'text'
self.inplace = False
self.sp_model_key = prepare_model(lang=language,
model_type='sentencepiece')
self.kl_model_key = prepare_model(lang=language, model_type='kenlm')
self.tokenizer = get_model(self.sp_model_key, self.language, 'sentencepiece')
self.kenlm_model = get_model(self.kl_model_key, self.language, 'kenlm')

def process_rayds(self, ds: Dataset) -> Dataset:
if self.inplace:
raise NotImplementedError("We don't inplace modify text with normalization")
else:
new_name = 'perplexity'
compute_func = self.get_compute_func()
return ds.map(lambda x: self.process_row(x, self.text_key, new_name, compute_func))

def process_spark(self, spark, spark_df: DataFrame) -> DataFrame:
import pyspark.sql.functions as F
from pyspark.sql import types as T
bytesize_udf = F.udf(self.get_compute_func(), T.FloatType())
return spark_df.withColumn("perplexity", bytesize_udf(F.col(self.text_key)))

def get_compute_func(self, *args, **kwargs):
tokenizer = self.tokenizer
kenlm_model = self.kenlm_model

def compute(text):
words = get_words_from_document(
text,
token_func=tokenizer.encode_as_pieces if tokenizer else None)
join_text = ' '.join(words)
# compute perplexity
logits, length = 0, 0
for line in join_text.splitlines():
logits += kenlm_model.score(line)
length += (len(line.split()) + 1)
ppl = (10.0 ** (-logits / length)) if length != 0 else 0.0
perplexity = round(ppl, 1)
return perplexity

return compute


LLMOPERATORS.register(TextPerplexityScore)
41 changes: 29 additions & 12 deletions RecDP/pyrecdp/primitives/operations/text_rouge_score_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,21 @@ def split2df(prod_df, limit_size, spark):


class RougeScoreDedup(BaseLLMOperation):
def __init__(self, text_key='text', max_ratio=0.7, batch_size=1000):
settings = {'text_key': text_key, 'max_ratio': max_ratio, 'batch_size': batch_size}
def __init__(self, text_key='text', max_ratio=0.7, batch_size=1000, score_store_path='/root/qyao/gitspace/e2eAIOK/RecDP/tests/data/filter_out/filtered'):
settings = {'text_key': text_key, 'max_ratio': max_ratio, 'batch_size': batch_size, "score_store_path": score_store_path}
super().__init__(settings)
self.text_key = text_key
self.max_ratio = max_ratio
self.batch_size = batch_size
self.score_store_path = score_store_path
self.rouge_type = 'rougeL'
self.support_spark = True
self.support_ray = False

def process_rayds(self, ds=None):
total_rows = ds.count()
line_num = []
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
scorer = rouge_scorer.RougeScorer([self.rouge_type], use_stemmer=False)
for i in range(1, total_rows):

d1, d2, d3 = ds.split_at_indices([i, i + 1])
Expand All @@ -62,20 +64,23 @@ def process_row(sample, target_token):
ds_score: Dataset = d1.map(lambda x: process_row(x, instruction_token))
if i == 1:
filterd_ds = d1
if ds_score.max("rouge_score") < 0.7:
if ds_score.max("rouge_score") < self.max_ratio:
filterd_ds = filterd_ds.union(d2)

return filterd_ds

def process_spark(self, spark, spark_df: DataFrame) -> DataFrame:
rouge_type = self.rouge_type
rouge_score_column_name = "rouge_score"
max_ratio = self.max_ratio

instruction_df_1 = (spark_df.select(self.text_key).rdd.zipWithIndex().toDF()
.select("_1.*", "_2").withColumnRenamed("_2", "id_1"))
.select("_1.*", "_2").withColumnRenamed("_2", "id_1"))
instruction_df_1 = instruction_df_1.withColumnRenamed(self.text_key, "instruction")
instruction_df_2 = (instruction_df_1.withColumnRenamed("id_1", "id_2")
.withColumnRenamed("instruction", "instruction_2"))
.withColumnRenamed("instruction", "instruction_2"))

max_ratio = self.max_ratio
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)
scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=False)

def gen_id(id_1, id_2):
if id_1 == id_2:
Expand All @@ -87,9 +92,9 @@ def gen_id(id_1, id_2):

def compare_rouge_score(str_1, str_2):
scores = scorer.score(str_1, str_2)
return scores['rougeL'].fmeasure > max_ratio
return scores[rouge_type].fmeasure

compare_rouge_score_udf = F.udf(compare_rouge_score, T.BooleanType())
compare_rouge_score_udf = F.udf(compare_rouge_score, T.FloatType())

batch_count = 0
while instruction_df_2 is not None:
Expand All @@ -104,8 +109,18 @@ def compare_rouge_score(str_1, str_2):
dupli_score_matrix = dupli_score_matrix.dropDuplicates(["id_pair"])
dupli_score_matrix = dupli_score_matrix.filter(F.column("id_1") != F.column("id_2"))

remove_df = dupli_score_matrix.filter(
compare_rouge_score_udf(F.column("instruction"), F.column("instruction_2"))).select(
remove_df = dupli_score_matrix.withColumn(rouge_score_column_name,
compare_rouge_score_udf(F.column("instruction"),
F.column("instruction_2")))
remove_df.show()
remove_df = remove_df.filter(F.column(rouge_score_column_name) > max_ratio)
if self.score_store_path:
if batch_count == 0:
score_df = remove_df.select('id_pair', 'rouge_score')
else:
score_df = score_df.union(remove_df.select('id_pair', 'rouge_score'))

remove_df = remove_df.select(
"instruction",
"id_1")
remove_df = remove_df.dropDuplicates(["id_1"])
Expand All @@ -119,6 +134,8 @@ def compare_rouge_score(str_1, str_2):
instruction_df_1 = instruction_df_1.withColumnRenamed("instruction", self.text_key)
spark_df = spark_df.join(instruction_df_1,
on=self.text_key, how="inner").select(spark_df.columns)
if self.score_store_path:
score_df.write.parquet(self.score_store_path, mode='overwrite')
return spark_df


Expand Down
16 changes: 10 additions & 6 deletions RecDP/tests/test_llmutils_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,21 @@ def test_filter_by_word_num_ray(self):
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_rayds(ctx.ds))


def test_filter_by_perplexity_ray(self):
pass
# Ray version not supported yet
op = PerplexityFilter()
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_rayds(ctx.ds))


def test_filter_by_word_repetition_ray(self):
pass
# Ray version not supported yet
op = WordRepetitionFilter()
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_rayds(ctx.ds))

def test_perplexity_score_ray(self):
op = TextPerplexityScore(language='en')
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_rayds(ctx.ds))

def test_text_fixer_ray(self):
op = TextFix()
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
Expand Down Expand Up @@ -346,3 +345,8 @@ def test_rouge_score_dedup_spark(self):
op = RougeScoreDedup()
with SparkContext("tests/data/llm_data/github_sample_50.jsonl") as ctx:
ctx.show(op.process_spark(ctx.spark, ctx.ds))

def test_perplexity_score_spark(self):
op = TextPerplexityScore(language='en')
with SparkContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_spark(ctx.spark, ctx.ds))

0 comments on commit 9c62d4d

Please sign in to comment.