diff --git a/train.py b/train.py index 7f8a7562..6dffa2cd 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ # import math import random import argparse +import sys from distutils.version import LooseVersion # Numerical libs import torch @@ -15,6 +16,10 @@ from utils import AverageMeter, parse_devices, setup_logger from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback +try: + from apex import amp +except ImportError: + amp = None # train one epoch def train(segmentation_module, iterator, optimizers, history, epoch, cfg): @@ -43,7 +48,12 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg): acc = acc.mean() # Backward - loss.backward() + if args.apex: + with amp.scale_loss(loss, optimizers) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + for optimizer in optimizers: optimizer.step() @@ -57,10 +67,11 @@ def train(segmentation_module, iterator, optimizers, history, epoch, cfg): # calculate accuracy, and display if i % cfg.TRAIN.disp_iter == 0: - print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' + print('Epoch: [{}][{}/{}], Img/s: {:.2f} Time: {:.2f}, Data: {:.2f}, ' 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 'Accuracy: {:4.2f}, Loss: {:.6f}' .format(epoch, i, cfg.TRAIN.epoch_iters, + batch_data[0]['img_data'].shape[0]/batch_time.average(), batch_time.average(), data_time.average(), cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, ave_acc.average(), ave_total_loss.average())) @@ -193,6 +204,11 @@ def main(cfg, gpus): nets = (net_encoder, net_decoder, crit) optimizers = create_optimizers(nets, cfg) + if args.apex: + nets, optimizers = amp.initialize(list(nets), list(optimizers), + opt_level=args.apex_opt_level + ) + # Main loop history = {'train': {'epoch': [], 'loss': [], 'acc': []}} @@ -230,6 +246,14 @@ def main(cfg, gpus): default=None, nargs=argparse.REMAINDER, ) + # Mixed precision training parameters + parser.add_argument('--apex', action='store_true', + help='Use apex for mixed precision training') + parser.add_argument('--apex-opt-level', default='O1', type=str, + help='For apex mixed precision training' + 'O0 for FP32 training, O1 for mixed precision training.' + 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' + ) args = parser.parse_args() cfg.merge_from_file(args.cfg) @@ -256,6 +280,13 @@ def main(cfg, gpus): assert os.path.exists(cfg.MODEL.weights_encoder) and \ os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" + if args.apex: + if sys.version_info < (3, 0): + raise RuntimeError("Apex currently only supports Python 3. Aborting.") + if amp is None: + raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training.") + # Parse gpu ids gpus = parse_devices(args.gpus) gpus = [x.replace('gpu', '') for x in gpus]