From b0f5260cb43026f6457b6ef6c57200bb113a28a5 Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Wed, 19 Jun 2019 18:19:55 +1000 Subject: [PATCH 1/2] adding mixed precision training, which improves performance by about 40% on GPUs with tensor cores --- train.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 66185200..c74c1009 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ # import math import random import argparse +import sys from distutils.version import LooseVersion # Numerical libs import torch @@ -15,6 +16,10 @@ from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback import lib.utils.data as torchdata +try: + from apex import amp +except ImportError: + amp = None # train one epoch def train(segmentation_module, iterator, optimizers, history, epoch, args): @@ -39,7 +44,12 @@ def train(segmentation_module, iterator, optimizers, history, epoch, args): acc = acc.mean() # Backward - loss.backward() + if args.apex: + with amp.scale_loss(loss, optimizers) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + for optimizer in optimizers: optimizer.step() @@ -53,10 +63,11 @@ def train(segmentation_module, iterator, optimizers, history, epoch, args): # calculate accuracy, and display if i % args.disp_iter == 0: - print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' + print('Epoch: [{}][{}/{}], Img/s: {:.2f} Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}' .format(epoch, i, args.epoch_iters, + batch_data[0]['img_data'].shape[0]/batch_time.average(), batch_time.average(), data_time.average(), args.running_lr_encoder, args.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) @@ -193,6 +204,11 @@ def main(args): nets = (net_encoder, net_decoder, crit) optimizers = create_optimizers(nets, args) + if args.apex: + nets, optimizers = amp.initialize(list(nets), list(optimizers), + opt_level=args.apex_opt_level + ) + # Main loop history = {'train': {'epoch': [], 'loss': [], 'acc': []}} @@ -281,11 +297,27 @@ def main(args): parser.add_argument('--disp_iter', type=int, default=20, help='frequency to display') + # Mixed precision training parameters + parser.add_argument('--apex', action='store_true', + help='Use apex for mixed precision training') + parser.add_argument('--apex-opt-level', default='O1', type=str, + help='For apex mixed precision training' + 'O0 for FP32 training, O1 for mixed precision training.' + 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' + ) + args = parser.parse_args() print("Input arguments:") for key, val in vars(args).items(): print("{:16} {}".format(key, val)) + if args.apex: + if sys.version_info < (3, 0): + raise RuntimeError("Apex currently only supports Python 3. Aborting.") + if amp is None: + raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training.") + # Parse gpu ids all_gpus = parse_devices(args.gpus) all_gpus = [x.replace('gpu', '') for x in all_gpus] From 08d17adf594cd8fbb5581149e1c3ac91462f6bce Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Thu, 8 Aug 2019 02:17:00 +0000 Subject: [PATCH 2/2] fix conflicts --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 66704e40..6dffa2cd 100644 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg): print('Epoch: [{}][{}/{}], Img/s: {:.2f} Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}' - .format(epoch, i, args.epoch_iters, + .format(epoch, i, cfg.TRAIN.epoch_iters, batch_data[0]['img_data'].shape[0]/batch_time.average(), batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder,