Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
[v1.2][ISSUE-306]Refactor rouge_score_dedup to support more compute f…
Browse files Browse the repository at this point in the history
…unctions (#431)

* optimize rouge-score impl spark version

Signed-off-by: Xue, Chendi <[email protected]>

* update rouge-score method

Signed-off-by: Xue, Chendi <[email protected]>

* update

Signed-off-by: Xue, Chendi <[email protected]>

* optimize by using local file

Signed-off-by: Xue, Chendi <[email protected]>

* Format codes and remove unnecessary comments.

.

* Refactor rouge_score_dedup to support more compute functions.

.

* Remove debug codes.

.

* Update text_compare_dedup.py

---------

Signed-off-by: Xue, Chendi <[email protected]>
Co-authored-by: Xue, Chendi <[email protected]>
  • Loading branch information
yao531441 and xuechendi authored Nov 3, 2023
1 parent 9c62d4d commit 106a2a0
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 143 deletions.
1 change: 1 addition & 0 deletions RecDP/pyrecdp/primitives/llmutils/rouge_score_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ def rouge_score_dedup(data_dir, out_dir, data_file_type="jsonl", max_ratio=0.7,
output_dir = args.output_dir
max_ratio = args.max_ratio
batch_size = args.batch_size

with Timer(f"Remove duplicate item by rouge score for {data_dir}"):
rouge_score_dedup(data_dir, output_dir, data_file_type, max_ratio, batch_size)
2 changes: 1 addition & 1 deletion RecDP/pyrecdp/primitives/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
from .text_diversityindicate import TextDiversityIndicate
from .text_custom import TextCustomerMap, TextCustomerFilter
from .text_toxicity import TextToxicity
from .text_rouge_score_dedup import RougeScoreDedup
from .text_compare_dedup import RougeScoreDedup
from .text_perplexity_score import TextPerplexityScore
181 changes: 181 additions & 0 deletions RecDP/pyrecdp/primitives/operations/text_compare_dedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from .base import BaseLLMOperation, LLMOPERATORS, statistics_decorator
from ray.data import Dataset
from pyspark.sql import DataFrame

import pyspark.sql.functions as F
from pyspark.sql import types as T
from pyspark.sql import Row
from rouge_score import rouge_scorer
from pyrecdp.primitives.llmutils.third_party import generate_connected_components

from .logging_utils import logger
from pyrecdp.core.utils import Timer
from tqdm import tqdm
import pandas as pd


class BaseCompareDedup(BaseLLMOperation):
def __init__(self, text_key='text', max_ratio=0.7, batch_size=100, score_store_path='RougeScorefiltered.parquet',
args_dict={}):
settings = {'text_key': text_key, 'max_ratio': max_ratio, 'batch_size': batch_size,
'score_store_path': score_store_path}
settings.update(args_dict)
super().__init__(settings)
self.text_key = text_key
self.max_ratio = max_ratio
self.batch_size = batch_size
self.score_store_path = score_store_path
self.support_spark = True
self.support_ray = False
self.new_column_name = "score"

def process_rayds(self, ds=None):
total_rows = ds.count()
line_num = []
for i in range(1, total_rows):

d1, d2, d3 = ds.split_at_indices([i, i + 1])
target_sample = d2.take(1)[0]
instruction = target_sample[self.text_key]

compute_func = self.get_compute_func()

# ds = d2.filter(lambda x: True if rouge_scorer._score_lcs(new_instruction_token, scorer._tokenizer.tokenize(
# x["instruction"])).fmeasure < 0.7 else False)
def process_row(sample):
sample[self.new_column_name] = compute_func(instruction, sample[self.text_key])
return sample

ds_score: Dataset = d1.map(lambda x: process_row(x))
if i == 1:
filterd_ds = d1
if ds_score.max("rouge_score") < self.max_ratio:
filterd_ds = filterd_ds.union(d2)

return filterd_ds

@statistics_decorator
def process_spark(self, spark, spark_df: DataFrame) -> DataFrame:
max_ratio = self.max_ratio
spark_df = spark_df.withColumn('id_1', F.monotonically_increasing_id())
instruction_df_1 = spark_df.withColumnRenamed(self.text_key, "similarity_left")
instruction_df_2 = (spark_df.withColumnRenamed("id_1", "id_2")
.withColumnRenamed(self.text_key, "similarity_right"))

monotonically_increasing_id_list = spark_df.rdd.map(lambda x: x.id_1).collect()
batches = [monotonically_increasing_id_list[i: i + self.batch_size] for i in
range(0, len(monotonically_increasing_id_list), self.batch_size)]

def gen_id(id_1, id_2):
if id_1 == id_2:
return -1
if id_1 < id_2:
return f"{id_1} :: {id_2}"
else:
return f"{id_2} :: {id_1}"

gen_id_udf = F.udf(gen_id, T.StringType())
compare_rouge_score_udf = F.udf(self.get_compute_func(), T.FloatType())
history_pair_df = None
score_df_list = []

for batch_count, to_process_ids in tqdm(enumerate(batches), total=len(batches)):
with Timer(f"Round {batch_count}"):
# prepare matrix for one batch calculation
# 1. cross join to get n*n pairs
# 2. use id_pair to reduce calculated pair, if we have dome i_j, then skip j_i
# 3. skip i_i
R = Row('id_2')
tmp_id_df = spark.createDataFrame([R(i) for i in to_process_ids])
batch_df = instruction_df_2.join(tmp_id_df, on='id_2', how='inner')
dupli_score_matrix = instruction_df_1.crossJoin(batch_df)
dupli_score_matrix = dupli_score_matrix.withColumn("id_pair",
gen_id_udf(F.column("id_1"), F.column("id_2")))
dupli_score_matrix = dupli_score_matrix.dropDuplicates(["id_pair"])
dupli_score_matrix = dupli_score_matrix.filter(F.column("id_1") != F.column("id_2"))
dupli_score_matrix = dupli_score_matrix.cache()

# Now we have minimun pair, start to calculate rouge score
remove_df = dupli_score_matrix.withColumn(self.new_column_name,
compare_rouge_score_udf(F.column("similarity_left"),
F.column("similarity_right")))

# find out sample_pairs whose similarity > threshold
remove_df = remove_df.filter(F.column(self.new_column_name) > max_ratio).cache()
logger.info(
f"Round {batch_count}: total processing num_samples is {dupli_score_matrix.count()}, detected high score num_samples is {remove_df.count()}")
# materialize one round

score_df = remove_df.select('id_1', 'id_2', 'id_pair', 'similarity_left', 'similarity_right',
self.new_column_name).toPandas()
score_df_list.append(score_df)

instruction_df_1.join(tmp_id_df.withColumnRenamed('id_2', 'id_1'), on='id_1', how='anti').write.parquet(
f"f{self.score_store_path}.tmp_df", mode='overwrite')
instruction_df_1 = spark.read.parquet(f"f{self.score_store_path}.tmp_df")

# Final join
with Timer("generate_connected_components => duplicates"):
results = []
[results.extend(df_['id_pair'].to_list()) for df_ in score_df_list]
components = generate_connected_components.generate_connected_components_py(results)
duplicates = [c for c_list in components for c in c_list[1:]]
R = Row('id_1')
total_dup = len(duplicates)
if total_dup != 0:
duplicates_sdf = spark.createDataFrame([R(dup) for dup in duplicates]).cache()
total_dup = duplicates_sdf.count()
spark_df = spark_df.join(duplicates_sdf,
on='id_1', how="left_anti").drop("id_1")
logger.info(f"Finally detected duplicated num_samples is {total_dup}")
else:
spark_df = spark_df.drop("id_1")
score_df = pd.concat(score_df_list, ignore_index=True).reset_index(drop=True)
if self.score_store_path:
import os, shutil
if os.path.exists(self.score_store_path):
os.remove(self.score_store_path)
score_df.to_parquet(self.score_store_path)
if self.statistics_flag:
self.statistics.example = score_df

return spark_df

def get_compute_func(self, *args, **kwargs):
raise NotImplementedError("Abstract func")

def summarize(self) -> str:
return (
f"A total of {self.statistics.total_in} rows of data were processed, using {self.statistics.used_time} seconds, "
f"A duplication list containing {self.statistics.total_out} found, "
f"Sampled, duplication preview: {self.statistics.example.head(50)}")


LLMOPERATORS.register(BaseCompareDedup)


class RougeScoreDedup(BaseCompareDedup):
def __init__(self, text_key='text', max_ratio=0.7, batch_size=100, score_store_path='RougeScorefiltered.parquet'):
settings = {'text_key': text_key, 'max_ratio': max_ratio, 'batch_size': batch_size,
"score_store_path": score_store_path}
super().__init__(args_dict=settings)
self.text_key = text_key
self.max_ratio = max_ratio
self.batch_size = batch_size
self.score_store_path = score_store_path
self.rouge_type = 'rougeL'
self.support_spark = True
self.support_ray = False

def get_compute_func(self, *args, **kwargs):
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer([self.rouge_type], use_stemmer=False)

def compare_rouge_score(str_1, str_2):
scores = scorer.score(str_1, str_2)
return scores['rougeL'].fmeasure

return compare_rouge_score


LLMOPERATORS.register(RougeScoreDedup)
142 changes: 0 additions & 142 deletions RecDP/pyrecdp/primitives/operations/text_rouge_score_dedup.py

This file was deleted.

0 comments on commit 106a2a0

Please sign in to comment.