From d512ce7c4f103e8887960198505518bed404abdc Mon Sep 17 00:00:00 2001 From: liangliang <107097683+mtjhl@users.noreply.github.com> Date: Tue, 10 Oct 2023 20:46:35 +0800 Subject: [PATCH] support cache images to RAM to speed up training. (#924) * use cache to accelerate the training of yolov6 * normalize codes about cache images to RAM. * fix the bug in cache images into RAM * only display cache progress in gpu 0 in DDP mode. --------- Co-authored-by: zhujiajian98 --- tools/train.py | 1 + yolov6/core/engine.py | 6 ++- yolov6/data/data_load.py | 9 +++-- yolov6/data/datasets.py | 81 ++++++++++++++++++++++++++++++++++------ 4 files changed, 80 insertions(+), 17 deletions(-) diff --git a/tools/train.py b/tools/train.py index 635c68e4..d7f1ee30 100644 --- a/tools/train.py +++ b/tools/train.py @@ -58,6 +58,7 @@ def get_args_parser(add_help=True): parser.add_argument('--specific-shape', action='store_true', help='rectangular training') parser.add_argument('--height', type=int, default=None, help='image height of model input') parser.add_argument('--width', type=int, default=None, help='image width of model input') + parser.add_argument('--cache-ram', action='store_true', help='whether to cache images into RAM to speed up training') return parser diff --git a/yolov6/core/engine.py b/yolov6/core/engine.py index 10545135..d91f0173 100644 --- a/yolov6/core/engine.py +++ b/yolov6/core/engine.py @@ -387,7 +387,8 @@ def get_data_loader(args, cfg, data_dict): hyp=dict(cfg.data_aug), augment=True, rect=args.rect, rank=args.local_rank, workers=args.workers, shuffle=True, check_images=args.check_images, check_labels=args.check_labels, data_dict=data_dict, task='train', - specific_shape=args.specific_shape, height=args.height, width=args.width)[0] + specific_shape=args.specific_shape, height=args.height, width=args.width, + cache_ram=args.cache_ram)[0] # create val dataloader val_loader = None if args.rank in [-1, 0]: @@ -396,7 +397,8 @@ def get_data_loader(args, cfg, data_dict): hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5, workers=args.workers, check_images=args.check_images, check_labels=args.check_labels, data_dict=data_dict, task='val', - specific_shape=args.specific_shape, height=args.height, width=args.width)[0] + specific_shape=args.specific_shape, height=args.height, width=args.width, + cache_ram=args.cache_ram)[0] return train_loader, val_loader diff --git a/yolov6/data/data_load.py b/yolov6/data/data_load.py index e68e8d71..bc0fcff8 100644 --- a/yolov6/data/data_load.py +++ b/yolov6/data/data_load.py @@ -30,9 +30,9 @@ def create_dataloader( task="Train", specific_shape=False, height=1088, - width=1920 - -): + width=1920, + cache_ram=False + ): """Create general dataloader. Returns dataloader and dataset @@ -59,7 +59,8 @@ def create_dataloader( task=task, specific_shape = specific_shape, height=height, - width=width + width=width, + cache_ram=cache_ram ) batch_size = min(batch_size, len(dataset)) diff --git a/yolov6/data/datasets.py b/yolov6/data/datasets.py index a5b8bc05..a9dcd4b0 100644 --- a/yolov6/data/datasets.py +++ b/yolov6/data/datasets.py @@ -30,6 +30,9 @@ mosaic_augmentation, ) from yolov6.utils.events import LOGGER +import copy +import psutil +from multiprocessing.pool import ThreadPool # Parameters @@ -67,11 +70,11 @@ def __init__( task="train", specific_shape = False, height=1088, - width=1920 - + width=1920, + cache_ram=False ): assert task.lower() in ("train", "val", "test", "speed"), f"Not supported task: {task}" - t1 = time.time() + tik = time.time() self.__dict__.update(locals()) self.main_process = self.rank in (-1, 0) self.task = self.task.capitalize() @@ -81,6 +84,12 @@ def __init__( self.specific_shape = specific_shape self.target_height = height self.target_width = width + self.cache_ram = cache_ram + if self.cache_ram: + self.num_imgs = len(self.img_paths) + self.imgs = [None] * self.num_imgs + self.cache_images(num_imgs=self.num_imgs) + if self.rect: shapes = [self.img_info[p]["shape"] for p in self.img_paths] self.shapes = np.array(shapes, dtype=np.float64) @@ -98,14 +107,61 @@ def __init__( self.sort_files_shapes() - t2 = time.time() + tok = time.time() + if self.main_process: - LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1)) + LOGGER.info(f"%.1fs for dataset initialization." % (tok - tik)) + + def cache_images(self, num_imgs=None): + assert num_imgs is not None, "num_imgs must be specified as the size of the dataset" + + mem = psutil.virtual_memory() + mem_required = self.cal_cache_occupy(num_imgs) + gb = 1 << 30 + + if mem_required > mem.available: + self.cache_ram = False + LOGGER.warning("Not enough RAM to cache images, caching is disabled.") + else: + LOGGER.warning( + f"{mem_required / gb:.1f}GB RAM required, " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, " + f"Since the first thing we do is cache, " + f"there is no guarantee that the remaining memory space is sufficient" + ) + + print(f"self.imgs: {len(self.imgs)}") + LOGGER.info("You are using cached images in RAM to accelerate training!") + LOGGER.info( + "Caching images...\n" + "This might take some time for your dataset" + ) + num_threads = min(16, max(1, os.cpu_count() - 1)) + load_imgs = ThreadPool(num_threads).imap(self.load_image, range(num_imgs)) + pbar = tqdm(enumerate(load_imgs), total=num_imgs, disable=self.rank > 0) + for i, (x, (h0, w0), shape) in pbar: + self.imgs[i] = x + + def __del__(self): + if self.cache_ram: + del self.imgs + + def cal_cache_occupy(self, num_imgs): + '''estimate the memory required to cache images in RAM. + ''' + cache_bytes = 0 + num_imgs = len(self.img_paths) + num_samples = min(num_imgs, 32) + for _ in range(num_samples): + img, _, _ = self.load_image(index=random.randint(0, len(self.img_paths) - 1)) + cache_bytes += img.nbytes + mem_required = cache_bytes * num_imgs / num_samples + return mem_required def __len__(self): """Get the length of dataset""" return len(self.img_paths) - + def __getitem__(self, index): """Fetching a data sample for a given key. This function applies mosaic and mixup augments during training. @@ -196,7 +252,7 @@ def __getitem__(self, index): img = np.ascontiguousarray(img) return torch.from_numpy(img), labels_out, self.img_paths[index], shapes - + def load_image(self, index, shrink_size=None): """Load image. This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio. @@ -206,12 +262,16 @@ def load_image(self, index, shrink_size=None): """ path = self.img_paths[index] try: - im = cv2.imread(path) + if self.cache_ram and self.imgs[index] is not None: + im = self.imgs[index] + im = copy.deepcopy(im) + else: + im = cv2.imread(path) assert im is not None, f"opencv cannot read image correctly or {path} not exists" - except: + except Exception as e: + print(e) im = cv2.cvtColor(np.asarray(Image.open(path)), cv2.COLOR_RGB2BGR) assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}" - h0, w0 = im.shape[:2] # origin shape if self.specific_shape: # keep ratio resize @@ -222,7 +282,6 @@ def load_image(self, index, shrink_size=None): else: ratio = self.img_size / max(h0, w0) - if ratio != 1: im = cv2.resize( im,