From 7da6d6fd8a828668513ea4e7e4a687a1d3a9508b Mon Sep 17 00:00:00 2001 From: Ghost <1432072586@qq.com> Date: Sat, 24 Dec 2022 20:56:48 +0800 Subject: [PATCH 1/6] Use MMEval DOTAMAP --- mmrotate/evaluation/metrics/dota_metric.py | 256 +++++++++------------ setup.cfg | 2 +- 2 files changed, 106 insertions(+), 152 deletions(-) diff --git a/mmrotate/evaluation/metrics/dota_metric.py b/mmrotate/evaluation/metrics/dota_metric.py index e631329d0..324bd3b43 100644 --- a/mmrotate/evaluation/metrics/dota_metric.py +++ b/mmrotate/evaluation/metrics/dota_metric.py @@ -4,6 +4,7 @@ import os.path as osp import re import tempfile +import warnings import zipfile from collections import OrderedDict, defaultdict from typing import List, Optional, Sequence, Union @@ -11,86 +12,53 @@ import numpy as np import torch from mmcv.ops import nms_quadri, nms_rotated -from mmengine.evaluator import BaseMetric from mmengine.fileio import dump -from mmengine.logging import MMLogger +from mmengine.logging import MMLogger, print_log +from mmeval import DOTAMeanAP +from terminaltables import AsciiTable -from mmrotate.evaluation import eval_rbbox_map from mmrotate.registry import METRICS from mmrotate.structures.bbox import rbox2qbox @METRICS.register_module() -class DOTAMetric(BaseMetric): - """DOTA evaluation metric. - - Note: In addition to format the output results to JSON like CocoMetric, - it can also generate the full image's results by merging patches' results. - The premise is that you must use the tool provided by us to crop the DOTA - large images, which can be found at: ``tools/data/dota/split``. - - Args: - iou_thrs (float or List[float]): IoU threshold. Defaults to 0.5. - scale_ranges (List[tuple], optional): Scale ranges for evaluating - mAP. If not specified, all bounding boxes would be included in - evaluation. Defaults to None. - metric (str | list[str]): Metrics to be evaluated. Only support - 'mAP' now. If is list, the first setting in the list will - be used to evaluate metric. - predict_box_type (str): Box type of model results. If the QuadriBoxes - is used, you need to specify 'qbox'. Defaults to 'rbox'. - format_only (bool): Format the output results without perform - evaluation. It is useful when you want to format the result - to a specific format. Defaults to False. - outfile_prefix (str, optional): The prefix of json/zip files. It - includes the file path and the prefix of filename, e.g., - "a/b/prefix". If not specified, a temp file will be created. - Defaults to None. - merge_patches (bool): Generate the full image's results by merging - patches' results. - iou_thr (float): IoU threshold of ``nms_rotated`` used in merge - patches. Defaults to 0.1. - eval_mode (str): 'area' or '11points', 'area' means calculating the - area under precision-recall curve, '11points' means calculating - the average precision of recalls at [0, 0.1, ..., 1]. - The PASCAL VOC2007 defaults to use '11points', while PASCAL - VOC2012 defaults to use 'area'. Defaults to '11points'. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Defaults to None. - """ - - default_prefix: Optional[str] = 'dota' +class DOTAMetric(DOTAMeanAP): def __init__(self, iou_thrs: Union[float, List[float]] = 0.5, scale_ranges: Optional[List[tuple]] = None, - metric: Union[str, List[str]] = 'mAP', + num_classes: Optional[int] = None, + eval_mode: str = '11points', + nproc: int = 4, + drop_class_ap: bool = True, + dist_backend: str = 'torch_cuda', predict_box_type: str = 'rbox', format_only: bool = False, outfile_prefix: Optional[str] = None, merge_patches: bool = False, iou_thr: float = 0.1, - eval_mode: str = '11points', - collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: - super().__init__(collect_device=collect_device, prefix=prefix) - self.iou_thrs = [iou_thrs] if isinstance(iou_thrs, float) \ - else iou_thrs - assert isinstance(self.iou_thrs, list) - self.scale_ranges = scale_ranges - # voc evaluation metrics - if not isinstance(metric, str): - assert len(metric) == 1 - metric = metric[0] - allowed_metrics = ['mAP'] - if metric not in allowed_metrics: - raise KeyError(f"metric should be one of 'mAP', but got {metric}.") - self.metric = metric + **kwargs) -> None: + metric = kwargs.pop('metric', None) + if metric is not None: + warnings.warn('DeprecationWarning: The `metric` parameter of ' + '`DOTAMetric` is deprecated, only mAP is supported') + collect_device = kwargs.pop('collect_device', None) + if collect_device is not None: + warnings.warn( + 'DeprecationWarning: The `collect_device` parameter of ' + '`DOTAMetric` is deprecated, use `dist_backend` instead.') + + super().__init__( + iou_thrs=iou_thrs, + scale_ranges=scale_ranges, + num_classes=num_classes, + eval_mode=eval_mode, + nproc=nproc, + drop_class_ap=drop_class_ap, + classwise=True, + dist_backend=dist_backend, + **kwargs) + self.predict_box_type = predict_box_type self.format_only = format_only @@ -102,9 +70,31 @@ def __init__(self, self.outfile_prefix = outfile_prefix self.merge_patches = merge_patches self.iou_thr = iou_thr - self.use_07_metric = True if eval_mode == '11points' else False + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]): + predictions, groundtruths = [], [] + for data_sample in data_samples: + gt = copy.deepcopy(data_sample) + gt_instances = gt['gt_instances'] + gt_ignore_instances = gt['ignored_instances'] + ann = dict( + labels=gt_instances['labels'].cpu().numpy(), + bboxes=gt_instances['bboxes'].cpu().numpy(), + bboxes_ignore=gt_ignore_instances['bboxes'].cpu().numpy(), + labels_ignore=gt_ignore_instances['labels'].cpu().numpy()) + groundtruths.append(ann) + + pred = data_sample['pred_instances'] + # used for merge patches + pred['img_id'] = data_sample['img_id'] + pred['bboxes'] = pred['bboxes'].cpu().numpy() + pred['scores'] = pred['scores'].cpu().numpy() + pred['labels'] = pred['labels'].cpu().numpy() + predictions.append(pred) + self.add(predictions, groundtruths) + def merge_results(self, results: Sequence[dict], outfile_prefix: str) -> str: """Merge patches' predictions into full image's results and generate a @@ -121,7 +111,6 @@ def merge_results(self, results: Sequence[dict], "somepath/xxx/xxx.zip". """ collector = defaultdict(list) - for idx, result in enumerate(results): img_id = result.get('img_id', idx) splitname = img_id.split('__') @@ -223,8 +212,8 @@ def results2json(self, results: Sequence[dict], Args: results (Sequence[dict]): Testing results of the dataset. - outfile_prefix (str): The filename prefix of the json files. If the - prefix is "somepath/xxx", the json files will be named + outfile_prefix (str): The filename prefix of the json files. If + the prefix is "somepath/xxx", the json files will be named "somepath/xxx.bbox.json", "somepath/xxx.segm.json", "somepath/xxx.proposal.json". @@ -253,59 +242,9 @@ def results2json(self, results: Sequence[dict], return result_files - def process(self, data_batch: Sequence[dict], - data_samples: Sequence[dict]) -> None: - """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to - compute the metrics when all batches have been processed. - - Args: - data_batch (dict): A batch of data from the dataloader. - data_samples (Sequence[dict]): A batch of data samples that - contain annotations and predictions. - """ - for data_sample in data_samples: - gt = copy.deepcopy(data_sample) - gt_instances = gt['gt_instances'] - gt_ignore_instances = gt['ignored_instances'] - if gt_instances == {}: - ann = dict() - else: - ann = dict( - labels=gt_instances['labels'].cpu().numpy(), - bboxes=gt_instances['bboxes'].cpu().numpy(), - bboxes_ignore=gt_ignore_instances['bboxes'].cpu().numpy(), - labels_ignore=gt_ignore_instances['labels'].cpu().numpy()) - result = dict() - pred = data_sample['pred_instances'] - result['img_id'] = data_sample['img_id'] - result['bboxes'] = pred['bboxes'].cpu().numpy() - result['scores'] = pred['scores'].cpu().numpy() - result['labels'] = pred['labels'].cpu().numpy() - - result['pred_bbox_scores'] = [] - for label in range(len(self.dataset_meta['CLASSES'])): - index = np.where(result['labels'] == label)[0] - pred_bbox_scores = np.hstack([ - result['bboxes'][index], result['scores'][index].reshape( - (-1, 1)) - ]) - result['pred_bbox_scores'].append(pred_bbox_scores) - - self.results.append((ann, result)) - - def compute_metrics(self, results: list) -> dict: - """Compute the metrics from processed results. - - Args: - results (list): The processed results of each batch. - Returns: - dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ + def evaluate(self, *args, **kwargs) -> dict: logger: MMLogger = MMLogger.get_current_instance() - gts, preds = zip(*results) - + preds, gts = zip(*self._results) tmp_dir = None if self.outfile_prefix is None: tmp_dir = tempfile.TemporaryDirectory() @@ -319,35 +258,50 @@ def compute_metrics(self, results: list) -> dict: zip_path = self.merge_results(preds, outfile_prefix) logger.info(f'The submission file save at {zip_path}') return eval_results - else: - # convert predictions to coco format and dump to json file + elif self.format_only: _ = self.results2json(preds, outfile_prefix) - if self.format_only: - logger.info('results are saved in ' - f'{osp.dirname(outfile_prefix)}') - return eval_results - - if self.metric == 'mAP': - assert isinstance(self.iou_thrs, list) - dataset_name = self.dataset_meta['CLASSES'] - dets = [pred['pred_bbox_scores'] for pred in preds] - - mean_aps = [] - for iou_thr in self.iou_thrs: - logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') - mean_ap, _ = eval_rbbox_map( - dets, - gts, - scale_ranges=self.scale_ranges, - iou_thr=iou_thr, - use_07_metric=self.use_07_metric, - box_type=self.predict_box_type, - dataset=dataset_name, - logger=logger) - mean_aps.append(mean_ap) - eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) - eval_results['mAP'] = sum(mean_aps) / len(mean_aps) - eval_results.move_to_end('mAP', last=False) - else: - raise NotImplementedError - return eval_results + logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + metric_results = self.compute(*args, **kwargs) + self.reset() + classwise_result = metric_results['classwise_result'] + del metric_results['classwise_result'] + + classes = self.dataset_meta['CLASSES'] + header = ['class', 'gts', 'dets', 'recall', 'ap'] + + for i, iou_thr in enumerate(self.iou_thrs): + for j, scale_range in enumerate(self.scale_ranges): + table_title = f' IoU thr: {iou_thr} ' + if scale_range != (None, None): + table_title += f'Scale range: {scale_range} ' + + table_data = [header] + aps = [] + for k in range(len(classes)): + class_results = classwise_result[k] + recalls = class_results['recalls'][i, j] + recall = 0 if len(recalls) == 0 else recalls[-1] + row_data = [ + classes[k], class_results['num_gts'][i, j], + class_results['num_dets'], + round(recall, 3), + round(class_results['ap'][i, j], 3) + ] + table_data.append(row_data) + if class_results['num_gts'][i, j] > 0: + aps.append(class_results['ap'][i, j]) + + mean_ap = np.mean(aps) if aps != [] else 0 + table_data.append(['mAP', '', '', '', f'{mean_ap:.3f}']) + table = AsciiTable(table_data, title=table_title) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger='current') + + evaluate_results = { + f'pascal_voc/{k}': round(float(v), 3) + for k, v in metric_results.items() + } + return evaluate_results diff --git a/setup.cfg b/setup.cfg index a56498218..d7e481ad5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmrotate -known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,mmengine,numpy,parameterized,pycocotools,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml +known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,mmengine,mmeval,numpy,parameterized,pycocotools,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY From 15ca4e127828c7c067f9f63e3c54ea504af4b773 Mon Sep 17 00:00:00 2001 From: Ghost <1432072586@qq.com> Date: Sat, 24 Dec 2022 21:24:21 +0800 Subject: [PATCH 2/6] add eval_metric.py for offline evaluation --- tools/analysis_tools/eval_metric.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tools/analysis_tools/eval_metric.py diff --git a/tools/analysis_tools/eval_metric.py b/tools/analysis_tools/eval_metric.py new file mode 100644 index 000000000..1d701b1be --- /dev/null +++ b/tools/analysis_tools/eval_metric.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import mmengine +from mmengine import Config, DictAction +from mmengine.evaluator import Evaluator + +from mmrotate.registry import DATASETS +from mmrotate.utils import register_all_modules + +# from mmdet.registry import DATASETS +# from mmdet.utils import register_all_modules + + +def parse_args(): + parser = argparse.ArgumentParser(description='Evaluate metric of the ' + 'results saved in pkl format') + parser.add_argument('config', help='Config of the model') + parser.add_argument('pkl_results', help='Results in pickle format') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + register_all_modules(init_default_scope=True) + + cfg = Config.fromfile(args.config) + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + dataset = DATASETS.build(cfg.test_dataloader.dataset) + predictions = mmengine.load(args.pkl_results) + + evaluator = Evaluator(cfg.val_evaluator) + evaluator.dataset_meta = dataset.metainfo + eval_results = evaluator.offline_evaluate(predictions) + print(eval_results) + + +if __name__ == '__main__': + main() From 2e1bcf36888f82f150e5a1bca55a0b551f9fedd5 Mon Sep 17 00:00:00 2001 From: Ghost <1432072586@qq.com> Date: Wed, 28 Dec 2022 16:00:41 +0800 Subject: [PATCH 3/6] add docstrings --- mmrotate/evaluation/metrics/dota_metric.py | 67 ++++++++++++++++++---- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/mmrotate/evaluation/metrics/dota_metric.py b/mmrotate/evaluation/metrics/dota_metric.py index 324bd3b43..539abe26d 100644 --- a/mmrotate/evaluation/metrics/dota_metric.py +++ b/mmrotate/evaluation/metrics/dota_metric.py @@ -23,6 +23,43 @@ @METRICS.register_module() class DOTAMetric(DOTAMeanAP): + """DOTA evaluation metric. + + Note: In addition to format the output results to JSON like CocoMetric, + it can also generate the full image's results by merging patches' results. + The premise is that you must use the tool provided by us to crop the DOTA + large images, which can be found at: ``tools/data/dota/split``. + + Args: + iou_thrs (float | List[float], optional): IoU threshold. + scale_ranges (List[tuple], optional): Scale ranges for evaluating mAP. + num_classes (int, optional): The number of classes. If None, it will be + obtained from the 'CLASSES' field in ``self.dataset_meta``. + eval_mode (str, optional): 'area' or '11points', 'area' means + calculating the area under precision-recall curve, '11points'means + calculatingthe average precision of recalls at [0, 0.1, ..., 1]. + nproc (int, optional): Processes used for computing TP and FP. If nproc + is less than or equal to 1, multiprocessing will not be used. + drop_class_ap (bool, optional): Whether to drop the class without + ground truth when calculating the average precision for each class. + dist_backend (str, optional): The name of the mmeval distributed + communication backend, you can get all the backend names through + ``mmeval.core.list_all_backends()``. + predict_box_type (str, optional): Box type of model results. If the + QuadriBoxes is used, you need to specify 'qbox'. Defaults to + 'rbox'. + format_only (bool, optional): Format the output results without perform + evaluation. It is useful when you want to format the result + to a specific format. Defaults to False. + outfile_prefix (Optional[str], optional): The prefix of json/zip files. + It includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Defaults to None. + merge_patches (bool, optional): Generate the full image's results by + merging patches' results. + iou_thr (float, optional): IoU threshold of ``nms_rotated`` used in + merge patches. Defaults to 0.1. + """ def __init__(self, iou_thrs: Union[float, List[float]] = 0.5, @@ -56,16 +93,16 @@ def __init__(self, nproc=nproc, drop_class_ap=drop_class_ap, classwise=True, + predict_box_type=predict_box_type, dist_backend=dist_backend, **kwargs) - self.predict_box_type = predict_box_type - self.format_only = format_only if self.format_only: - assert outfile_prefix is not None, 'outfile_prefix must be not' - 'None when format_only is True, otherwise the result files will' - 'be saved to a temp directory which will be cleaned up at the end.' + assert outfile_prefix is not None, 'outfile_prefix must be not' \ + 'None when format_only is True, otherwise the result files' \ + 'will be saved to a temp directory which will be cleaned' \ + 'up at the end.' self.outfile_prefix = outfile_prefix self.merge_patches = merge_patches @@ -74,6 +111,14 @@ def __init__(self, def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]): + """Process one batch of data samples and predictions. The function will + call self.add() to add predictions and groundtruths to self._results. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of data samples that + contain annotations and predictions. + """ predictions, groundtruths = [], [] for data_sample in data_samples: gt = copy.deepcopy(data_sample) @@ -244,7 +289,7 @@ def results2json(self, results: Sequence[dict], def evaluate(self, *args, **kwargs) -> dict: logger: MMLogger = MMLogger.get_current_instance() - preds, gts = zip(*self._results) + preds, _ = zip(*self._results) tmp_dir = None if self.outfile_prefix is None: tmp_dir = tempfile.TemporaryDirectory() @@ -280,12 +325,12 @@ def evaluate(self, *args, **kwargs) -> dict: table_data = [header] aps = [] - for k in range(len(classes)): - class_results = classwise_result[k] + for idx, _ in enumerate(classes): + class_results = classwise_result[idx] recalls = class_results['recalls'][i, j] recall = 0 if len(recalls) == 0 else recalls[-1] row_data = [ - classes[k], class_results['num_gts'][i, j], + classes[idx], class_results['num_gts'][i, j], class_results['num_dets'], round(recall, 3), round(class_results['ap'][i, j], 3) @@ -294,14 +339,14 @@ def evaluate(self, *args, **kwargs) -> dict: if class_results['num_gts'][i, j] > 0: aps.append(class_results['ap'][i, j]) - mean_ap = np.mean(aps) if aps != [] else 0 + mean_ap = np.mean(aps) if aps else 0 table_data.append(['mAP', '', '', '', f'{mean_ap:.3f}']) table = AsciiTable(table_data, title=table_title) table.inner_footing_row_border = True print_log('\n' + table.table, logger='current') evaluate_results = { - f'pascal_voc/{k}': round(float(v), 3) + f'dota/{k}': round(float(v), 3) for k, v in metric_results.items() } return evaluate_results From f3e7865a51a72431dbfca216d40f6271988291db Mon Sep 17 00:00:00 2001 From: YanxingLiu <42299757+YanxingLiu@users.noreply.github.com> Date: Wed, 28 Dec 2022 16:02:38 +0800 Subject: [PATCH 4/6] Delete eval_metric.py --- tools/analysis_tools/eval_metric.py | 54 ----------------------------- 1 file changed, 54 deletions(-) delete mode 100644 tools/analysis_tools/eval_metric.py diff --git a/tools/analysis_tools/eval_metric.py b/tools/analysis_tools/eval_metric.py deleted file mode 100644 index 1d701b1be..000000000 --- a/tools/analysis_tools/eval_metric.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse - -import mmengine -from mmengine import Config, DictAction -from mmengine.evaluator import Evaluator - -from mmrotate.registry import DATASETS -from mmrotate.utils import register_all_modules - -# from mmdet.registry import DATASETS -# from mmdet.utils import register_all_modules - - -def parse_args(): - parser = argparse.ArgumentParser(description='Evaluate metric of the ' - 'results saved in pkl format') - parser.add_argument('config', help='Config of the model') - parser.add_argument('pkl_results', help='Results in pickle format') - parser.add_argument( - '--cfg-options', - nargs='+', - action=DictAction, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. If the value to ' - 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' - 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' - 'Note that the quotation marks are necessary and that no white space ' - 'is allowed.') - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - register_all_modules(init_default_scope=True) - - cfg = Config.fromfile(args.config) - - if args.cfg_options is not None: - cfg.merge_from_dict(args.cfg_options) - - dataset = DATASETS.build(cfg.test_dataloader.dataset) - predictions = mmengine.load(args.pkl_results) - - evaluator = Evaluator(cfg.val_evaluator) - evaluator.dataset_meta = dataset.metainfo - eval_results = evaluator.offline_evaluate(predictions) - print(eval_results) - - -if __name__ == '__main__': - main() From 574dcd385a7828f64d9341108b90c0ec643642a9 Mon Sep 17 00:00:00 2001 From: YanxingLiu <42299757+YanxingLiu@users.noreply.github.com> Date: Sun, 8 Jan 2023 09:15:21 +0800 Subject: [PATCH 5/6] Update mmrotate/evaluation/metrics/dota_metric.py Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> --- mmrotate/evaluation/metrics/dota_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmrotate/evaluation/metrics/dota_metric.py b/mmrotate/evaluation/metrics/dota_metric.py index 539abe26d..74f9db783 100644 --- a/mmrotate/evaluation/metrics/dota_metric.py +++ b/mmrotate/evaluation/metrics/dota_metric.py @@ -138,7 +138,7 @@ def process(self, data_batch: Sequence[dict], pred['scores'] = pred['scores'].cpu().numpy() pred['labels'] = pred['labels'].cpu().numpy() predictions.append(pred) - self.add(predictions, groundtruths) + self.add(predictions, groundtruths) def merge_results(self, results: Sequence[dict], outfile_prefix: str) -> str: From 2927e5bac2f6f15166148083175d07e5202a55a1 Mon Sep 17 00:00:00 2001 From: Ghost <1432072586@qq.com> Date: Wed, 11 Jan 2023 00:05:01 +0800 Subject: [PATCH 6/6] modify dota_metric --- mmrotate/evaluation/metrics/dota_metric.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mmrotate/evaluation/metrics/dota_metric.py b/mmrotate/evaluation/metrics/dota_metric.py index 539abe26d..43419fbf8 100644 --- a/mmrotate/evaluation/metrics/dota_metric.py +++ b/mmrotate/evaluation/metrics/dota_metric.py @@ -93,10 +93,9 @@ def __init__(self, nproc=nproc, drop_class_ap=drop_class_ap, classwise=True, - predict_box_type=predict_box_type, dist_backend=dist_backend, **kwargs) - + self.predict_box_type = predict_box_type self.format_only = format_only if self.format_only: assert outfile_prefix is not None, 'outfile_prefix must be not' \ @@ -138,7 +137,7 @@ def process(self, data_batch: Sequence[dict], pred['scores'] = pred['scores'].cpu().numpy() pred['labels'] = pred['labels'].cpu().numpy() predictions.append(pred) - self.add(predictions, groundtruths) + self.add(predictions, groundtruths) def merge_results(self, results: Sequence[dict], outfile_prefix: str) -> str: