Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

detokenization parallelization #37

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions benchmarks/models/hf_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.4 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.8 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.7 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.1 100

## Accuracy
#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/models/hf_distibart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100
# todo: bigger bs doesn't increase speed
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.5 100

## Accuracy
#grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/models/hf_mbart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ source utils.sh
grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.8 6.2
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.2 100
119 changes: 102 additions & 17 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,89 @@
import argparse
import json
from pathlib import Path

import torch
from tqdm import tqdm

from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score
from multiprocessing import Process, Queue, JoinableQueue, cpu_count
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

class IOProcess (Process) :
""" Write detokenized output to file in order."""
def __init__ (self, msg_queue, fout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__ (self, msg_queue, fout):
def __init__(self, msg_queue, fout):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the similar spaces in other places.

super(IOProcess, self).__init__()
self.msg_queue = msg_queue
self.fout = fout
self.waiting_for=0
self.dec_buf = {}

def process_dec (self, dec) :
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()

def process_buffer (self):
while self.waiting_for in self.dec_buf :
self.process_dec(self.dec_buf[self.waiting_for])
del self.dec_buf[self.waiting_for]
self.waiting_for+=1

def run (self) :
while (True) :
ind, dec = self.msg_queue.get()
if dec == GENERATE_FINISHED :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if dec == GENERATE_FINISHED :
if dec == GENERATE_FINISHED:

break
elif ind != self.waiting_for:
self.dec_buf[ind] = dec
else :
self.process_dec(dec)
self.waiting_for+=1
self.process_buffer()
self.process_buffer()
assert not self.dec_buf, "IO Buffer not empty"
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess (Process) :
""" Parallel detokenization """
def __init__ (self, tokenizer, data_queue, msg_queue, skip_special_tokens, clean_up_tokenization_spaces) :
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.skip_special_tokens = skip_special_tokens

def run (self) :
while True :
ind, summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED :
self.data_queue.put((-1,POSTPROCESS_FINISHED))
break
elif summaries == POSTPROCESS_FINISHED :
self.data_queue.put((-1,POSTPROCESS_FINISHED))
break
else :
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens = self.skip_special_tokens,
clean_up_tokenization_spaces = self.clean_up_tokenization_spaces)
self.msg_queue.put((ind,dec))

self.data_queue.close()
self.data_queue.join_thread()
self.msg_queue.close()
self.msg_queue.join_thread()


def generate_summaries_or_translations(
examples: list,
Expand All @@ -29,6 +97,8 @@ def generate_summaries_or_translations(
decoder_start_token_id=None,
fastseq_opt=True,
no_repeat_ngram_size=None,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -41,13 +111,25 @@ def generate_summaries_or_translations(
model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# update config with summarization specific params
use_task_specific_params(model, task)

for batch in tqdm(list(chunks(examples, batch_size))):
data_queue = Queue()
msg_queue = Queue()
p_list = []
threads = cpu_count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to allow users to specify CPU numbers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't make a big difference right, although I can create an argument .,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be some differences. It will waste the CPU resources and it also brings overhead to create and manage these processes and sync data across these processes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a parameter define when support parallel for fairseq. GPU machine has 32/64 or more CPU. Do you get better speed when have threads > 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@feihugis I added support for this.
@yuyan2do , I haven't yet analyzed effect of changing num of threads on speed, let me do that .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't notice significant changes in overall time when number of threads are changed.


for i in range(threads) :
p = PostProcess(tokenizer, data_queue, msg_queue, skip_special_tokens, clean_up_tokenization_spaces)
p_list.append(p)
p.start()

io_process = IOProcess( msg_queue, fout)
io_process.start()

for ind, batch in tqdm(enumerate(list(chunks(examples, batch_size)))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch,
Expand All @@ -63,14 +145,14 @@ def generate_summaries_or_translations(
no_repeat_ngram_size=no_repeat_ngram_size,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()


summaries_cpu = summaries.cpu()
data_queue.put((ind,summaries_cpu))
data_queue.put((-1,GENERATE_FINISHED))
for p in p_list :
p.join()
msg_queue.put((-1,GENERATE_FINISHED))
io_process.join()
def run_generate():
"""Entrance is here."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -116,8 +198,9 @@ def run_generate():
help="How many observations. Defaults to all.")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
required=False, help="size of no repeat ngram")

args = parser.parse_args()
examples = [
" " + x.rstrip() if "t5" in args.model_name else x.rstrip()
Expand All @@ -137,7 +220,9 @@ def run_generate():
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
if args.reference_path is None:
return
# Compute scores
Expand Down