-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
203 additions
and
219 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,5 @@ | ||
import argparse | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from threading import Thread | ||
|
||
import numpy as np | ||
import torch | ||
|
@@ -29,14 +26,14 @@ def get_thres(data, | |
set_logging() | ||
device = select_device(opt.device, batch_size=batch_size) | ||
if isinstance(data, str): | ||
is_coco = data.endswith('coco.yaml') | ||
# is_coco = data.endswith('coco.yaml') | ||
with open(data) as f: | ||
data = yaml.load(f, Loader=yaml.SafeLoader) | ||
check_dataset(data) # check | ||
nc = int(data['nc']) # number of classes | ||
iouv = torch.linspace(0.5, 0.95, | ||
10).to(device) # iou vector for [email protected]:0.95 | ||
niou = iouv.numel() | ||
# iouv = torch.linspace(0.5, 0.95, | ||
# 10).to(device) # iou vector for [email protected]:0.95 | ||
# niou = iouv.numel() | ||
|
||
# Load model | ||
model = Model(cfg, ch=3, nc=nc) # create | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,11 +26,10 @@ | |
from utils.autoanchor import check_anchors | ||
from utils.checkpoint import get_state_dict | ||
from utils.datasets import create_dataloader | ||
from utils.general import (check_dataset, check_file, check_git_status, | ||
check_img_size, colorstr, fitness, get_latest_run, | ||
increment_path, init_seeds, labels_to_class_weights, | ||
labels_to_image_weights, one_cycle, set_logging, | ||
strip_optimizer) | ||
from utils.general import (check_dataset, check_file, check_img_size, colorstr, | ||
fitness, get_latest_run, increment_path, init_seeds, | ||
labels_to_class_weights, labels_to_image_weights, | ||
one_cycle, set_logging, strip_optimizer) | ||
from utils.loss import ComputeLoss, ComputeLossOTA, ComputeLossOTADual | ||
from utils.plots import plot_images, plot_lr_scheduler, plot_results | ||
from utils.torch_utils import (ModelEMA, intersect_dicts, is_parallel, | ||
|
@@ -85,7 +84,8 @@ def train(hyp, opt, device, tb_writer=None): | |
loggers['wandb'] = wandb_logger.wandb | ||
data_dict = wandb_logger.data_dict | ||
if wandb_logger.wandb: | ||
weight, epochs, hyp = opt.weight, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming | ||
# WandbLogger might update weights, epochs if resuming | ||
weight, epochs, hyp = opt.weight, opt.epochs, opt.hyp | ||
|
||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes | ||
names = ['item'] if opt.single_cls and len( | ||
|
@@ -225,8 +225,11 @@ def train(hyp, opt, device, tb_writer=None): | |
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | ||
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR | ||
if opt.linear_lr: | ||
lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp[ | ||
'lrf'] # linear | ||
|
||
def get_linear_lr(x): | ||
return (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] | ||
|
||
lf = get_linear_lr | ||
else: | ||
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] | ||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | ||
|
@@ -362,9 +365,10 @@ def train(hyp, opt, device, tb_writer=None): | |
|
||
# Start training | ||
t0 = time.time() | ||
nw = max(round(hyp['warmup_epochs'] * nb), | ||
1000) # number of warmup iterations, max(3 epochs, 1k iterations) | ||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training | ||
nw = max( | ||
round(hyp['warmup_epochs'] * nb), | ||
1000) # number of warm up iterations, max(3 epochs, 1k iterations) | ||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warm up to < 1/2 of training | ||
maps = np.zeros(nc) # mAP per class | ||
results = (0, 0, 0, 0, 0, 0, 0 | ||
) # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
|
@@ -425,7 +429,7 @@ def train(hyp, opt, device, tb_writer=None): | |
imgs = imgs.to(device, non_blocking=True).float( | ||
) / 255.0 # uint8 to float32, 0-255 to 0.0-1.0 | ||
|
||
# Warmup | ||
# Warm up | ||
if ni <= nw: | ||
xi = [0, nw] # x interp | ||
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) | ||
|
@@ -505,8 +509,8 @@ def train(hyp, opt, device, tb_writer=None): | |
] | ||
}) | ||
|
||
# end batch ------------------------------------------------------------------------------------------------ | ||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
# end batch ---------------------------------------------------------------------------------------------- | ||
# end epoch -------------------------------------------------------------------------------------------------- | ||
|
||
# Scheduler | ||
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard | ||
|
@@ -620,7 +624,7 @@ def train(hyp, opt, device, tb_writer=None): | |
best_model=best_fitness == fi) | ||
del ckpt | ||
|
||
# end epoch ---------------------------------------------------------------------------------------------------- | ||
# end epoch -------------------------------------------------------------------------------------------------- | ||
# end training | ||
if rank in [-1, 0]: | ||
# Plots | ||
|
@@ -795,7 +799,8 @@ def train(hyp, opt, device, tb_writer=None): | |
f, Loader=yaml.SafeLoader)) # replace | ||
opt.cfg, opt.weight, opt.resume = os.path.relpath( | ||
Path(ckpt).parent.parent / 'cfg.yaml'), ckpt, True | ||
opt.batch_size, opt.global_rank, opt.local_rank = opt.total_batch_size, *apriori # reinstate | ||
opt.batch_size, opt.global_rank, opt.local_rank = \ | ||
opt.total_batch_size, *apriori # reinstate | ||
opt.save_dir = os.path.relpath(Path(ckpt).parent.parent) | ||
logger.info('Resuming training from %s' % ckpt) | ||
else: | ||
|
Oops, something went wrong.