Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct computation of recall in case of missing detection files #119

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import argparse
import math
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -363,27 +364,20 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out

gt_files = []
for txt_file in ground_truth_files_list:
#print(txt_file)
file_id = txt_file.split(".txt", 1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
# check if there is a correspondent detection-results file
temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
if not os.path.exists(temp_path):
error_msg = "Error. File not found: {}\n".format(temp_path)
error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
error(error_msg)
lines_list = file_lines_to_list(txt_file)

txt_file = Path(txt_file)
lines_list = txt_file.read_text().splitlines()
# create ground-truth dictionary
bounding_boxes = []
is_difficult = False
already_seen_classes = []
for line in lines_list:
try:
if "difficult" in line:
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
else:
class_name, left, top, right, bottom = line.split()
class_name, left, top, right, bottom = line.split()
except ValueError:
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
error_msg += " Expected: <class_name> <left> <top> <right> <bottom> ['difficult']\n"
Expand All @@ -396,10 +390,10 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
continue
bbox = left + " " + top + " " + right + " " +bottom
if is_difficult:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
bounding_boxes.append({"class_name": class_name, "bbox": bbox, "used": False, "difficult": True})
is_difficult = False
else:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
bounding_boxes.append({"class_name": class_name, "bbox": bbox, "used": False})
# count that object
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
Expand All @@ -415,9 +409,8 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
counter_images_per_class[class_name] = 1
already_seen_classes.append(class_name)


# dump bounding_boxes into a ".json" file
new_temp_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
new_temp_file = Path(TEMP_FILES_PATH).joinpath(txt_file.stem + "_ground_truth.json")
gt_files.append(new_temp_file)
with open(new_temp_file, 'w') as outfile:
json.dump(bounding_boxes, outfile)
Expand Down Expand Up @@ -466,15 +459,16 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
for txt_file in dr_files_list:
#print(txt_file)
# the first time it checks if all the corresponding ground-truth files exist
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
temp_path = os.path.join(GT_PATH, (file_id + ".txt"))

txt_file = Path(txt_file)
# check if there is a correspondent detection-results file
temp_path = Path(GT_PATH).joinpath(txt_file.name)
if class_index == 0:
if not os.path.exists(temp_path):
if not temp_path.exists():
error_msg = "Error. File not found: {}\n".format(temp_path)
error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
error(error_msg)
lines = file_lines_to_list(txt_file)
lines = txt_file.read_text().splitlines()
for line in lines:
try:
tmp_class_name, confidence, left, top, right, bottom = line.split()
Expand All @@ -485,11 +479,11 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
error(error_msg)
if tmp_class_name == class_name:
#print("match")
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
bbox = left + " " + top + " " + right + " " + bottom
bounding_boxes.append({"confidence": confidence, "file_id": txt_file.stem, "bbox": bbox})
#print(bounding_boxes)
# sort detection-results by decreasing confidence
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
bounding_boxes.sort(key=lambda x: float(x['confidence']), reverse=True)
with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)

Expand Down Expand Up @@ -739,7 +733,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
#print(ground_truth_data)
# get name of corresponding image
start = TEMP_FILES_PATH + '/'
img_id = tmp_file[tmp_file.find(start)+len(start):tmp_file.rfind('_ground_truth.json')]
img_id = Path(tmp_file).stem.split('_ground_truth')[0]
img_cumulative_path = output_files_path + "/images/" + img_id + ".jpg"
img = cv2.imread(img_cumulative_path)
if img is None:
Expand Down