-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtrackers.py
126 lines (110 loc) · 4.03 KB
/
trackers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Unified tracker interface for supported trackers."""
from abc import abstractmethod
from dataclasses import asdict
from typing import Dict, List, Tuple, Union
import numpy as np
from similari import (
Sort as SortImpl,
VisualSort as VisualSortImpl,
BoundingBox,
SpatioTemporalConstraints,
PositionalMetricType,
)
from .config import (
OriginalSortParams,
SortParams,
VisualSortParams,
PositionalMetricType as PositionalMetricConfigType,
)
from .original_sort import Sort as OriginalSortImpl
class Tracker:
@abstractmethod
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
"""(left, top, width, height, confidence) =>
(track_id, left, top, width, height, confidence)
"""
pass
class OriginalSort(Tracker):
def __init__(self, params: OriginalSortParams):
self._tracker = OriginalSortImpl(**asdict(params))
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
# tuple(top, left, width, height) to np.array([x1, y1, x2, y2])
np_detections = np.array(detections)
np_detections[:, 2:4] += np_detections[:, 0:2]
tracks = self._tracker.update(np_detections)
return [
(
int(track[4]),
track[0],
track[1],
track[2] - track[0],
track[3] - track[1],
1.0,
)
for track in tracks
]
class SimilariTracker(Tracker):
def __init__(self, params: Union[SortParams, VisualSortParams]):
constraints = None
if params.spatio_temporal_constraints:
constraints = SpatioTemporalConstraints()
constraints.add_constraints(
list(map(tuple, params.spatio_temporal_constraints))
)
positional_metric = None
if params.positional_metric:
if params.positional_metric.type == PositionalMetricConfigType.IoU:
positional_metric = PositionalMetricType.iou(
threshold=params.positional_metric.threshold
)
else:
positional_metric = PositionalMetricType.maha()
if isinstance(params, SortParams):
self._tracker = SortImpl(
shards=params.shards,
bbox_history=params.bbox_history,
max_idle_epochs=params.max_idle_epochs,
method=positional_metric,
spatio_temporal_constraints=constraints,
)
else:
raise NotImplementedError
self._use_confidence = params.use_confidence
self._track_id_map: Dict[int, int] = {} # to have 1-based track id
def process_frame(
self, frame_num: int, detections: List[Tuple[float, float, float, float, float]]
) -> List[Tuple[int, float, float, float, float, float]]:
if self._use_confidence:
dets = [
(BoundingBox.new_with_confidence(*detection).as_xyaah(), 0)
for detection in detections
]
else:
dets = [
(BoundingBox(*detection[:-1]).as_xyaah(), 0) for detection in detections
]
tracks = self._tracker.predict(dets)
rows = []
for track in tracks:
track_id = track.id
_bbox = track.predicted_bbox.as_ltwh()
if track_id not in self._track_id_map:
self._track_id_map[track_id] = len(self._track_id_map) + 1
rows.append(
(
self._track_id_map[track_id],
_bbox.left,
_bbox.top,
_bbox.width,
_bbox.height,
1.0,
)
)
# TODO
# if frame_num % ...:
# self._tracker.wasted()
return rows