Skip to content

Commit

Permalink
Cleaned up test_detect_target and generate_expected
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethan118 committed Feb 7, 2024
1 parent 0ab48c9 commit 435992c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tests/model_example/generate_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


# Downloaded from: https://github.com/ultralytics/assets/releases
MODEL_PATH = pathlib.Path("tests", "model_example", "yolov8s_ultralytics_pretrained_default.pt")
TEST_PATH = pathlib.Path("tests", "model_example")
MODEL_PATH = pathlib.Path(TEST_PATH, "yolov8s_ultralytics_pretrained_default.pt")

BUS_IMAGE_PATH = pathlib.Path(TEST_PATH, "bus.jpg")
ZIDANE_IMAGE_PATH = pathlib.Path(TEST_PATH, "zidane.jpg")
Expand Down
37 changes: 20 additions & 17 deletions tests/test_detect_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,39 @@
IMAGE_ZIDANE_PATH = pathlib.Path("tests", "model_example", "zidane.jpg")
BOUNDING_BOX_ZIDANE_PATH = pathlib.Path("tests", "model_example", "bounding_box_zidane.txt")

BOUNDING_BOX_DESIRED_PRECISION = 7
CONFIDENCE_DESIRED_PRECISION = 3
BOUNDING_BOX_DESIRED_TOLERANCE = 7
CONFIDENCE_DESIRED_TOLERANCE = 3


def compare_detections(actual: detections_and_time.DetectionsAndTime, expected: detections_and_time.DetectionsAndTime) -> None:
"""
Compare expected and actual detections.
"""
assert len(expected.detections) == len(actual.detections)
assert len(actual.detections) == len(expected.detections)

for i in range(0, len(expected.detections)):
expected_detection = expected.detections[i]
actual_detection = actual.detections[i]

assert expected_detection.label == actual_detection.label
np.testing.assert_almost_equal(expected_detection.confidence, actual_detection.confidence, decimal=CONFIDENCE_DESIRED_PRECISION)
np.testing.assert_almost_equal(expected_detection.confidence, actual_detection.confidence, decimal=CONFIDENCE_DESIRED_TOLERANCE)

np.testing.assert_almost_equal(expected_detection.x1, actual_detection.x1, decimal=BOUNDING_BOX_DESIRED_PRECISION)
np.testing.assert_almost_equal(expected_detection.y1, actual_detection.y1, decimal=BOUNDING_BOX_DESIRED_PRECISION)
np.testing.assert_almost_equal(expected_detection.x2, actual_detection.x2, decimal=BOUNDING_BOX_DESIRED_PRECISION)
np.testing.assert_almost_equal(expected_detection.y2, actual_detection.y2, decimal=BOUNDING_BOX_DESIRED_PRECISION)
np.testing.assert_almost_equal(actual_detection.x1, expected_detection.x1, decimal=BOUNDING_BOX_DESIRED_TOLERANCE)
np.testing.assert_almost_equal(actual_detection.y1, expected_detection.y1, decimal=BOUNDING_BOX_DESIRED_TOLERANCE)
np.testing.assert_almost_equal(actual_detection.x2, expected_detection.x2, decimal=BOUNDING_BOX_DESIRED_TOLERANCE)
np.testing.assert_almost_equal(actual_detection.y2, expected_detection.y2, decimal=BOUNDING_BOX_DESIRED_TOLERANCE)

def create_detections(expected: np.ndarray) -> detections_and_time.DetectionsAndTime:
def create_detections(detections_from_file: np.ndarray) -> detections_and_time.DetectionsAndTime:
"""
Create DetectionsAndTime from expected.
"""
detections = detections_and_time.DetectionsAndTime(0)
for i in range(0, expected.shape[0]):
result, detection = detections_and_time.Detection.create(expected[i][2:], int(expected[i][1]), expected[i][0])
assert detections_from_file.shape[1] == 6

for i in range(0, detections_from_file.shape[0]):
result, detection = detections_and_time.Detection.create(detections_from_file[i][2:], int(detections_from_file[i][1]), detections_from_file[i][0])
assert result
assert detection is not None
detections.append(detection)

return detections
Expand Down Expand Up @@ -91,29 +95,28 @@ def image_zidane():
def expected_bus():
"""
Load expected bus detections.
Format: [confidence, label, x1, y1, x2, y2]
"""
# Format: [confidence, label, x1, y1, x2, y2]
expected_bus = np.loadtxt(BOUNDING_BOX_BUS_PATH)
yield create_detections(expected_bus)

@pytest.fixture()
def expected_zidane():
"""
Load expected Zidane detections.
Format: [confidence, label, x1, y1, x2, y2]
"""
# Format: [confidence, label, x1, y1, x2, y2]
expected_zidane = np.loadtxt(BOUNDING_BOX_ZIDANE_PATH)
yield create_detections(expected_zidane)

class TestDetector:
"""
Tests `DetectTarget.run()` .
"""

def test_single_bus_image(self,
detector: detect_target.DetectTarget,
image_bus: image_and_time.ImageAndTime,
expected_bus: np.ndarray):
expected_bus: detections_and_time.DetectionsAndTime):
"""
Bus image.
"""
Expand All @@ -130,7 +133,7 @@ def test_single_bus_image(self,
def test_single_zidane_image(self,
detector: detect_target.DetectTarget,
image_zidane: image_and_time.ImageAndTime,
expected_zidane: np.ndarray):
expected_zidane: detections_and_time.DetectionsAndTime):
"""
Zidane image.
"""
Expand All @@ -146,7 +149,7 @@ def test_single_zidane_image(self,
def test_multiple_zidane_image(self,
detector: detect_target.DetectTarget,
image_zidane: image_and_time.ImageAndTime,
expected_zidane: np.ndarray):
expected_zidane: detections_and_time.DetectionsAndTime):
"""
Multiple Zidane images.
"""
Expand Down

0 comments on commit 435992c

Please sign in to comment.