From 2e71213b504bc1b11caf693876304fc87c533adc Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 12:35:09 -0400 Subject: [PATCH 1/7] initial unfinished work on added metrics --- validate.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/validate.py b/validate.py index 3b982cc..e75c575 100644 --- a/validate.py +++ b/validate.py @@ -225,12 +225,19 @@ def validate(args): losses = AverageMeter() top1 = AverageMeter() + thres = 0.5 # Threshold of positive values. if args.binary_metrics: auroc = torchmetrics.AUROC() - f1 = torchmetrics.F1Score() + f1 = torchmetrics.F1Score(threshold=thres) + sens = torchmetrics.Sensitivity(threshold=thres) + spec = torchmetrics.Specificity(threshold=thres) + statscores = torchmetrics.StatScores(threshold=thres) else: auroc = torchmetrics.AUROC(num_classes=args.num_classes) - f1 = torchmetrics.F1Score(num_classes=args.num_classes) + f1 = torchmetrics.F1Score(num_classes=args.num_classes, threshold=thres) + sens = torchmetrics.Sensitivity(num_classes=args.num_classes, threshold=thres) + spec = torchmetrics.Specificity(num_classes=args.num_classes, threshold=thres) + statscores = torchmetrics.StatScores(num_classes=args.num_classes, threshold=thres) model.eval() with torch.no_grad(): @@ -271,8 +278,8 @@ def validate(args): if args.binary_metrics: # Keep the probabilities of the "positive" class. probs = probs[:, 1] - auroc.update(preds=probs, target=target) - f1.update(preds=probs, target=target) + for obj in [auroc, f1, sens, spec, statscores]: + obj.update(preds=probs, target=target) # measure elapsed time batch_time.update(time.time() - end) @@ -283,7 +290,7 @@ def validate(args): 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' + # 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'AUROC: {auroc:>7.4f} ' 'F1: {f1:>7.4f}'.format( batch_idx, len(loader), batch_time=batch_time, @@ -296,6 +303,15 @@ def validate(args): top1a = real_labels.get_accuracy(k=1) else: top1a = top1.avg + + stats = statscores.compute().numpy() # TODO: what is the shape of this? (2, 5)? + print("Shape of stats obj", stats.shape) + print(stats) + raise RuntimeError(f"Shape of stats obj: {stats.shape}") + + fnr = None # False negative rate + fpr = None # False positive rate + results = OrderedDict( model=args.model, top1=round(top1a, 4), top1_err=round(100 - top1a, 4), From af9f7dc3615cffab8b0f8bddf306492b6686a409 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 12:49:05 -0400 Subject: [PATCH 2/7] rm sensitivity and specificity --- validate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/validate.py b/validate.py index e75c575..c2992b6 100644 --- a/validate.py +++ b/validate.py @@ -235,8 +235,6 @@ def validate(args): else: auroc = torchmetrics.AUROC(num_classes=args.num_classes) f1 = torchmetrics.F1Score(num_classes=args.num_classes, threshold=thres) - sens = torchmetrics.Sensitivity(num_classes=args.num_classes, threshold=thres) - spec = torchmetrics.Specificity(num_classes=args.num_classes, threshold=thres) statscores = torchmetrics.StatScores(num_classes=args.num_classes, threshold=thres) model.eval() @@ -278,7 +276,7 @@ def validate(args): if args.binary_metrics: # Keep the probabilities of the "positive" class. probs = probs[:, 1] - for obj in [auroc, f1, sens, spec, statscores]: + for obj in [auroc, f1, statscores]: obj.update(preds=probs, target=target) # measure elapsed time From 712543578cb58b07f17bbb4ca737e921e81ae63a Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 12:49:48 -0400 Subject: [PATCH 3/7] actually rm sens and spec... --- validate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/validate.py b/validate.py index c2992b6..b5112a4 100644 --- a/validate.py +++ b/validate.py @@ -229,8 +229,6 @@ def validate(args): if args.binary_metrics: auroc = torchmetrics.AUROC() f1 = torchmetrics.F1Score(threshold=thres) - sens = torchmetrics.Sensitivity(threshold=thres) - spec = torchmetrics.Specificity(threshold=thres) statscores = torchmetrics.StatScores(threshold=thres) else: auroc = torchmetrics.AUROC(num_classes=args.num_classes) From f7015840294c5c3b4ef71ae116ad24fab4416a9e Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 13:03:48 -0400 Subject: [PATCH 4/7] calc and print fnr,fpr,tnr,tpr --- validate.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/validate.py b/validate.py index b5112a4..72bd429 100644 --- a/validate.py +++ b/validate.py @@ -300,26 +300,41 @@ def validate(args): else: top1a = top1.avg - stats = statscores.compute().numpy() # TODO: what is the shape of this? (2, 5)? - print("Shape of stats obj", stats.shape) - print(stats) - raise RuntimeError(f"Shape of stats obj: {stats.shape}") - - fnr = None # False negative rate - fpr = None # False positive rate + stats = statscores.compute().numpy() + if stats.shape != (5,): + raise NotImplementedError( + "Computing confusion matrix stats only valid when num" + " classes == 2 and binary-metrics is used.") + tp, fp, tn, fn, sup = stats # sup is support = tp+fn + + fnr = fn / (fn + tp) # False negative rate + fpr = fp / (fp + tn) # False positive rate + tnr = 1 - fpr # True negative rate + tpr = 1 - fnr # True positive rate results = OrderedDict( model=args.model, top1=round(top1a, 4), top1_err=round(100 - top1a, 4), auroc=auroc.compute().item(), f1=f1.compute().item(), + fnr=fnr, + fpr=fpr, + tnr=tnr, + tpr=tpr, param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], crop_pct=crop_pct, interpolation=data_config['interpolation']) - _logger.info(' * Acc@1 {:.3f} ({:.3f}) AUROC {:.3f} F1 {:.3f}'.format( - results['top1'], results['top1_err'], results['auroc'], results['f1'])) + _logger.info( + "***" + f" AUROC {results['auroc']:.3f}\n" + f" F1@{thres:.2f} {results['f1']:.3f}\n" + f" FNR {results['fnr']:.3f}\n" + f" FPR {results['fpr']:.3f}\n" + f" TNR {results['tnr']:.3f}\n" + f" TPR {results['tpr']:.3f}" + ) return results From 63728c17f56dc88247c0bc5d08e102052c842f9f Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 13:14:35 -0400 Subject: [PATCH 5/7] add fpr,fnr,tpr,tnr --- run_all_evaluations.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run_all_evaluations.py b/run_all_evaluations.py index 0dd4053..574102a 100644 --- a/run_all_evaluations.py +++ b/run_all_evaluations.py @@ -122,6 +122,10 @@ def _run_one_evaluation(row: pd.Series) -> pd.Series: pretrained=row["pretrained"], auroc=tmp_results["auroc"], f1=tmp_results["f1"], + fpr=tmp_results["fpr"], + fnr=tmp_results["fnr"], + tpr=tmp_results["tpr"], + tnr=tmp_results["tnr"], accuracy=tmp_results["top1"], checkpoint=checkpoint, classmap=classmap_file, From da3cc9b77eeb586aaa03b3f949d2e7bcb628352f Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 18 Aug 2022 13:17:12 -0400 Subject: [PATCH 6/7] print fpr,fnr,tpr,tnr --- run_all_evaluations.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run_all_evaluations.py b/run_all_evaluations.py index 574102a..24888c5 100644 --- a/run_all_evaluations.py +++ b/run_all_evaluations.py @@ -171,6 +171,10 @@ def run_all_evaluations(directory) -> pd.DataFrame: print("[champkit] Results:") print(f"[champkit] AUROC={result['auroc']:0.4f}") print(f"[champkit] F1={result['f1']:0.4f}") + print(f"[champkit] FPR={result['fpr']:0.4f}") + print(f"[champkit] FNR={result['fnr']:0.4f}") + print(f"[champkit] TPR={result['tpr']:0.4f}") + print(f"[champkit] TNR={result['tnr']:0.4f}") result["epoch"] = epoch # could be None but that's ok all_results.append(result) del result # for our sanity From 0bf31b18df17d792273e044e5d4afb7352811134 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Tue, 23 Aug 2022 23:20:13 -0400 Subject: [PATCH 7/7] add model hyperparams to output csv --- run_all_evaluations.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run_all_evaluations.py b/run_all_evaluations.py index 24888c5..f980c4b 100644 --- a/run_all_evaluations.py +++ b/run_all_evaluations.py @@ -176,6 +176,10 @@ def run_all_evaluations(directory) -> pd.DataFrame: print(f"[champkit] TPR={result['tpr']:0.4f}") print(f"[champkit] TNR={result['tnr']:0.4f}") result["epoch"] = epoch # could be None but that's ok + # Add model hyperparams. + for k, v in row.iteritems(): + if k not in result.index: + result[k] = v all_results.append(result) del result # for our sanity df = pd.DataFrame(all_results).reset_index(drop=True)