diff --git a/.gitignore b/.gitignore index 9e24aa8..cdd8db4 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ __pycache__/ *.py[cod] *$py.class .DS_Store -datasets +/datasets/ # C extensions *.so diff --git a/yolox/data/datasets/__init__.py b/yolox/data/datasets/__init__.py new file mode 100644 index 0000000..61065a8 --- /dev/null +++ b/yolox/data/datasets/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. + +from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset +from .mosaicdetection import MosaicDetection +from .mot import MOTDataset diff --git a/yolox/data/datasets/datasets_wrapper.py b/yolox/data/datasets/datasets_wrapper.py new file mode 100644 index 0000000..a262e6a --- /dev/null +++ b/yolox/data/datasets/datasets_wrapper.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. + +from torch.utils.data.dataset import ConcatDataset as torchConcatDataset +from torch.utils.data.dataset import Dataset as torchDataset + +import bisect +from functools import wraps + + +class ConcatDataset(torchConcatDataset): + def __init__(self, datasets): + super(ConcatDataset, self).__init__(datasets) + if hasattr(self.datasets[0], "input_dim"): + self._input_dim = self.datasets[0].input_dim + self.input_dim = self.datasets[0].input_dim + + def pull_item(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx].pull_item(sample_idx) + + +class MixConcatDataset(torchConcatDataset): + def __init__(self, datasets): + super(MixConcatDataset, self).__init__(datasets) + if hasattr(self.datasets[0], "input_dim"): + self._input_dim = self.datasets[0].input_dim + self.input_dim = self.datasets[0].input_dim + + def __getitem__(self, index): + + if not isinstance(index, int): + idx = index[1] + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + if not isinstance(index, int): + index = (index[0], sample_idx, index[2]) + + return self.datasets[dataset_idx][index] + + +class Dataset(torchDataset): + """ This class is a subclass of the base :class:`torch.utils.data.Dataset`, + that enables on the fly resizing of the ``input_dim``. + + Args: + input_dimension (tuple): (width,height) tuple with default dimensions of the network + """ + + def __init__(self, input_dimension, mosaic=True): + super().__init__() + self.__input_dim = input_dimension[:2] + self.enable_mosaic = mosaic + + @property + def input_dim(self): + """ + Dimension that can be used by transforms to set the correct image size, etc. + This allows transforms to have a single source of truth + for the input dimension of the network. + + Return: + list: Tuple containing the current width,height + """ + if hasattr(self, "_input_dim"): + return self._input_dim + return self.__input_dim + + @staticmethod + def resize_getitem(getitem_fn): + """ + Decorator method that needs to be used around the ``__getitem__`` method. |br| + This decorator enables the on the fly resizing of + the ``input_dim`` with our :class:`~lightnet.data.DataLoader` class. + + Example: + >>> class CustomSet(ln.data.Dataset): + ... def __len__(self): + ... return 10 + ... @ln.data.Dataset.resize_getitem + ... def __getitem__(self, index): + ... # Should return (image, anno) but here we return input_dim + ... return self.input_dim + >>> data = CustomSet((200,200)) + >>> data[0] + (200, 200) + >>> data[(480,320), 0] + (480, 320) + """ + + @wraps(getitem_fn) + def wrapper(self, index): + if not isinstance(index, int): + has_dim = True + self._input_dim = index[0] + self.enable_mosaic = index[2] + index = index[1] + else: + has_dim = False + + ret_val = getitem_fn(self, index) + + if has_dim: + del self._input_dim + + return ret_val + + return wrapper diff --git a/yolox/data/datasets/mosaicdetection.py b/yolox/data/datasets/mosaicdetection.py new file mode 100644 index 0000000..d2bf39f --- /dev/null +++ b/yolox/data/datasets/mosaicdetection.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +# Copyright (c) Megvii, Inc. and its affiliates. + +import cv2 +import numpy as np + +from yolox.utils import adjust_box_anns + +import random + +from ..data_augment import box_candidates, random_perspective, augment_hsv +from .datasets_wrapper import Dataset + + +def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w): + # TODO update doc + # index0 to top left part of image + if mosaic_index == 0: + x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc + small_coord = w - (x2 - x1), h - (y2 - y1), w, h + # index1 to top right part of image + elif mosaic_index == 1: + x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc + small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h + # index2 to bottom left part of image + elif mosaic_index == 2: + x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h) + small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h) + # index2 to bottom right part of image + elif mosaic_index == 3: + x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa + small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h) + return (x1, y1, x2, y2), small_coord + + +class MosaicDetection(Dataset): + """Detection dataset wrapper that performs mixup for normal dataset.""" + + def __init__( + self, dataset, img_size, mosaic=True, preproc=None, + degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5), + shear=2.0, perspective=0.0, enable_mixup=True, *args + ): + """ + + Args: + dataset(Dataset) : Pytorch dataset object. + img_size (tuple): + mosaic (bool): enable mosaic augmentation or not. + preproc (func): + degrees (float): + translate (float): + scale (tuple): + mscale (tuple): + shear (float): + perspective (float): + enable_mixup (bool): + *args(tuple) : Additional arguments for mixup random sampler. + """ + super().__init__(img_size, mosaic=mosaic) + self._dataset = dataset + self.preproc = preproc + self.degrees = degrees + self.translate = translate + self.scale = scale + self.shear = shear + self.perspective = perspective + self.mixup_scale = mscale + self.enable_mosaic = mosaic + self.enable_mixup = enable_mixup + + def __len__(self): + return len(self._dataset) + + @Dataset.resize_getitem + def __getitem__(self, idx): + if self.enable_mosaic: + mosaic_labels = [] + input_dim = self._dataset.input_dim + input_h, input_w = input_dim[0], input_dim[1] + + # yc, xc = s, s # mosaic center x, y + yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) + xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) + + # 3 additional image indices + indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)] + + for i_mosaic, index in enumerate(indices): + img, _labels, _, _ = self._dataset.pull_item(index) + h0, w0 = img.shape[:2] # orig hw + scale = min(1. * input_h / h0, 1. * input_w / w0) + img = cv2.resize( + img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR + ) + # generate output mosaic image + (h, w, c) = img.shape[:3] + if i_mosaic == 0: + mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8) + + # suffix l means large image, while s means small image in mosaic aug. + (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate( + mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w + ) + + mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] + padw, padh = l_x1 - s_x1, l_y1 - s_y1 + + labels = _labels.copy() + # Normalized xywh to pixel xyxy format + if _labels.size > 0: + labels[:, 0] = scale * _labels[:, 0] + padw + labels[:, 1] = scale * _labels[:, 1] + padh + labels[:, 2] = scale * _labels[:, 2] + padw + labels[:, 3] = scale * _labels[:, 3] + padh + mosaic_labels.append(labels) + + if len(mosaic_labels): + mosaic_labels = np.concatenate(mosaic_labels, 0) + ''' + np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0]) + np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1]) + np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2]) + np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3]) + ''' + + mosaic_labels = mosaic_labels[mosaic_labels[:, 0] < 2 * input_w] + mosaic_labels = mosaic_labels[mosaic_labels[:, 2] > 0] + mosaic_labels = mosaic_labels[mosaic_labels[:, 1] < 2 * input_h] + mosaic_labels = mosaic_labels[mosaic_labels[:, 3] > 0] + + #augment_hsv(mosaic_img) + mosaic_img, mosaic_labels = random_perspective( + mosaic_img, + mosaic_labels, + degrees=self.degrees, + translate=self.translate, + scale=self.scale, + shear=self.shear, + perspective=self.perspective, + border=[-input_h // 2, -input_w // 2], + ) # border to remove + + # ----------------------------------------------------------------- + # CopyPaste: https://arxiv.org/abs/2012.07177 + # ----------------------------------------------------------------- + if self.enable_mixup and not len(mosaic_labels) == 0: + mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim) + + mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim) + img_info = (mix_img.shape[1], mix_img.shape[0]) + + return mix_img, padded_labels, img_info, np.array([idx]) + + else: + self._dataset._input_dim = self.input_dim + img, label, img_info, id_ = self._dataset.pull_item(idx) + img, label = self.preproc(img, label, self.input_dim) + return img, label, img_info, id_ + + def mixup(self, origin_img, origin_labels, input_dim): + jit_factor = random.uniform(*self.mixup_scale) + FLIP = random.uniform(0, 1) > 0.5 + cp_labels = [] + while len(cp_labels) == 0: + cp_index = random.randint(0, self.__len__() - 1) + cp_labels = self._dataset.load_anno(cp_index) + img, cp_labels, _, _ = self._dataset.pull_item(cp_index) + + if len(img.shape) == 3: + cp_img = np.ones((input_dim[0], input_dim[1], 3)) * 114.0 + else: + cp_img = np.ones(input_dim) * 114.0 + cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.float32) + cp_img[ + : int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio) + ] = resized_img + cp_img = cv2.resize( + cp_img, + (int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)), + ) + cp_scale_ratio *= jit_factor + if FLIP: + cp_img = cp_img[:, ::-1, :] + + origin_h, origin_w = cp_img.shape[:2] + target_h, target_w = origin_img.shape[:2] + padded_img = np.zeros( + (max(origin_h, target_h), max(origin_w, target_w), 3) + ).astype(np.uint8) + padded_img[:origin_h, :origin_w] = cp_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h - 1) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w - 1) + padded_cropped_img = padded_img[ + y_offset: y_offset + target_h, x_offset: x_offset + target_w + ] + + cp_bboxes_origin_np = adjust_box_anns( + cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h + ) + if FLIP: + cp_bboxes_origin_np[:, 0::2] = ( + origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1] + ) + cp_bboxes_transformed_np = cp_bboxes_origin_np.copy() + ''' + cp_bboxes_transformed_np[:, 0::2] = np.clip( + cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w + ) + cp_bboxes_transformed_np[:, 1::2] = np.clip( + cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h + ) + ''' + cp_bboxes_transformed_np[:, 0::2] = cp_bboxes_transformed_np[:, 0::2] - x_offset + cp_bboxes_transformed_np[:, 1::2] = cp_bboxes_transformed_np[:, 1::2] - y_offset + keep_list = box_candidates(cp_bboxes_origin_np.T, cp_bboxes_transformed_np.T, 5) + + if keep_list.sum() >= 1.0: + cls_labels = cp_labels[keep_list, 4:5].copy() + id_labels = cp_labels[keep_list, 5:6].copy() + box_labels = cp_bboxes_transformed_np[keep_list] + labels = np.hstack((box_labels, cls_labels, id_labels)) + # remove outside bbox + labels = labels[labels[:, 0] < target_w] + labels = labels[labels[:, 2] > 0] + labels = labels[labels[:, 1] < target_h] + labels = labels[labels[:, 3] > 0] + origin_labels = np.vstack((origin_labels, labels)) + origin_img = origin_img.astype(np.float32) + origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32) + + return origin_img, origin_labels diff --git a/yolox/data/datasets/mot.py b/yolox/data/datasets/mot.py new file mode 100644 index 0000000..34d2ed3 --- /dev/null +++ b/yolox/data/datasets/mot.py @@ -0,0 +1,138 @@ +import cv2 +import numpy as np +from pycocotools.coco import COCO + +import os +import torch +from ..dataloading import get_yolox_datadir +from .datasets_wrapper import Dataset + + +class MOTDataset(Dataset): + """ + COCO dataset class. + """ + + def __init__( + self, + data_dir=None, + json_file="train_half.json", + name="train", + img_size=(608, 1088), + preproc=None, + return_origin_img=False, + ): + """ + COCO dataset initialization. Annotation data are read into memory by COCO API. + Args: + data_dir (str): dataset root directory + json_file (str): COCO json file name + name (str): COCO data name (e.g. 'train2017' or 'val2017') + img_size (int): target image size after pre-processing + preproc: data augmentation strategy + return_origin_img (bool): whether to return origin image + """ + super().__init__(img_size) + if data_dir is None: + data_dir = os.path.join(get_yolox_datadir(), "mot") + self.data_dir = data_dir + self.json_file = json_file + + self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file)) + self.ids = self.coco.getImgIds() + self.class_ids = sorted(self.coco.getCatIds()) + cats = self.coco.loadCats(self.coco.getCatIds()) + self._classes = tuple([c["name"] for c in cats]) + self.annotations = self._load_coco_annotations() + self.name = name + self.img_size = img_size + self.preproc = preproc + self.return_origin_img = return_origin_img + + def __len__(self): + return len(self.ids) + + def _load_coco_annotations(self): + return [self.load_anno_from_ids(_ids) for _ids in self.ids] + + def load_anno_from_ids(self, id_): + im_ann = self.coco.loadImgs(id_)[0] + width = im_ann["width"] + height = im_ann["height"] + frame_id = im_ann["frame_id"] + video_id = im_ann["video_id"] + anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False) + annotations = self.coco.loadAnns(anno_ids) + objs = [] + for obj in annotations: + x1 = obj["bbox"][0] + y1 = obj["bbox"][1] + x2 = x1 + obj["bbox"][2] + y2 = y1 + obj["bbox"][3] + if obj["area"] > 0 and x2 >= x1 and y2 >= y1: + obj["clean_bbox"] = [x1, y1, x2, y2] + objs.append(obj) + + num_objs = len(objs) + + res = np.zeros((num_objs, 6)) + + for ix, obj in enumerate(objs): + cls = self.class_ids.index(obj["category_id"]) + res[ix, 0:4] = obj["clean_bbox"] + res[ix, 4] = cls + res[ix, 5] = obj["track_id"] + + file_name = im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg" + img_info = (height, width, frame_id, video_id, file_name) + + del im_ann, annotations + + return (res, img_info, file_name) + + def load_anno(self, index): + return self.annotations[index][0] + + def pull_item(self, index): + id_ = self.ids[index] + + res, img_info, file_name = self.annotations[index] + # load image and preprocess + img_file = os.path.join( + self.data_dir, self.name, file_name + ) + img = cv2.imread(img_file) + assert img is not None + + return img, res.copy(), img_info, np.array([id_]) + + @Dataset.resize_getitem + def __getitem__(self, index): + """ + One image / label pair for the given index is picked up and pre-processed. + + Args: + index (int): data index + + Returns: + img (numpy.ndarray): pre-processed image + padded_labels (torch.Tensor): pre-processed label data. + The shape is :math:`[max_labels, 5]`. + each label consists of [class, xc, yc, w, h]: + class (float): class index. + xc, yc (float) : center of bbox whose values range from 0 to 1. + w, h (float) : size of bbox whose values range from 0 to 1. + info_img : tuple of h, w, nh, nw, dx, dy. + h, w (int): original shape of the image + nh, nw (int): shape of the resized image without padding + dx, dy (int): pad size + img_id (int): same as the input index. Used for evaluation. + """ + img, target, img_info, img_id = self.pull_item(index) + origin_img = torch.from_numpy(img).permute(2,0,1) + if self.preproc is not None: + img, target = self.preproc(img, target, self.input_dim) + if self.return_origin_img: + return origin_img, img, target, img_info, img_id + else: + return img, target, img_info, img_id