Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

detokenization parallelization #37

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Below shows the generation speed gain by using FastSeq.
| Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup |
|------------------|:--------------------------:|:-------------------------:|:-----:|
| [ProphetNet](examples/prophetnet/README.md) | 2.7 | 10.3 | 3.8x |
| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 14.5 | 5.4x |
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 6.4 | 1.9x |
| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 6.5 | 1.6x |
| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 7.5 | 1.6x |
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 135.0 | 1.6x |
| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 13.3 | 5x |
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 11.0 | 3.2x |
| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 13.5 | 3.4x |
| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 17.0 | 3.5x |
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 124.0 | 1.5x |

- All the following benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model.
- `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.9.0 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version.
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/models/hf_bart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.4 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.9 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.7 100
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.0 100

## Accuracy
#grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/models/hf_distibart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ source utils.sh
grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.3 100
# todo: bigger bs doesn't increase speed
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.5 100
grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.5 100

## Accuracy
#grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/models/hf_mbart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ source utils.sh
# Accuracy
grep "facebook/mbart-large-en-ro wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.79 27.95
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 5.8 6.2
grep -E "transformers_v3.0.2 facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.6 7.7
grep -E "transformers_v3.0.2\+fastseq_v.* facebook/mbart-large-en-ro wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.0 100
5 changes: 2 additions & 3 deletions benchmarks/models/hf_t5.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ source utils.sh
grep "t5-base wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.42 27.44
# Speed on V100 16GB 250W
grep -E "transformers_v3.0.2 t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.6 5.2
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.8 7.3
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.5 8.0

grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 14.4 14.8
grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 16.8 17.0
177 changes: 153 additions & 24 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,115 @@
import argparse
import json
from pathlib import Path

import torch
from multiprocessing import Process, Queue
from tqdm import tqdm

from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from fastseq_cli.transformers_utils import use_task_specific_params, trim_batch, calculate_rouge, calculate_bleu_score

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None

class TokenizeDataset(torch.utils.data.Dataset):
"""Characterizes a dataset for PyTorch"""
def __init__(self, examples, tokenizer, model_name, prefix):
"""Multiprocess Dataloader.
Args:
examples (List(str)): a list of input sentences.
tokenizer (AutoTokenizer): instance of AutoTokenizer.
model_name (string): model name.
prefix (string): input example prefix if any.
"""
self.examples = examples
self.tokenizer= tokenizer
self.model_name = model_name
self.prefix = prefix
self.return_tensors="pt"
self.truncation=True
Comment on lines +30 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use hard code here? We can put these two as the parameters of the constructor.

self.padding="max_length"

def __len__(self):
return len(self.examples)

def __getitem__(self, index):
batch = self.examples[index]
if "t5" in self.model_name:
batch = self.prefix + batch
batch = self.tokenizer(batch,
return_tensors=self.return_tensors,
truncation=self.truncation,
padding=self.padding)
return batch['input_ids'], batch['attention_mask']

class IOProcess (Process):
""" Write detokenized output to file in order."""
def __init__(self, msg_queue, fout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs

super(IOProcess, self).__init__()
self.msg_queue = msg_queue
self.fout = fout
self.waiting_for=0
self.dec_buf = {}

def process_dec(self, dec):
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()

def process_buffer(self):
while self.waiting_for in self.dec_buf:
self.process_dec(self.dec_buf[self.waiting_for])
del self.dec_buf[self.waiting_for]
self.waiting_for+=1

def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def run(self):
while True:
ind, dec = self.msg_queue.get()
if dec == GENERATE_FINISHED:
break
elif ind != self.waiting_for:
self.dec_buf[ind] = dec
else:
self.process_dec(dec)
self.waiting_for+=1
self.process_buffer()
self.process_buffer()
assert not self.dec_buf, "IO Buffer not empty"
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess(Process):
""" Parallel detokenization """
def __init__(self, tokenizer, data_queue, msg_queue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs.

skip_special_tokens, clean_up_tokenization_spaces):
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.skip_special_tokens = skip_special_tokens

def run(self):
while True:
ind, summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
elif summaries == POSTPROCESS_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
else:
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens = self.skip_special_tokens,
clean_up_tokenization_spaces =
self.clean_up_tokenization_spaces)
self.msg_queue.put((ind, dec))

self.data_queue.close()
self.data_queue.join_thread()
self.msg_queue.close()
self.msg_queue.join_thread()

def generate_summaries_or_translations(
examples: list,
Expand All @@ -29,6 +123,10 @@ def generate_summaries_or_translations(
decoder_start_token_id=None,
fastseq_opt=True,
no_repeat_ngram_size=None,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
preprocess_cpu_num=2,
postprocess_cpu_num=2,
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -46,30 +144,44 @@ def generate_summaries_or_translations(

# update config with summarization specific params
use_task_specific_params(model, task)
data_queue = Queue()
msg_queue = Queue()
p_list = []

for i in range(postprocess_cpu_num):
p = PostProcess(tokenizer, data_queue, msg_queue,
skip_special_tokens, clean_up_tokenization_spaces)
p_list.append(p)
p.start()

for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch,
return_tensors="pt",
truncation=True,
padding="max_length").to(device)
io_process = IOProcess( msg_queue, fout)
io_process.start()
dataset = TokenizeDataset(examples, tokenizer, model_name,
model.config.prefix)
training_generator = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers = preprocess_cpu_num,
drop_last=True)
for ind, batch in tqdm(enumerate(training_generator)):
input_ids, attention_mask = batch
input_ids = input_ids.view(batch_size, -1).to(device)
attention_mask = attention_mask.view(batch_size, -1).to(device)
input_ids, attention_mask = trim_batch(
**batch, pad_token_id=tokenizer.pad_token_id)
input_ids, tokenizer.pad_token_id, attention_mask)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
no_repeat_ngram_size=no_repeat_ngram_size,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()

summaries_cpu = summaries.cpu()
data_queue.put((ind, summaries_cpu))
data_queue.put((-1, GENERATE_FINISHED))
for p in p_list:
p.join()
msg_queue.put((-1, GENERATE_FINISHED))
io_process.join()
fout.close()

def run_generate():
"""Entrance is here."""
Expand Down Expand Up @@ -118,6 +230,19 @@ def run_generate():
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
required=False, help="size of no repeat ngram")
parser.add_argument("--include_special_tokens", action="store_true")
parser.add_argument("--clean_up_tokenization_spaces", action="store_true")
parser.add_argument("--preprocess_cpu_num",
type=int,
default=2,
required=False,
help="pre-processing worker threads")
parser.add_argument("--postprocess_cpu_num",
type=int,
default=2,
required=False,
help="post-processing worker threads")

args = parser.parse_args()
examples = [
" " + x.rstrip() if "t5" in args.model_name else x.rstrip()
Expand All @@ -137,7 +262,11 @@ def run_generate():
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
)
skip_special_tokens=not args.include_special_tokens,
clean_up_tokenization_spaces=args.clean_up_tokenization_spaces,
preprocess_cpu_num=args.preprocess_cpu_num,
postprocess_cpu_num=args.postprocess_cpu_num,
)
if args.reference_path is None:
return
# Compute scores
Expand Down