From 9adb57c54195c017f7c02675a1a146670c18363c Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 5 Jun 2024 11:30:46 +0100 Subject: [PATCH] PredictedInstance repr --- sleap_io/model/instance.py | 16 ++++++++++++++++ tests/model/test_instance.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/sleap_io/model/instance.py b/sleap_io/model/instance.py index 89838e5f..25dcd0aa 100644 --- a/sleap_io/model/instance.py +++ b/sleap_io/model/instance.py @@ -334,6 +334,22 @@ class PredictedInstance(Instance): score: float = 0.0 tracking_score: Optional[float] = 0 + def __repr__(self) -> str: + """Return a readable representation of the instance.""" + pts = self.numpy().tolist() + track = f'"{self.track.name}"' if self.track is not None else self.track + + score = str(self.score) if self.score is None else f"{self.score:.2f}" + tracking_score = ( + str(self.tracking_score) + if self.tracking_score is None + else f"{self.tracking_score:.2f}" + ) + return ( + f"PredictedInstance(points={pts}, track={track}, " + f"score={score}, tracking_score={tracking_score})" + ) + @classmethod def from_numpy( # type: ignore[override] cls, diff --git a/tests/model/test_instance.py b/tests/model/test_instance.py index 7e626098..e3181ec3 100644 --- a/tests/model/test_instance.py +++ b/tests/model/test_instance.py @@ -159,3 +159,8 @@ def test_predicted_instance(): assert inst[0].score == 0.4 assert inst[1].score == 0.5 assert inst.score == 0.6 + + assert ( + str(inst) == "PredictedInstance(points=[[0.0, 1.0], [2.0, 3.0]], track=None, " + "score=0.60, tracking_score=None)" + )