Skip to content

Commit

Permalink
support cache images to RAM to speed up training. (#924)
Browse files Browse the repository at this point in the history
* use cache to accelerate the training of yolov6

* normalize codes about cache images to RAM.

* fix the bug in cache images into RAM

* only display cache progress in gpu 0 in DDP mode.

---------

Co-authored-by: zhujiajian98 <[email protected]>
  • Loading branch information
mtjhl and zhujiajian98 authored Oct 10, 2023
1 parent 9d19edb commit d512ce7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 17 deletions.
1 change: 1 addition & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_args_parser(add_help=True):
parser.add_argument('--specific-shape', action='store_true', help='rectangular training')
parser.add_argument('--height', type=int, default=None, help='image height of model input')
parser.add_argument('--width', type=int, default=None, help='image width of model input')
parser.add_argument('--cache-ram', action='store_true', help='whether to cache images into RAM to speed up training')
return parser


Expand Down
6 changes: 4 additions & 2 deletions yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def get_data_loader(args, cfg, data_dict):
hyp=dict(cfg.data_aug), augment=True, rect=args.rect, rank=args.local_rank,
workers=args.workers, shuffle=True, check_images=args.check_images,
check_labels=args.check_labels, data_dict=data_dict, task='train',
specific_shape=args.specific_shape, height=args.height, width=args.width)[0]
specific_shape=args.specific_shape, height=args.height, width=args.width,
cache_ram=args.cache_ram)[0]
# create val dataloader
val_loader = None
if args.rank in [-1, 0]:
Expand All @@ -396,7 +397,8 @@ def get_data_loader(args, cfg, data_dict):
hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
workers=args.workers, check_images=args.check_images,
check_labels=args.check_labels, data_dict=data_dict, task='val',
specific_shape=args.specific_shape, height=args.height, width=args.width)[0]
specific_shape=args.specific_shape, height=args.height, width=args.width,
cache_ram=args.cache_ram)[0]

return train_loader, val_loader

Expand Down
9 changes: 5 additions & 4 deletions yolov6/data/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def create_dataloader(
task="Train",
specific_shape=False,
height=1088,
width=1920

):
width=1920,
cache_ram=False
):
"""Create general dataloader.
Returns dataloader and dataset
Expand All @@ -59,7 +59,8 @@ def create_dataloader(
task=task,
specific_shape = specific_shape,
height=height,
width=width
width=width,
cache_ram=cache_ram
)

batch_size = min(batch_size, len(dataset))
Expand Down
81 changes: 70 additions & 11 deletions yolov6/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
mosaic_augmentation,
)
from yolov6.utils.events import LOGGER
import copy
import psutil
from multiprocessing.pool import ThreadPool


# Parameters
Expand Down Expand Up @@ -67,11 +70,11 @@ def __init__(
task="train",
specific_shape = False,
height=1088,
width=1920

width=1920,
cache_ram=False
):
assert task.lower() in ("train", "val", "test", "speed"), f"Not supported task: {task}"
t1 = time.time()
tik = time.time()
self.__dict__.update(locals())
self.main_process = self.rank in (-1, 0)
self.task = self.task.capitalize()
Expand All @@ -81,6 +84,12 @@ def __init__(
self.specific_shape = specific_shape
self.target_height = height
self.target_width = width
self.cache_ram = cache_ram
if self.cache_ram:
self.num_imgs = len(self.img_paths)
self.imgs = [None] * self.num_imgs
self.cache_images(num_imgs=self.num_imgs)

if self.rect:
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
self.shapes = np.array(shapes, dtype=np.float64)
Expand All @@ -98,14 +107,61 @@ def __init__(

self.sort_files_shapes()

t2 = time.time()
tok = time.time()

if self.main_process:
LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1))
LOGGER.info(f"%.1fs for dataset initialization." % (tok - tik))

def cache_images(self, num_imgs=None):
assert num_imgs is not None, "num_imgs must be specified as the size of the dataset"

mem = psutil.virtual_memory()
mem_required = self.cal_cache_occupy(num_imgs)
gb = 1 << 30

if mem_required > mem.available:
self.cache_ram = False
LOGGER.warning("Not enough RAM to cache images, caching is disabled.")
else:
LOGGER.warning(
f"{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
f"Since the first thing we do is cache, "
f"there is no guarantee that the remaining memory space is sufficient"
)

print(f"self.imgs: {len(self.imgs)}")
LOGGER.info("You are using cached images in RAM to accelerate training!")
LOGGER.info(
"Caching images...\n"
"This might take some time for your dataset"
)
num_threads = min(16, max(1, os.cpu_count() - 1))
load_imgs = ThreadPool(num_threads).imap(self.load_image, range(num_imgs))
pbar = tqdm(enumerate(load_imgs), total=num_imgs, disable=self.rank > 0)
for i, (x, (h0, w0), shape) in pbar:
self.imgs[i] = x

def __del__(self):
if self.cache_ram:
del self.imgs

def cal_cache_occupy(self, num_imgs):
'''estimate the memory required to cache images in RAM.
'''
cache_bytes = 0
num_imgs = len(self.img_paths)
num_samples = min(num_imgs, 32)
for _ in range(num_samples):
img, _, _ = self.load_image(index=random.randint(0, len(self.img_paths) - 1))
cache_bytes += img.nbytes
mem_required = cache_bytes * num_imgs / num_samples
return mem_required

def __len__(self):
"""Get the length of dataset"""
return len(self.img_paths)

def __getitem__(self, index):
"""Fetching a data sample for a given key.
This function applies mosaic and mixup augments during training.
Expand Down Expand Up @@ -196,7 +252,7 @@ def __getitem__(self, index):
img = np.ascontiguousarray(img)

return torch.from_numpy(img), labels_out, self.img_paths[index], shapes

def load_image(self, index, shrink_size=None):
"""Load image.
This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio.
Expand All @@ -206,12 +262,16 @@ def load_image(self, index, shrink_size=None):
"""
path = self.img_paths[index]
try:
im = cv2.imread(path)
if self.cache_ram and self.imgs[index] is not None:
im = self.imgs[index]
im = copy.deepcopy(im)
else:
im = cv2.imread(path)
assert im is not None, f"opencv cannot read image correctly or {path} not exists"
except:
except Exception as e:
print(e)
im = cv2.cvtColor(np.asarray(Image.open(path)), cv2.COLOR_RGB2BGR)
assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}"

h0, w0 = im.shape[:2] # origin shape
if self.specific_shape:
# keep ratio resize
Expand All @@ -222,7 +282,6 @@ def load_image(self, index, shrink_size=None):

else:
ratio = self.img_size / max(h0, w0)

if ratio != 1:
im = cv2.resize(
im,
Expand Down

0 comments on commit d512ce7

Please sign in to comment.