-
Notifications
You must be signed in to change notification settings - Fork 123
/
Copy pathpretraining_train.py
183 lines (151 loc) · 11.3 KB
/
pretraining_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (c) 2019-present, Thomas Wolf.
# All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree.
import logging
import math
import os
from argparse import ArgumentParser
from pprint import pformat
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from ignite.contrib.handlers import CosineAnnealingScheduler, create_lr_scheduler_with_warmup
from ignite.engine import Engine, Events
from ignite.metrics import Loss, MetricsLambda
from pytorch_pretrained_bert import BertTokenizer
from pretraining_model import TransformerWithLMHead
from utils import get_and_tokenize_dataset, average_distributed_scalar, add_logging_and_checkpoint_saving, WEIGHTS_NAME
logger = logging.getLogger(__file__)
def get_data_loaders(args, tokenizer):
""" Prepare the dataloaders for training and evaluation """
datasets = get_and_tokenize_dataset(tokenizer, args.dataset_path, args.dataset_cache)
logger.info("Convert to Tensor and reshape in blocks of the transformer's input length")
for split_name in ['train', 'valid']:
tensor = torch.tensor(datasets[split_name], dtype=torch.long)
num_sequences = (tensor.size(0) // args.num_max_positions) * args.num_max_positions
datasets[split_name] = tensor.narrow(0, 0, num_sequences).view(-1, args.num_max_positions)
logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(datasets['train']) if args.distributed else None
valid_sampler = torch.utils.data.distributed.DistributedSampler(datasets['valid']) if args.distributed else None
train_loader = DataLoader(datasets['train'], sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed))
valid_loader = DataLoader(datasets['valid'], sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False)
logger.info("Train dataset (Batch, Seq length): {}".format(datasets['train'].shape))
logger.info("Valid dataset (Batch, Seq length): {}".format(datasets['valid'].shape))
return train_loader, valid_loader, train_sampler, valid_sampler, datasets['train_num_words'], datasets['valid_num_words']
def train():
parser = ArgumentParser()
parser.add_argument("--dataset_path", type=str, default='wikitext-2', help="One of ('wikitext-103', 'wikitext-2') or a dict of splits paths.")
parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
parser.add_argument("--embed_dim", type=int, default=410, help="Embeddings dim")
parser.add_argument("--hidden_dim", type=int, default=2100, help="Hidden dimension")
parser.add_argument("--num_max_positions", type=int, default=256, help="Max input length")
parser.add_argument("--num_heads", type=int, default=10, help="Number of heads")
parser.add_argument("--num_layers", type=int, default=16, help="NUmber of layers")
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
parser.add_argument("--initializer_range", type=float, default=0.02, help="Normal initialization standard deviation")
parser.add_argument("--sinusoidal_embeddings", action="store_true", help="Use sinusoidal embeddings")
parser.add_argument("--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling")
parser.add_argument("--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss")
parser.add_argument("--train_batch_size", type=int, default=8, help="Batch size for training")
parser.add_argument("--valid_batch_size", type=int, default=8, help="Batch size for validation")
parser.add_argument("--lr", type=float, default=2.5e-4, help="Learning rate")
parser.add_argument("--max_norm", type=float, default=0.25, help="Clipping gradient norm")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay")
parser.add_argument("--n_epochs", type=int, default=200, help="Number of training epochs")
parser.add_argument("--n_warmup", type=int, default=1000, help="Number of warmup iterations")
parser.add_argument("--eval_every", type=int, default=-1, help="Evaluate every X steps (-1 => end of epoch)")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradient")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
args = parser.parse_args()
# logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log on main process only, logger.warning => log on all processes
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes
logger.info("Arguments: %s", pformat(args)) # This is a logger.info: only printed on the first process
# Initialize distributed training if needed
args.distributed = (args.local_rank != -1)
if args.distributed:
torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
logger.info("Prepare tokenizer, model and optimizer")
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) # Let's use a pre-defined tokenizer
args.num_embeddings = len(tokenizer.vocab) # We need this to create the model at next line (number of embeddings to use)
model = TransformerWithLMHead(args)
model.to(args.device)
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
logger.info("Model has %s parameters", sum(p.numel() for p in model.parameters() if p.requires_grad))
# Prepare model for distributed training if needed
if args.distributed:
model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
logger.info("Prepare datasets")
train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders(args, tokenizer)
# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
def mask_tokens(inputs):
labels = inputs.clone()
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
labels[~masked_indices] = -1 # We only compute loss on masked tokens
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
inputs[indices_replaced] = tokenizer.vocab["[MASK]"] # 80% of the time, replace masked input tokens with [MASK]
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device)
inputs[indices_random] = random_words[indices_random] # 10% of the time, replace masked input tokens with random word
return inputs, labels
# Training function and trainer
def update(engine, batch):
model.train()
inputs = batch.transpose(0, 1).contiguous().to(args.device) # to shape [seq length, batch]
inputs, labels = mask_tokens(inputs) if args.mlm else (inputs, inputs) # Prepare masked input/labels if we use masked LM
logits, loss = model(inputs, labels=labels)
loss = loss / args.gradient_accumulation_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
if engine.state.iteration % args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
return loss.item()
trainer = Engine(update)
# Evaluation function and evaluator (evaluator output is the input of the metrics)
def inference(engine, batch):
model.eval()
with torch.no_grad():
inputs = batch.transpose(0, 1).contiguous().to(args.device) # to shape [seq length, batch]
inputs, labels = mask_tokens(inputs) if args.mlm else (inputs, inputs) # Prepare masked input/labels if we use masked LM
logits = model(inputs)
shift_logits = logits[:-1] if not args.mlm else logits
shift_labels = labels[1:] if not args.mlm else labels
return shift_logits.view(-1, logits.size(-1)), shift_labels.view(-1)
evaluator = Engine(inference)
# Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
if args.eval_every > 0:
trainer.add_event_handler(Events.ITERATION_COMPLETED,
lambda engine: evaluator.run(val_loader) if engine.state.iteration % args.eval_every == 0 else None)
if args.n_epochs < 1:
trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
# Make sure distributed data samplers split the dataset nicely between the distributed processes
if args.distributed:
trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))
# Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule
cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, len(train_loader) * args.n_epochs)
scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr, args.n_warmup)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
# Prepare metrics - note how we average distributed metrics using average_distributed_scalar
metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
# Let's convert sub-word perplexities in word perplexities. If you need details: http://sjmielke.com/comparing-perplexities.htm
metrics["average_word_ppl"] = MetricsLambda(lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words), metrics["average_nll"])
for name, metric in metrics.items():
metric.attach(evaluator, name)
# On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train
if args.local_rank in [-1, 0]:
checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving(trainer, evaluator, metrics, model, optimizer, args)
# Run the training
trainer.run(train_loader, max_epochs=args.n_epochs)
# On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading
if args.local_rank in [-1, 0] and args.n_epochs > 0:
tb_logger.close()
if __name__ == "__main__":
train()