-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bug fix
- Loading branch information
Showing
5 changed files
with
516 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ __pycache__/ | |
*.py[cod] | ||
*$py.class | ||
.DS_Store | ||
datasets | ||
/datasets/ | ||
|
||
# C extensions | ||
*.so | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. | ||
|
||
from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset | ||
from .mosaicdetection import MosaicDetection | ||
from .mot import MOTDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. | ||
|
||
from torch.utils.data.dataset import ConcatDataset as torchConcatDataset | ||
from torch.utils.data.dataset import Dataset as torchDataset | ||
|
||
import bisect | ||
from functools import wraps | ||
|
||
|
||
class ConcatDataset(torchConcatDataset): | ||
def __init__(self, datasets): | ||
super(ConcatDataset, self).__init__(datasets) | ||
if hasattr(self.datasets[0], "input_dim"): | ||
self._input_dim = self.datasets[0].input_dim | ||
self.input_dim = self.datasets[0].input_dim | ||
|
||
def pull_item(self, idx): | ||
if idx < 0: | ||
if -idx > len(self): | ||
raise ValueError( | ||
"absolute value of index should not exceed dataset length" | ||
) | ||
idx = len(self) + idx | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||
if dataset_idx == 0: | ||
sample_idx = idx | ||
else: | ||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||
return self.datasets[dataset_idx].pull_item(sample_idx) | ||
|
||
|
||
class MixConcatDataset(torchConcatDataset): | ||
def __init__(self, datasets): | ||
super(MixConcatDataset, self).__init__(datasets) | ||
if hasattr(self.datasets[0], "input_dim"): | ||
self._input_dim = self.datasets[0].input_dim | ||
self.input_dim = self.datasets[0].input_dim | ||
|
||
def __getitem__(self, index): | ||
|
||
if not isinstance(index, int): | ||
idx = index[1] | ||
if idx < 0: | ||
if -idx > len(self): | ||
raise ValueError( | ||
"absolute value of index should not exceed dataset length" | ||
) | ||
idx = len(self) + idx | ||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | ||
if dataset_idx == 0: | ||
sample_idx = idx | ||
else: | ||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | ||
if not isinstance(index, int): | ||
index = (index[0], sample_idx, index[2]) | ||
|
||
return self.datasets[dataset_idx][index] | ||
|
||
|
||
class Dataset(torchDataset): | ||
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`, | ||
that enables on the fly resizing of the ``input_dim``. | ||
Args: | ||
input_dimension (tuple): (width,height) tuple with default dimensions of the network | ||
""" | ||
|
||
def __init__(self, input_dimension, mosaic=True): | ||
super().__init__() | ||
self.__input_dim = input_dimension[:2] | ||
self.enable_mosaic = mosaic | ||
|
||
@property | ||
def input_dim(self): | ||
""" | ||
Dimension that can be used by transforms to set the correct image size, etc. | ||
This allows transforms to have a single source of truth | ||
for the input dimension of the network. | ||
Return: | ||
list: Tuple containing the current width,height | ||
""" | ||
if hasattr(self, "_input_dim"): | ||
return self._input_dim | ||
return self.__input_dim | ||
|
||
@staticmethod | ||
def resize_getitem(getitem_fn): | ||
""" | ||
Decorator method that needs to be used around the ``__getitem__`` method. |br| | ||
This decorator enables the on the fly resizing of | ||
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class. | ||
Example: | ||
>>> class CustomSet(ln.data.Dataset): | ||
... def __len__(self): | ||
... return 10 | ||
... @ln.data.Dataset.resize_getitem | ||
... def __getitem__(self, index): | ||
... # Should return (image, anno) but here we return input_dim | ||
... return self.input_dim | ||
>>> data = CustomSet((200,200)) | ||
>>> data[0] | ||
(200, 200) | ||
>>> data[(480,320), 0] | ||
(480, 320) | ||
""" | ||
|
||
@wraps(getitem_fn) | ||
def wrapper(self, index): | ||
if not isinstance(index, int): | ||
has_dim = True | ||
self._input_dim = index[0] | ||
self.enable_mosaic = index[2] | ||
index = index[1] | ||
else: | ||
has_dim = False | ||
|
||
ret_val = getitem_fn(self, index) | ||
|
||
if has_dim: | ||
del self._input_dim | ||
|
||
return ret_val | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from yolox.utils import adjust_box_anns | ||
|
||
import random | ||
|
||
from ..data_augment import box_candidates, random_perspective, augment_hsv | ||
from .datasets_wrapper import Dataset | ||
|
||
|
||
def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w): | ||
# TODO update doc | ||
# index0 to top left part of image | ||
if mosaic_index == 0: | ||
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc | ||
small_coord = w - (x2 - x1), h - (y2 - y1), w, h | ||
# index1 to top right part of image | ||
elif mosaic_index == 1: | ||
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc | ||
small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h | ||
# index2 to bottom left part of image | ||
elif mosaic_index == 2: | ||
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h) | ||
small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h) | ||
# index2 to bottom right part of image | ||
elif mosaic_index == 3: | ||
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa | ||
small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h) | ||
return (x1, y1, x2, y2), small_coord | ||
|
||
|
||
class MosaicDetection(Dataset): | ||
"""Detection dataset wrapper that performs mixup for normal dataset.""" | ||
|
||
def __init__( | ||
self, dataset, img_size, mosaic=True, preproc=None, | ||
degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5), | ||
shear=2.0, perspective=0.0, enable_mixup=True, *args | ||
): | ||
""" | ||
Args: | ||
dataset(Dataset) : Pytorch dataset object. | ||
img_size (tuple): | ||
mosaic (bool): enable mosaic augmentation or not. | ||
preproc (func): | ||
degrees (float): | ||
translate (float): | ||
scale (tuple): | ||
mscale (tuple): | ||
shear (float): | ||
perspective (float): | ||
enable_mixup (bool): | ||
*args(tuple) : Additional arguments for mixup random sampler. | ||
""" | ||
super().__init__(img_size, mosaic=mosaic) | ||
self._dataset = dataset | ||
self.preproc = preproc | ||
self.degrees = degrees | ||
self.translate = translate | ||
self.scale = scale | ||
self.shear = shear | ||
self.perspective = perspective | ||
self.mixup_scale = mscale | ||
self.enable_mosaic = mosaic | ||
self.enable_mixup = enable_mixup | ||
|
||
def __len__(self): | ||
return len(self._dataset) | ||
|
||
@Dataset.resize_getitem | ||
def __getitem__(self, idx): | ||
if self.enable_mosaic: | ||
mosaic_labels = [] | ||
input_dim = self._dataset.input_dim | ||
input_h, input_w = input_dim[0], input_dim[1] | ||
|
||
# yc, xc = s, s # mosaic center x, y | ||
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) | ||
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) | ||
|
||
# 3 additional image indices | ||
indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)] | ||
|
||
for i_mosaic, index in enumerate(indices): | ||
img, _labels, _, _ = self._dataset.pull_item(index) | ||
h0, w0 = img.shape[:2] # orig hw | ||
scale = min(1. * input_h / h0, 1. * input_w / w0) | ||
img = cv2.resize( | ||
img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR | ||
) | ||
# generate output mosaic image | ||
(h, w, c) = img.shape[:3] | ||
if i_mosaic == 0: | ||
mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8) | ||
|
||
# suffix l means large image, while s means small image in mosaic aug. | ||
(l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate( | ||
mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w | ||
) | ||
|
||
mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] | ||
padw, padh = l_x1 - s_x1, l_y1 - s_y1 | ||
|
||
labels = _labels.copy() | ||
# Normalized xywh to pixel xyxy format | ||
if _labels.size > 0: | ||
labels[:, 0] = scale * _labels[:, 0] + padw | ||
labels[:, 1] = scale * _labels[:, 1] + padh | ||
labels[:, 2] = scale * _labels[:, 2] + padw | ||
labels[:, 3] = scale * _labels[:, 3] + padh | ||
mosaic_labels.append(labels) | ||
|
||
if len(mosaic_labels): | ||
mosaic_labels = np.concatenate(mosaic_labels, 0) | ||
''' | ||
np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0]) | ||
np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1]) | ||
np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2]) | ||
np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3]) | ||
''' | ||
|
||
mosaic_labels = mosaic_labels[mosaic_labels[:, 0] < 2 * input_w] | ||
mosaic_labels = mosaic_labels[mosaic_labels[:, 2] > 0] | ||
mosaic_labels = mosaic_labels[mosaic_labels[:, 1] < 2 * input_h] | ||
mosaic_labels = mosaic_labels[mosaic_labels[:, 3] > 0] | ||
|
||
#augment_hsv(mosaic_img) | ||
mosaic_img, mosaic_labels = random_perspective( | ||
mosaic_img, | ||
mosaic_labels, | ||
degrees=self.degrees, | ||
translate=self.translate, | ||
scale=self.scale, | ||
shear=self.shear, | ||
perspective=self.perspective, | ||
border=[-input_h // 2, -input_w // 2], | ||
) # border to remove | ||
|
||
# ----------------------------------------------------------------- | ||
# CopyPaste: https://arxiv.org/abs/2012.07177 | ||
# ----------------------------------------------------------------- | ||
if self.enable_mixup and not len(mosaic_labels) == 0: | ||
mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim) | ||
|
||
mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim) | ||
img_info = (mix_img.shape[1], mix_img.shape[0]) | ||
|
||
return mix_img, padded_labels, img_info, np.array([idx]) | ||
|
||
else: | ||
self._dataset._input_dim = self.input_dim | ||
img, label, img_info, id_ = self._dataset.pull_item(idx) | ||
img, label = self.preproc(img, label, self.input_dim) | ||
return img, label, img_info, id_ | ||
|
||
def mixup(self, origin_img, origin_labels, input_dim): | ||
jit_factor = random.uniform(*self.mixup_scale) | ||
FLIP = random.uniform(0, 1) > 0.5 | ||
cp_labels = [] | ||
while len(cp_labels) == 0: | ||
cp_index = random.randint(0, self.__len__() - 1) | ||
cp_labels = self._dataset.load_anno(cp_index) | ||
img, cp_labels, _, _ = self._dataset.pull_item(cp_index) | ||
|
||
if len(img.shape) == 3: | ||
cp_img = np.ones((input_dim[0], input_dim[1], 3)) * 114.0 | ||
else: | ||
cp_img = np.ones(input_dim) * 114.0 | ||
cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1]) | ||
resized_img = cv2.resize( | ||
img, | ||
(int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)), | ||
interpolation=cv2.INTER_LINEAR, | ||
).astype(np.float32) | ||
cp_img[ | ||
: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio) | ||
] = resized_img | ||
cp_img = cv2.resize( | ||
cp_img, | ||
(int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)), | ||
) | ||
cp_scale_ratio *= jit_factor | ||
if FLIP: | ||
cp_img = cp_img[:, ::-1, :] | ||
|
||
origin_h, origin_w = cp_img.shape[:2] | ||
target_h, target_w = origin_img.shape[:2] | ||
padded_img = np.zeros( | ||
(max(origin_h, target_h), max(origin_w, target_w), 3) | ||
).astype(np.uint8) | ||
padded_img[:origin_h, :origin_w] = cp_img | ||
|
||
x_offset, y_offset = 0, 0 | ||
if padded_img.shape[0] > target_h: | ||
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1) | ||
if padded_img.shape[1] > target_w: | ||
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1) | ||
padded_cropped_img = padded_img[ | ||
y_offset: y_offset + target_h, x_offset: x_offset + target_w | ||
] | ||
|
||
cp_bboxes_origin_np = adjust_box_anns( | ||
cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h | ||
) | ||
if FLIP: | ||
cp_bboxes_origin_np[:, 0::2] = ( | ||
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1] | ||
) | ||
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy() | ||
''' | ||
cp_bboxes_transformed_np[:, 0::2] = np.clip( | ||
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w | ||
) | ||
cp_bboxes_transformed_np[:, 1::2] = np.clip( | ||
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h | ||
) | ||
''' | ||
cp_bboxes_transformed_np[:, 0::2] = cp_bboxes_transformed_np[:, 0::2] - x_offset | ||
cp_bboxes_transformed_np[:, 1::2] = cp_bboxes_transformed_np[:, 1::2] - y_offset | ||
keep_list = box_candidates(cp_bboxes_origin_np.T, cp_bboxes_transformed_np.T, 5) | ||
|
||
if keep_list.sum() >= 1.0: | ||
cls_labels = cp_labels[keep_list, 4:5].copy() | ||
id_labels = cp_labels[keep_list, 5:6].copy() | ||
box_labels = cp_bboxes_transformed_np[keep_list] | ||
labels = np.hstack((box_labels, cls_labels, id_labels)) | ||
# remove outside bbox | ||
labels = labels[labels[:, 0] < target_w] | ||
labels = labels[labels[:, 2] > 0] | ||
labels = labels[labels[:, 1] < target_h] | ||
labels = labels[labels[:, 3] > 0] | ||
origin_labels = np.vstack((origin_labels, labels)) | ||
origin_img = origin_img.astype(np.float32) | ||
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32) | ||
|
||
return origin_img, origin_labels |
Oops, something went wrong.