diff --git a/train.py b/train.py index 5f49fbc72..cc9f18081 100644 --- a/train.py +++ b/train.py @@ -169,7 +169,11 @@ def gather(self, outputs, output_device): return out -def train(): +def train(optimizer=None): + """ + @param optimizer: set custom optimizer, default (None) uses + `torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay)` + """ if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) @@ -212,8 +216,10 @@ def train(): print('Initializing weights...') yolact_net.init_weights(backbone_path=args.save_folder + cfg.backbone.path) - optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, + if optimizer is None: + optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.decay) + criterion = MultiBoxLoss(num_classes=cfg.num_classes, pos_threshold=cfg.positive_iou_threshold, neg_threshold=cfg.negative_iou_threshold, @@ -291,11 +297,11 @@ def train(): if changed: cfg.delayed_settings = [x for x in cfg.delayed_settings if x[0] > iteration] - # Warm up by linearly interpolating the learning rate from some smaller value + # Warm up by linearly interpolating the learning rate from some smaller value if cfg.lr_warmup_until > 0 and iteration <= cfg.lr_warmup_until: set_lr(optimizer, (args.lr - cfg.lr_warmup_init) * (iteration / cfg.lr_warmup_until) + cfg.lr_warmup_init) - # Adjust the learning rate at the given iterations, but also if we resume from past that iteration + # Adjust the learning rate at the given iterations, but also if we resume from past that iteration while step_index < len(cfg.lr_steps) and iteration >= cfg.lr_steps[step_index]: step_index += 1 set_lr(optimizer, args.lr * (args.gamma ** step_index))