From 99105bdbbb943d97c8626ee0b34e209e74af1ffd Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Mon, 14 Mar 2022 19:38:44 +0800 Subject: [PATCH] support read data from ceph --- mmgen/datasets/unconditional_image_dataset.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/mmgen/datasets/unconditional_image_dataset.py b/mmgen/datasets/unconditional_image_dataset.py index 2a128f9e7..cc3090b18 100644 --- a/mmgen/datasets/unconditional_image_dataset.py +++ b/mmgen/datasets/unconditional_image_dataset.py @@ -21,15 +21,19 @@ class UnconditionalImageDataset(Dataset): pipeline (list[dict | callable]): A sequence of data transforms. test_mode (bool, optional): If True, the dataset will work in test mode. Otherwise, in train mode. Default to False. + backend (str): io backend where images are store. Default: 'disk'. """ _VALID_IMG_SUFFIX = ('.jpg', '.png', '.jpeg', '.JPEG') + _VALID_BACKEND = ('petrel', 'disk') - def __init__(self, imgs_root, pipeline, test_mode=False): + def __init__(self, imgs_root, pipeline, test_mode=False, backend='disk'): super().__init__() self.imgs_root = imgs_root self.pipeline = Compose(pipeline) self.test_mode = test_mode + assert backend in self._VALID_BACKEND + self.backend = backend self.load_annotations() # print basic dataset information to check the validity @@ -38,9 +42,19 @@ def __init__(self, imgs_root, pipeline, test_mode=False): def load_annotations(self): """Load annotations.""" # recursively find all of the valid images from imgs_root - imgs_list = mmcv.scandir( - self.imgs_root, self._VALID_IMG_SUFFIX, recursive=True) - self.imgs_list = [osp.join(self.imgs_root, x) for x in imgs_list] + if self.backend == 'disk': + imgs_list = mmcv.scandir( + self.imgs_root, self._VALID_IMG_SUFFIX, recursive=True) + self.imgs_list = [osp.join(self.imgs_root, x) for x in imgs_list] + elif self.backend == 'petrel': + file_client = mmcv.FileClient(backend=self.backend) + # get filename generator + files = file_client.list_dir_or_file(self.imgs_root) + self.imgs_list = [ + file_client.join_path(self.imgs_root, file) for file in files + ] + else: + raise NotImplementedError def prepare_train_data(self, idx): """Prepare training data.