Skip to content

Commit

Permalink
Merge pull request #29 from shrimo/dev
Browse files Browse the repository at this point in the history
Added VIT tracker
  • Loading branch information
shrimo authored Oct 6, 2024
2 parents 1f47d34 + 311f702 commit 73d8ef2
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 25 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,19 @@ dnf install ffmpeg
Ubuntu
```bash
...

sudo apt update
sudo apt install ffmpeg
```
<br>

> [!NOTE]
> Please install latest opencv-python
```bash
python3 -m pip install --upgrade opencv-python
```


**g2o** framework for Python can also be build from [source code](https://github.com/RainerKuemmerle/g2o/tree/pymem), also add path to the compiled library in file *config.py*, see the *g2opy_path* variable.
<br>

Expand Down
123 changes: 123 additions & 0 deletions boxes/plugins/vit_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Node and function library for Tracking
"""

import cv2
import numpy as np
from boxes import RootNode, insert_frame, Color

cc = Color()


class VitTrack(RootNode):
"""
VIT tracker(vision transformer tracker)
is a much better model for real-time object tracking.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.variable = self.param["variable"]
self.show_ROI = self.param["show_ROI"]
self.model_path = self.param["model_path"]
self.font = cv2.FONT_HERSHEY_SIMPLEX
self.go = False
self.Image = None
self.backend_target_pairs = [cv2.dnn.DNN_BACKEND_OPENCV, cv2.dnn.DNN_TARGET_CPU]
self.backend_id = self.backend_target_pairs[0]
self.target_id = self.backend_target_pairs[1]
self.params = cv2.TrackerVit_Params()
self.params.net = self.model_path
self.params.backend = self.backend_id
self.params.target = self.target_id
self.model = cv2.TrackerVit_create(self.params)

def out_frame(self):
frame = self.get_frame(0)
if frame is None:
print("VitTrack stop")
elif self.disabled:
return frame
elif self.buffer.switch:
self.go = self.calculations_for_ROI(frame, self.buffer.roi)
self.buffer.switch = False
elif self.go:
isLocated, bbox, score = self.infer(frame)
self.visualize(frame, bbox, score, isLocated)
if self.show_ROI:
insert_frame(frame, self.Image)

return frame

def calculations_for_ROI(self, frame, coord):
"""
Set region of interest (ROI) and traker init
"""
x0, y0, x1, y1 = coord
track_window = (x0, y0, x1 - x0, y1 - y0)
# print(f'window: {track_window}')

tmp = self.model.init(frame, track_window)
if not tmp:
print("[ERROR] tracker not initialized")
frame = frame[coord[1] : coord[3], coord[0] : coord[2]]
self.Image = frame.copy()
return True

def infer(self, image):
is_located, bbox = self.model.update(image)
score = self.model.getTrackingScore()
return is_located, bbox, score

def visualize(
self,
frame,
bbox,
score,
isLocated,
fontScale=0.5,
fontSize=1,
):
h, w, _ = frame.shape

if isLocated and score >= 0.3:
# bbox: Tuple of length 4
x, y, w, h = bbox
cv2.rectangle(frame, (x, y), (x + w, y + h), cc.yellow, 1)
label = "{:.2f}".format(score)
width_label = (6 * len(label)) + 16
cv2.rectangle(frame, (x, y), (x + width_label, y + 20), cc.yellow, -1)
cv2.putText(
frame,
label,
(x + 2, y + 15),
self.font,
fontScale,
cc.gray,
fontSize,
)
center = (np.int32(x + w * 0.5), np.int32(y + h * 0.5))
cv2.circle(frame, center, 2, cc.red, -1)
else:
text_size, baseline = cv2.getTextSize(
"Target lost", cv2.FONT_HERSHEY_DUPLEX, fontScale, fontSize
)
text_x = int((w - text_size[0]) / 2)
text_y = int((h - text_size[1]) / 2)
cv2.putText(
frame,
"Target lost!",
(text_x, text_y),
cv2.FONT_HERSHEY_DUPLEX,
fontScale,
cc.red,
fontSize,
)

return frame

def update(self, param):
self.disabled = param["disabled"]
self.model_path = param["model_path"]
self.variable = param["variable"]
self.show_ROI = param["show_ROI"]
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
port: int = 50001
recv_size: int = 10240
name: str = "SLAM Box"
version: str = "0.7.6"
version: str = "0.7.7"
system: str = platform.system()
gui: str = "Node-based UI"
description: str = "Computer Vision Node Graph Editor"
date: str = "(Sat Feb 3 08:52:20 PM EET 2024)"
date: str = "(Sun Oct 6 05:27:06 AM EEST 2024)"
nodegraphqt: str = "./NodeGraphQt/"
css_style: str = "QLabel {background-color: #363636; color: white; font-size: 11pt;}"
g2opy_path: str = "/home/cds/github/g2o-pymem/build/lib"
Binary file added data/object_tracking_vittrack_2023sep.onnx
Binary file not shown.
50 changes: 29 additions & 21 deletions plugins_ui/tracking_gui_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,40 @@
from Qt import QtCore, QtWidgets
from plugins_ui.main_gui_nodes import NodeColorStyle
from NodeGraphQt import BaseNode
from NodeGraphQt.constants import (NODE_PROP_QLABEL,
NODE_PROP_QLINEEDIT,
NODE_PROP_QCOMBO,
NODE_PROP_QSPINBOX,
NODE_PROP_COLORPICKER,
NODE_PROP_SLIDER,
NODE_PROP_FILE,
NODE_PROP_QCHECKBOX,
NODE_PROP_FLOAT)
from NodeGraphQt.constants import (
NODE_PROP_QLABEL,
NODE_PROP_QLINEEDIT,
NODE_PROP_QCOMBO,
NODE_PROP_QSPINBOX,
NODE_PROP_COLORPICKER,
NODE_PROP_SLIDER,
NODE_PROP_FILE,
NODE_PROP_QCHECKBOX,
NODE_PROP_FLOAT,
)

ncs = NodeColorStyle()
ncs.set_value(15)


class AllTrackers(BaseNode):
__identifier__ = 'nodes.Tracking'
NODE_NAME = 'AllTrackers'
__identifier__ = "nodes.Tracking"
NODE_NAME = "AllTrackers"

def __init__(self):
super().__init__()
self.add_input('in', color=(180, 80, 180))
self.add_output('out')
tracker_type = ['BOOSTING', 'MIL', 'KCF', 'TLD', 'MEDIANFLOW', 'MOSSE', 'CSRT']
self.create_property('label_variable', 'Variable', widget_type=NODE_PROP_QLABEL)
self.create_property('variable', 'track1', widget_type=NODE_PROP_QLINEEDIT)
self.create_property('label_tracker_types', 'Tracker type', widget_type=NODE_PROP_QLABEL)
self.create_property('tracker_type', 'CSRT', items=tracker_type, widget_type=NODE_PROP_QCOMBO)
self.add_checkbox('show_ROI', 'Show ROI', text='On/Off', state=False, tab='attributes')
self.add_input("in", color=(180, 80, 180))
self.add_output("out")
tracker_type = ["BOOSTING", "MIL", "KCF", "TLD", "MEDIANFLOW", "MOSSE", "CSRT"]
self.create_property("label_variable", "Variable", widget_type=NODE_PROP_QLABEL)
self.create_property("variable", "track1", widget_type=NODE_PROP_QLINEEDIT)
self.create_property(
"label_tracker_types", "Tracker type", widget_type=NODE_PROP_QLABEL
)
self.create_property(
"tracker_type", "CSRT", items=tracker_type, widget_type=NODE_PROP_QCOMBO
)
self.add_checkbox(
"show_ROI", "Show ROI", text="On/Off", state=False, tab="attributes"
)
self.set_color(*ncs.Tracking)


42 changes: 42 additions & 0 deletions plugins_ui/vit_track_gui_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
""" GUI for tracking nodes """
from Qt import QtCore, QtWidgets
from plugins_ui.main_gui_nodes import NodeColorStyle
from NodeGraphQt import BaseNode
from NodeGraphQt.constants import (
NODE_PROP_QLABEL,
NODE_PROP_QLINEEDIT,
NODE_PROP_QCOMBO,
NODE_PROP_QSPINBOX,
NODE_PROP_COLORPICKER,
NODE_PROP_SLIDER,
NODE_PROP_FILE,
NODE_PROP_QCHECKBOX,
NODE_PROP_FLOAT,
)

ncs = NodeColorStyle()
ncs.set_value(15)


class VitTrack(BaseNode):
__identifier__ = "nodes.Tracking"
NODE_NAME = "VitTrack"

def __init__(self):
super().__init__()
self.add_input("in", color=(180, 80, 180))
self.add_output("out")
self.create_property(
"label_model_path", "Model path", widget_type=NODE_PROP_QLABEL
)
self.create_property(
"model_path",
"data/object_tracking_vittrack_2023sep.onnx",
widget_type=NODE_PROP_FILE,
)
self.create_property("label_variable", "Variable", widget_type=NODE_PROP_QLABEL)
self.create_property("variable", "track1", widget_type=NODE_PROP_QLINEEDIT)
self.add_checkbox(
"show_ROI", "Show ROI", text="On/Off", state=False, tab="attributes"
)
self.set_color(*ncs.Tracking)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Flask-Cors==4.0.0
g2o-python==0.0.12
numpy==1.26.1
open3d==0.17.0
opencv-contrib-python==4.8.1.78
opencv-python==4.8.1.78
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
PySide2==5.15.2.1
Qt.py==1.3.8
scikit-build==0.17.6
Expand Down

0 comments on commit 73d8ef2

Please sign in to comment.