Skip to content

Commit

Permalink
feat: 🚩 add progress bar to console, only set seed in test mode
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Apr 30, 2024
1 parent a89fa44 commit 7d1d30f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pyiqa/models/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def __init__(
self.net = ARCH_REGISTRY.get(network_type)(**net_opts)
self.net = self.net.to(self.device)
self.net.eval()
set_random_seed(seed)

if not as_loss:
set_random_seed(seed)

self.dummy_param = torch.nn.Parameter(torch.empty(0)).to(self.device)

Expand Down
21 changes: 12 additions & 9 deletions pyiqa/pyiqa_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ def main():

print(f"{'='*50} Loading metrics {'='*50}")
metric_func_list = {}
results = {}
for metric in args.metric:
metric_func = create_metric(metric)
if metric == 'fid':
result = metric_func(args.target, args.ref, mode=args.fid_mode, verbose=args.verbose)
results[metric] = result
elif metric == 'inception_score':
result = metric_func(args.target, splits=args.isc_splits, verbose=args.verbose)
results[metric] = result
else:
metric_func_list[metric] = metric_func
metric_func_list[metric] = metric_func
print(f"{'='*50} Metrics loaded {'='*50}")

results = {}
# Test fid, inception_score
if 'fid' in metric_func_list:
metric_func = metric_func_list.pop('fid')
result = metric_func(args.target, args.ref, mode=args.fid_mode, verbose=args.verbose)
results['fid'] = result
if 'inception_score' in metric_func_list:
metric_func = metric_func_list.pop('inception_score')
result = metric_func(args.target, splits=args.isc_splits, verbose=args.verbose)
results['inception_score'] = result

if os.path.isdir(args.target):
target_list = scandir_images(args.target)
Expand Down

0 comments on commit 7d1d30f

Please sign in to comment.