Skip to content

Commit

Permalink
Merge pull request #4
Browse files Browse the repository at this point in the history
bug fix
  • Loading branch information
ret-1 authored Aug 19, 2023
2 parents 52d317c + 549321b commit 61978b3
Show file tree
Hide file tree
Showing 5 changed files with 516 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ __pycache__/
*.py[cod]
*$py.class
.DS_Store
datasets
/datasets/

# C extensions
*.so
Expand Down
7 changes: 7 additions & 0 deletions yolox/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
from .mosaicdetection import MosaicDetection
from .mot import MOTDataset
128 changes: 128 additions & 0 deletions yolox/data/datasets/datasets_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
from torch.utils.data.dataset import Dataset as torchDataset

import bisect
from functools import wraps


class ConcatDataset(torchConcatDataset):
def __init__(self, datasets):
super(ConcatDataset, self).__init__(datasets)
if hasattr(self.datasets[0], "input_dim"):
self._input_dim = self.datasets[0].input_dim
self.input_dim = self.datasets[0].input_dim

def pull_item(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].pull_item(sample_idx)


class MixConcatDataset(torchConcatDataset):
def __init__(self, datasets):
super(MixConcatDataset, self).__init__(datasets)
if hasattr(self.datasets[0], "input_dim"):
self._input_dim = self.datasets[0].input_dim
self.input_dim = self.datasets[0].input_dim

def __getitem__(self, index):

if not isinstance(index, int):
idx = index[1]
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
if not isinstance(index, int):
index = (index[0], sample_idx, index[2])

return self.datasets[dataset_idx][index]


class Dataset(torchDataset):
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
that enables on the fly resizing of the ``input_dim``.
Args:
input_dimension (tuple): (width,height) tuple with default dimensions of the network
"""

def __init__(self, input_dimension, mosaic=True):
super().__init__()
self.__input_dim = input_dimension[:2]
self.enable_mosaic = mosaic

@property
def input_dim(self):
"""
Dimension that can be used by transforms to set the correct image size, etc.
This allows transforms to have a single source of truth
for the input dimension of the network.
Return:
list: Tuple containing the current width,height
"""
if hasattr(self, "_input_dim"):
return self._input_dim
return self.__input_dim

@staticmethod
def resize_getitem(getitem_fn):
"""
Decorator method that needs to be used around the ``__getitem__`` method. |br|
This decorator enables the on the fly resizing of
the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
Example:
>>> class CustomSet(ln.data.Dataset):
... def __len__(self):
... return 10
... @ln.data.Dataset.resize_getitem
... def __getitem__(self, index):
... # Should return (image, anno) but here we return input_dim
... return self.input_dim
>>> data = CustomSet((200,200))
>>> data[0]
(200, 200)
>>> data[(480,320), 0]
(480, 320)
"""

@wraps(getitem_fn)
def wrapper(self, index):
if not isinstance(index, int):
has_dim = True
self._input_dim = index[0]
self.enable_mosaic = index[2]
index = index[1]
else:
has_dim = False

ret_val = getitem_fn(self, index)

if has_dim:
del self._input_dim

return ret_val

return wrapper
242 changes: 242 additions & 0 deletions yolox/data/datasets/mosaicdetection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.

import cv2
import numpy as np

from yolox.utils import adjust_box_anns

import random

from ..data_augment import box_candidates, random_perspective, augment_hsv
from .datasets_wrapper import Dataset


def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
# TODO update doc
# index0 to top left part of image
if mosaic_index == 0:
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
small_coord = w - (x2 - x1), h - (y2 - y1), w, h
# index1 to top right part of image
elif mosaic_index == 1:
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
# index2 to bottom left part of image
elif mosaic_index == 2:
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
# index2 to bottom right part of image
elif mosaic_index == 3:
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
return (x1, y1, x2, y2), small_coord


class MosaicDetection(Dataset):
"""Detection dataset wrapper that performs mixup for normal dataset."""

def __init__(
self, dataset, img_size, mosaic=True, preproc=None,
degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
shear=2.0, perspective=0.0, enable_mixup=True, *args
):
"""
Args:
dataset(Dataset) : Pytorch dataset object.
img_size (tuple):
mosaic (bool): enable mosaic augmentation or not.
preproc (func):
degrees (float):
translate (float):
scale (tuple):
mscale (tuple):
shear (float):
perspective (float):
enable_mixup (bool):
*args(tuple) : Additional arguments for mixup random sampler.
"""
super().__init__(img_size, mosaic=mosaic)
self._dataset = dataset
self.preproc = preproc
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
self.mixup_scale = mscale
self.enable_mosaic = mosaic
self.enable_mixup = enable_mixup

def __len__(self):
return len(self._dataset)

@Dataset.resize_getitem
def __getitem__(self, idx):
if self.enable_mosaic:
mosaic_labels = []
input_dim = self._dataset.input_dim
input_h, input_w = input_dim[0], input_dim[1]

# yc, xc = s, s # mosaic center x, y
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))

# 3 additional image indices
indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]

for i_mosaic, index in enumerate(indices):
img, _labels, _, _ = self._dataset.pull_item(index)
h0, w0 = img.shape[:2] # orig hw
scale = min(1. * input_h / h0, 1. * input_w / w0)
img = cv2.resize(
img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
)
# generate output mosaic image
(h, w, c) = img.shape[:3]
if i_mosaic == 0:
mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)

# suffix l means large image, while s means small image in mosaic aug.
(l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
)

mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
padw, padh = l_x1 - s_x1, l_y1 - s_y1

labels = _labels.copy()
# Normalized xywh to pixel xyxy format
if _labels.size > 0:
labels[:, 0] = scale * _labels[:, 0] + padw
labels[:, 1] = scale * _labels[:, 1] + padh
labels[:, 2] = scale * _labels[:, 2] + padw
labels[:, 3] = scale * _labels[:, 3] + padh
mosaic_labels.append(labels)

if len(mosaic_labels):
mosaic_labels = np.concatenate(mosaic_labels, 0)
'''
np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
'''

mosaic_labels = mosaic_labels[mosaic_labels[:, 0] < 2 * input_w]
mosaic_labels = mosaic_labels[mosaic_labels[:, 2] > 0]
mosaic_labels = mosaic_labels[mosaic_labels[:, 1] < 2 * input_h]
mosaic_labels = mosaic_labels[mosaic_labels[:, 3] > 0]

#augment_hsv(mosaic_img)
mosaic_img, mosaic_labels = random_perspective(
mosaic_img,
mosaic_labels,
degrees=self.degrees,
translate=self.translate,
scale=self.scale,
shear=self.shear,
perspective=self.perspective,
border=[-input_h // 2, -input_w // 2],
) # border to remove

# -----------------------------------------------------------------
# CopyPaste: https://arxiv.org/abs/2012.07177
# -----------------------------------------------------------------
if self.enable_mixup and not len(mosaic_labels) == 0:
mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)

mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
img_info = (mix_img.shape[1], mix_img.shape[0])

return mix_img, padded_labels, img_info, np.array([idx])

else:
self._dataset._input_dim = self.input_dim
img, label, img_info, id_ = self._dataset.pull_item(idx)
img, label = self.preproc(img, label, self.input_dim)
return img, label, img_info, id_

def mixup(self, origin_img, origin_labels, input_dim):
jit_factor = random.uniform(*self.mixup_scale)
FLIP = random.uniform(0, 1) > 0.5
cp_labels = []
while len(cp_labels) == 0:
cp_index = random.randint(0, self.__len__() - 1)
cp_labels = self._dataset.load_anno(cp_index)
img, cp_labels, _, _ = self._dataset.pull_item(cp_index)

if len(img.shape) == 3:
cp_img = np.ones((input_dim[0], input_dim[1], 3)) * 114.0
else:
cp_img = np.ones(input_dim) * 114.0
cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32)
cp_img[
: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)
] = resized_img
cp_img = cv2.resize(
cp_img,
(int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
)
cp_scale_ratio *= jit_factor
if FLIP:
cp_img = cp_img[:, ::-1, :]

origin_h, origin_w = cp_img.shape[:2]
target_h, target_w = origin_img.shape[:2]
padded_img = np.zeros(
(max(origin_h, target_h), max(origin_w, target_w), 3)
).astype(np.uint8)
padded_img[:origin_h, :origin_w] = cp_img

x_offset, y_offset = 0, 0
if padded_img.shape[0] > target_h:
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
if padded_img.shape[1] > target_w:
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
padded_cropped_img = padded_img[
y_offset: y_offset + target_h, x_offset: x_offset + target_w
]

cp_bboxes_origin_np = adjust_box_anns(
cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h
)
if FLIP:
cp_bboxes_origin_np[:, 0::2] = (
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
)
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
'''
cp_bboxes_transformed_np[:, 0::2] = np.clip(
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
)
cp_bboxes_transformed_np[:, 1::2] = np.clip(
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
)
'''
cp_bboxes_transformed_np[:, 0::2] = cp_bboxes_transformed_np[:, 0::2] - x_offset
cp_bboxes_transformed_np[:, 1::2] = cp_bboxes_transformed_np[:, 1::2] - y_offset
keep_list = box_candidates(cp_bboxes_origin_np.T, cp_bboxes_transformed_np.T, 5)

if keep_list.sum() >= 1.0:
cls_labels = cp_labels[keep_list, 4:5].copy()
id_labels = cp_labels[keep_list, 5:6].copy()
box_labels = cp_bboxes_transformed_np[keep_list]
labels = np.hstack((box_labels, cls_labels, id_labels))
# remove outside bbox
labels = labels[labels[:, 0] < target_w]
labels = labels[labels[:, 2] > 0]
labels = labels[labels[:, 1] < target_h]
labels = labels[labels[:, 3] > 0]
origin_labels = np.vstack((origin_labels, labels))
origin_img = origin_img.astype(np.float32)
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)

return origin_img, origin_labels
Loading

0 comments on commit 61978b3

Please sign in to comment.