Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

detokenization parallelization #37

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
92 changes: 60 additions & 32 deletions fastseq_cli/transformers_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
import json
from pathlib import Path
from multiprocessing import Process, Queue, cpu_count
from multiprocessing import Process, Queue
from tqdm import tqdm
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
Expand All @@ -13,13 +13,28 @@
GENERATE_FINISHED = 'done'
POSTPROCESS_FINISHED = None

class Dataset(torch.utils.data.Dataset):
"""Characterizes a dataset for PyTorch"""
def __init__(self, examples, tokenizer, model_name, prefix):
self.examples = examples
self.tokenizer= tokenizer
self.model_name = model_name
self.prefix = prefix

def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def __len__(self):
return len(self.examples)

class IOProcess (Process) :
def __getitem__(self, index):
if "t5" in self.model_name:
batch = [self.prefix + text for text in batch]
batch = self.examples[index]
batch = self.tokenizer(batch,
return_tensors="pt",
truncation=True,
padding="max_length")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add these parameters to the constructor instead of hard coding.

return batch['input_ids'], batch['attention_mask']

class IOProcess (Process):
""" Write detokenized output to file in order."""
def __init__(self, msg_queue, fout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs

super(IOProcess, self).__init__()
Expand All @@ -28,25 +43,25 @@ def __init__(self, msg_queue, fout):
self.waiting_for=0
self.dec_buf = {}

def process_dec(self, dec) :
def process_dec(self, dec):
for hypothesis in dec:
self.fout.write(hypothesis + "\n")
self.fout.flush()

def process_buffer(self):
while self.waiting_for in self.dec_buf :
while self.waiting_for in self.dec_buf:
self.process_dec(self.dec_buf[self.waiting_for])
del self.dec_buf[self.waiting_for]
self.waiting_for+=1

def run(self) :
while True :
def run(self):
while True:
ind, dec = self.msg_queue.get()
if dec == GENERATE_FINISHED:
break
elif ind != self.waiting_for:
self.dec_buf[ind] = dec
else :
else:
self.process_dec(dec)
self.waiting_for+=1
self.process_buffer()
Expand All @@ -55,27 +70,27 @@ def run(self) :
self.msg_queue.close()
self.msg_queue.join_thread()

class PostProcess(Process) :
class PostProcess(Process):
""" Parallel detokenization """
def __init__(self, tokenizer, data_queue, msg_queue,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing docs.

skip_special_tokens, clean_up_tokenization_spaces) :
skip_special_tokens, clean_up_tokenization_spaces):
super(PostProcess, self).__init__()
self.data_queue = data_queue
self.msg_queue = msg_queue
self.tokenizer = tokenizer
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
self.skip_special_tokens = skip_special_tokens

def run(self) :
while True :
def run(self):
while True:
ind, summaries = self.data_queue.get()
if summaries == GENERATE_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
elif summaries == POSTPROCESS_FINISHED :
elif summaries == POSTPROCESS_FINISHED:
self.data_queue.put((-1, POSTPROCESS_FINISHED))
break
else :
else:
dec = self.tokenizer.batch_decode(summaries,
skip_special_tokens = self.skip_special_tokens,
clean_up_tokenization_spaces =
Expand All @@ -87,7 +102,6 @@ def run(self) :
self.msg_queue.close()
self.msg_queue.join_thread()


def generate_summaries_or_translations(
examples: list,
out_file: str,
Expand All @@ -101,6 +115,8 @@ def generate_summaries_or_translations(
no_repeat_ngram_size=None,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
pre_process_threads=2,
post_process_threads=2,
**gen_kwargs,
) -> None:
"""Run generation"""
Expand All @@ -121,26 +137,24 @@ def generate_summaries_or_translations(
data_queue = Queue()
msg_queue = Queue()
p_list = []
threads = cpu_count()

for i in range(threads) :
for i in range(post_process_threads):
p = PostProcess(tokenizer, data_queue, msg_queue,
skip_special_tokens, clean_up_tokenization_spaces)
p_list.append(p)
p.start()

io_process = IOProcess( msg_queue, fout)
io_process.start()

for ind, batch in tqdm(enumerate(list(chunks(examples, batch_size)))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch,
return_tensors="pt",
truncation=True,
padding="max_length").to(device)
dataset = Dataset(examples, tokenizer, model_name, model.config.prefix)
training_generator = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, num_workers = pre_process_threads)
for ind, batch in tqdm(enumerate(training_generator)):
input_ids, attention_mask = batch
input_ids = input_ids.view(batch_size, -1).to(device)
attention_mask = attention_mask.view(batch_size, -1).to(device)
input_ids, attention_mask = trim_batch(
**batch, pad_token_id=tokenizer.pad_token_id)
input_ids, tokenizer.pad_token_id, attention_mask)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
Expand All @@ -151,7 +165,7 @@ def generate_summaries_or_translations(
summaries_cpu = summaries.cpu()
data_queue.put((ind, summaries_cpu))
data_queue.put((-1, GENERATE_FINISHED))
for p in p_list :
for p in p_list:
p.join()
msg_queue.put((-1, GENERATE_FINISHED))
io_process.join()
Expand Down Expand Up @@ -204,6 +218,18 @@ def run_generate():
parser.add_argument("--without_fastseq_opt", action="store_true")
parser.add_argument("--no_repeat_ngram_size", type=int, default=None,
required=False, help="size of no repeat ngram")
parser.add_argument("--include_special_tokens", action="store_true")
parser.add_argument("--clean_up_tokenization_spaces", action="store_true")
parser.add_argument("--pre_process_threads",
type=int,
default=2,
required=False,
help="pre-processing worker threads")
parser.add_argument("--post_process_threads",
type=int,
default=2,
required=False,
help="post-processing worker threads")

args = parser.parse_args()
examples = [
Expand All @@ -224,8 +250,10 @@ def run_generate():
decoder_start_token_id=args.decoder_start_token_id,
fastseq_opt=not args.without_fastseq_opt,
no_repeat_ngram_size=args.no_repeat_ngram_size,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
skip_special_tokens=not args.include_special_tokens,
clean_up_tokenization_spaces=args.clean_up_tokenization_spaces,
pre_process_threads=args.pre_process_threads,
post_process_threads=args.post_process_threads,
)
if args.reference_path is None:
return
Expand Down