Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Per Class Analysis and Class Confusion Matrix #21

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tidecv/ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def get_mAP(self) -> float:
aps = [x.get_ap() for x in self.objs.values() if not x.is_empty()]
return sum(aps) / len(aps)

def get_per_class_APs(self) -> dict:
return {k : v.get_ap() for k, v in self.objs.items()}

def get_gt_positives(self) -> dict:
return {k: v.num_gt_positives for k, v in self.objs.items()}

Expand Down
7 changes: 5 additions & 2 deletions tidecv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def toRLE(mask:object, w:int, h:int):
if type(mask) == list:
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask, h, w)
return maskUtils.merge(rles)
if mask:
rles = maskUtils.frPyObjects(mask, h, w)
return maskUtils.merge(rles)
else:
return mask
elif type(mask['counts']) == list:
# uncompressed RLE
return maskUtils.frPyObjects(mask, h, w)
Expand Down
84 changes: 82 additions & 2 deletions tidecv/quantify.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, gt:Data, preds:Data, pos_thresh:float, bg_thresh:float, mode:
self.preds = preds

self.errors = []
self.per_class_errors = []
self.error_dict = {_type: [] for _type in TIDE._error_types}
self.ap_data = ClassedAPDataObject()
self.qualifiers = {}
Expand Down Expand Up @@ -177,6 +178,7 @@ def _run(self):
error.disabled = False

self.ap = self.ap_data.get_mAP()
self.per_classes_ap = self.ap_data.get_per_class_APs()

# Now that we've stored the fixed errors, we can clear the gt info
self._clear()
Expand Down Expand Up @@ -330,6 +332,39 @@ def fix_errors(self, condition=lambda x: False, transform=None, false_neg_dict:d

return new_ap_data

def fix_main_per_class_errors(self, progressive: bool = False, error_types: list = None, qual: Qualifier = None) -> dict:
ap_data = self.ap_data
last_per_class_ap = self.per_classes_ap

if qual is None:
qual = Qualifier('', None)

if error_types is None:
error_types = TIDE._error_types

errors_per_class = {}
for error in error_types:
_ap_data = self.fix_errors(qual._make_error_func(error),
ap_data=ap_data, disable_errors=progressive)

new_per_class_ap = _ap_data.get_per_class_APs()
# If an error is negative that means it's likely due to binning differences, so just
# Ignore the negative by setting it to 0.
errors_per_class[error] = {k: max(new_per_class_ap[k] - last_per_class_ap[k], 0)
for k in new_per_class_ap.keys()
}

if progressive:
last_per_class_ap = new_per_class_ap
ap_data = _ap_data

# TODO: progressive
if progressive:
for error in self.errors:
error.disabled = False

return errors_per_class

def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qualifier=None) -> dict:
ap_data = self.ap_data
last_ap = self.ap
Expand All @@ -341,7 +376,6 @@ def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qu
error_types = TIDE._error_types

errors = {}

for error in error_types:
_ap_data = self.fix_errors(qual._make_error_func(error),
ap_data=ap_data, disable_errors=progressive)
Expand All @@ -350,7 +384,7 @@ def fix_main_errors(self, progressive:bool=False, error_types:list=None, qual:Qu
# If an error is negative that means it's likely due to binning differences, so just
# Ignore the negative by setting it to 0.
errors[error] = max(new_ap - last_ap, 0)

if progressive:
last_ap = new_ap
ap_data = _ap_data
Expand Down Expand Up @@ -435,6 +469,7 @@ def __init__(self, pos_threshold:float=0.5, background_threshold:float=0.1, mode
self.runs = {}
self.run_thresholds = {}
self.run_main_errors = {}
self.run_main_per_class_errors = {}
self.run_special_errors = {}

self.qualifiers = OrderedDict()
Expand Down Expand Up @@ -492,6 +527,7 @@ def add_qualifiers(self, *quals):
def summarize(self):
""" Summarizes the mAP values and errors for all runs in this TIDE object. Results are printed to the console. """
main_errors = self.get_main_errors()
main_per_class_errors = self.get_main_per_class_errors()
special_errors = self.get_special_errors()

for run_name, run in self.runs.items():
Expand Down Expand Up @@ -552,6 +588,15 @@ def summarize(self):
[' dAP'] + ['{:6.2f}'.format(main_errors[run_name][err.short_name]) for err in TIDE._error_types]
], title='Main Errors')

print()
# Print the per class errors
P.print_table(
[['class'] + ['Type'] + [err.short_name for err in TIDE._error_types]]
+
[[run.gt.classes[k]] + [' dAP'] + ['{:6.2f}'.format(main_per_class_errors[run_name][err.short_name][k])
for err in TIDE._error_types]
for k in sorted(main_per_class_errors[run_name][TIDE._error_types[0].short_name].keys())]
, title='Main Per Class Errors')


print()
Expand Down Expand Up @@ -605,6 +650,20 @@ def get_main_errors(self):

return errors

def get_main_per_class_errors(self):
errors = {}

for run_name, run in self.runs.items():
if run_name in self.run_main_per_class_errors:
errors[run_name] = self.run_main_per_class_errors[run_name]
else:
errors[run_name] = {
error.short_name: value
for error, value in run.fix_main_per_class_errors().items()
}

return errors

def get_special_errors(self):
errors = {}

Expand All @@ -631,4 +690,25 @@ def get_all_errors(self):
'special': self.get_special_errors()
}

def get_confusion_matrix(self):
confusion_matrix = {}
for run_name, run in self.runs.items():
n_classes = len(run.gt.classes)
#row: predicted classes, col: actual classes
cm = np.zeros((n_classes, n_classes), dtype=np.int32)
for error in run.errors:
if isinstance(error, ClassError):
cm[error.pred['class']-1][error.gt['class']-1] += 1
confusion_matrix[run_name] = cm
sorted_keys = sorted(run.gt.classes.keys())
print()

P.print_table([
['pred/gt'] + [run.gt.classes[k] for k in sorted_keys],
] + [
[run.gt.classes[k]] + [str(cnt) for cnt in cm[i]] for i, k in enumerate(sorted_keys)
], title=f"{run_name} confusion matrix")

return confusion_matrix