From f0c6c24b0e736135aea511059b485a8ab52e394d Mon Sep 17 00:00:00 2001 From: JiangongWang Date: Thu, 5 May 2022 14:06:35 +0800 Subject: [PATCH 1/2] fix 'Dist' parameter duplication of val_dataloader --- mmgen/apis/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mmgen/apis/train.py b/mmgen/apis/train.py index 56c738984..fd1347e9a 100644 --- a/mmgen/apis/train.py +++ b/mmgen/apis/train.py @@ -179,8 +179,7 @@ def train_model(model, **loader_cfg, 'shuffle': False, **cfg.data.get('val_data_loader', {}) } - val_dataloader = build_dataloader( - val_dataset, dist=distributed, **val_loader_cfg) + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) eval_cfg = deepcopy(cfg.get('evaluation')) priority = eval_cfg.pop('priority', 'LOW') eval_cfg.update(dict(dist=distributed, dataloader=val_dataloader)) From 86b52119b68d9242557d3c09373b8f8d37a92d66 Mon Sep 17 00:00:00 2001 From: JiangongWang Date: Thu, 5 May 2022 14:45:35 +0800 Subject: [PATCH 2/2] fix support for val_samples_per_gpu and val_workers_per_gpu --- mmgen/apis/train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mmgen/apis/train.py b/mmgen/apis/train.py index fd1347e9a..cde663403 100644 --- a/mmgen/apis/train.py +++ b/mmgen/apis/train.py @@ -62,7 +62,7 @@ def train_model(model, k: v for k, v in cfg.data.items() if k not in [ 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', - 'test_dataloader' + 'test_dataloader', 'val_samples_per_gpu', 'val_workers_per_gpu' ] }) @@ -179,6 +179,12 @@ def train_model(model, **loader_cfg, 'shuffle': False, **cfg.data.get('val_data_loader', {}) } + val_loader_cfg.update({ + 'samples_per_gpu': + cfg.data.get('val_samples_per_gpu', cfg.data.samples_per_gpu), + 'workers_per_gpu': + cfg.data.get('val_workers_per_gpu', cfg.data.workers_per_gpu) + }) val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) eval_cfg = deepcopy(cfg.get('evaluation')) priority = eval_cfg.pop('priority', 'LOW')