Skip to content

Commit

Permalink
feat: byte track
Browse files Browse the repository at this point in the history
Signed-off-by: wep21 <[email protected]>
  • Loading branch information
wep21 committed Dec 18, 2024
1 parent c4c8c75 commit 7582df9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
5 changes: 3 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ cc_binary(
],
deps = [
"engine",
"@cvcuda",
"@byte_track_eigen",
"@cvcuda",
"@opencv",
],
)
Expand All @@ -30,7 +31,7 @@ cc_binary(
],
deps = [
"engine",
"@cvcuda",
"@cvcuda",
"@opencv",
],
)
3 changes: 2 additions & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module(name = "tensorrt_yolo")

bazel_dep(name = "byte_track_eigen", version = "2.1.0")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_cc", version = "0.1.0")
bazel_dep(name = "rules_cuda", version = "0.2.3")

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
Expand Down
104 changes: 93 additions & 11 deletions src/video_demo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@
#include <cvcuda/OpResize.hpp>
#include <opencv2/opencv.hpp>

#include "BYTETracker.h"
#include "engine.hpp"

int main(int argc, char* argv[]) {
assert(argc == 2 || argc == 3);
assert(argc == 2 || argc == 3 || argc == 4);
bool save_video = false;
if (argc == 4 && std::string(argv[3]) == "save") {
save_video = true;
}

cudaStream_t stream;
CHECK_CUDA_ERROR(cudaStreamCreate(&stream));
auto engine = std::make_unique<tensorrt::Engine>(std::string(argv[1]));
Expand All @@ -21,6 +27,7 @@ int main(int argc, char* argv[]) {
dims.d[2] = 640;
dims.d[3] = 640;
engine->set_input_shape("images", dims);
auto tracker = std::make_unique<BYTETracker>(0.4, 240, 0.7, 12);
cv::VideoCapture cap;
if (argc == 2) {
cap.open(0);
Expand Down Expand Up @@ -89,6 +96,15 @@ int main(int argc, char* argv[]) {
nvcv::Tensor resized_tensor(1, {640, 640}, nvcv::FMT_RGB8);
nvcv::Tensor float_tensor(1, {640, 640}, nvcv::FMT_RGBf32);

cv::VideoWriter video_writer;
if (save_video) {
video_writer.open("output.mp4", cv::VideoWriter::fourcc('M', 'P', '4', 'V'), 30, frame.size());
if (!video_writer.isOpened()) {
std::cerr << "Could not open the output video file for write\n";
return -1;
}
}

// Warm up
{
cvt_color_op(stream, input_tensor, rgb_tensor, NVCV_COLOR_BGR2RGB);
Expand Down Expand Up @@ -126,20 +142,86 @@ int main(int argc, char* argv[]) {
cudaEventElapsedTime(&operatorms, start, stop);
std::cout << "Time for Infer : " << operatorms << " ms" << std::endl;
cudaStreamSynchronize(stream);
Eigen::MatrixXf detections(100, 5);
size_t person_index = 0;
for (int i = 0; i < num_detections[0]; ++i) {
cv::rectangle(
frame, cv::Point(boxes[4 * i] / 640 * frame.cols, boxes[4 * i + 1] / 640 * frame.rows),
cv::Point(boxes[4 * i + 2] / 640 * frame.cols, boxes[4 * i + 3] / 640 * frame.rows),
cv::Scalar(0, 0, 255), 1, 8, 0);
if (classes[i] != 0) {
continue;
}
detections(person_index, 0) = boxes[4 * i] / 640 * frame.cols;
detections(person_index, 1) = boxes[4 * i + 1] / 640 * frame.rows;
detections(person_index, 2) = (boxes[4 * i + 2] - boxes[4 * i]) / 640 * frame.cols;
detections(person_index, 3) = (boxes[4 * i + 3] - boxes[4 * i + 1]) / 640 * frame.rows;
detections(person_index, 4) = scores[i];
person_index++;
}
cv::imshow("win", frame);
const int key = cv::waitKey(1);
if (key == 'q') {
break;
} else if (key == 's') {
cv::imwrite("img.png", frame);
detections.conservativeResize(person_index, Eigen::NoChange);

auto tracking_start = std::chrono::high_resolution_clock::now();
std::vector<KalmanBBoxTrack> tracks = tracker->process_frame_detections(detections);
Eigen::MatrixXf tlbr_boxes(detections.rows(), 4);
tlbr_boxes << detections.col(0), detections.col(1), detections.col(0) + detections.col(2),
detections.col(1) + detections.col(3);
std::vector<int> track_ids = match_detections_with_tracks(
tlbr_boxes.cast<double>(), tracks);
auto tracking_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<float, std::milli> tracking_duration = tracking_end - tracking_start;
std::cout << "Time for Tracking : " << tracking_duration.count() << " ms" << std::endl;

if (save_video) {
for (int i = 0; i < track_ids.size(); ++i) {
if (track_ids[i] == -1) {
continue;
}
cv::Scalar color;
switch (track_ids[i] % 6) {
case 0: color = cv::Scalar(255, 0, 0); break; // Blue
case 1: color = cv::Scalar(0, 255, 0); break; // Green
case 2: color = cv::Scalar(0, 0, 255); break; // Red
case 3: color = cv::Scalar(255, 255, 0); break; // Cyan
case 4: color = cv::Scalar(255, 0, 255); break; // Magenta
case 5: color = cv::Scalar(0, 255, 255); break; // Yellow
}
cv::rectangle(
frame, cv::Point(tlbr_boxes(i, 0), tlbr_boxes(i, 1)),
cv::Point(tlbr_boxes(i, 2), tlbr_boxes(i, 3)), color, 1, 8, 0);
cv::putText(frame, std::to_string(track_ids[i]), cv::Point(tlbr_boxes(i, 0), tlbr_boxes(i, 1)),
cv::FONT_HERSHEY_SIMPLEX, 2, color, 2, cv::LINE_AA);
}
video_writer.write(frame);
} else {
for (int i = 0; i < track_ids.size(); ++i) {
if (track_ids[i] == -1) {
continue;
}
cv::Scalar color;
switch (track_ids[i] % 6) {
case 0: color = cv::Scalar(255, 0, 0); break; // Blue
case 1: color = cv::Scalar(0, 255, 0); break; // Green
case 2: color = cv::Scalar(0, 0, 255); break; // Red
case 3: color = cv::Scalar(255, 255, 0); break; // Cyan
case 4: color = cv::Scalar(255, 0, 255); break; // Magenta
case 5: color = cv::Scalar(0, 255, 255); break; // Yellow
}
cv::rectangle(
frame, cv::Point(tlbr_boxes(i, 0), tlbr_boxes(i, 1)),
cv::Point(tlbr_boxes(i, 2), tlbr_boxes(i, 3)), color, 1, 8, 0);
cv::putText(frame, std::to_string(track_ids[i]), cv::Point(tlbr_boxes(i, 0), tlbr_boxes(i, 1)),
cv::FONT_HERSHEY_SIMPLEX, 2, color, 2, cv::LINE_AA);
}

cv::imshow("win", frame);
const int key = cv::waitKey(1);
if (key == 'q') {
break;
} else if (key == 's') {
cv::imwrite("img.png", frame);
}
}
}
if (save_video) {
video_writer.release();
}
cv::destroyAllWindows();
CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
return 0;
Expand Down

0 comments on commit 7582df9

Please sign in to comment.