From 7d1d30f3540cb550a644c0bd8e84f99b9316980e Mon Sep 17 00:00:00 2001 From: chaofengc Date: Tue, 30 Apr 2024 16:37:30 +0800 Subject: [PATCH] feat: :triangular_flag_on_post: add progress bar to console, only set seed in test mode --- pyiqa/models/inference_model.py | 4 +++- pyiqa/pyiqa_cmd.py | 21 ++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pyiqa/models/inference_model.py b/pyiqa/models/inference_model.py index 61ba1b8..8dcc090 100644 --- a/pyiqa/models/inference_model.py +++ b/pyiqa/models/inference_model.py @@ -55,7 +55,9 @@ def __init__( self.net = ARCH_REGISTRY.get(network_type)(**net_opts) self.net = self.net.to(self.device) self.net.eval() - set_random_seed(seed) + + if not as_loss: + set_random_seed(seed) self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device) diff --git a/pyiqa/pyiqa_cmd.py b/pyiqa/pyiqa_cmd.py index 8181da0..9875180 100644 --- a/pyiqa/pyiqa_cmd.py +++ b/pyiqa/pyiqa_cmd.py @@ -30,18 +30,21 @@ def main(): print(f"{'='*50} Loading metrics {'='*50}") metric_func_list = {} - results = {} for metric in args.metric: metric_func = create_metric(metric) - if metric == 'fid': - result = metric_func(args.target, args.ref, mode=args.fid_mode, verbose=args.verbose) - results[metric] = result - elif metric == 'inception_score': - result = metric_func(args.target, splits=args.isc_splits, verbose=args.verbose) - results[metric] = result - else: - metric_func_list[metric] = metric_func + metric_func_list[metric] = metric_func print(f"{'='*50} Metrics loaded {'='*50}") + + results = {} + # Test fid, inception_score + if 'fid' in metric_func_list: + metric_func = metric_func_list.pop('fid') + result = metric_func(args.target, args.ref, mode=args.fid_mode, verbose=args.verbose) + results['fid'] = result + if 'inception_score' in metric_func_list: + metric_func = metric_func_list.pop('inception_score') + result = metric_func(args.target, splits=args.isc_splits, verbose=args.verbose) + results['inception_score'] = result if os.path.isdir(args.target): target_list = scandir_images(args.target)