diff --git a/Cargo.toml b/Cargo.toml index 6d4ad975..98d12192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ "crates/kornia-core", + "crates/kornia-dnn", "crates/kornia-image", "crates/kornia-io", "crates/kornia-imgproc", @@ -26,6 +27,7 @@ version = "0.1.6+dev" [workspace.dependencies] kornia-core = { path = "crates/kornia-core", version = "0.1.6+dev" } +kornia-dnn = { path = "crates/kornia-dnn", version = "0.1.6+dev" } kornia-image = { path = "crates/kornia-image", version = "0.1.6+dev" } kornia-io = { path = "crates/kornia-io", version = "0.1.6+dev" } kornia-imgproc = { path = "crates/kornia-imgproc", version = "0.1.6+dev" } diff --git a/crates/kornia-dnn/Cargo.toml b/crates/kornia-dnn/Cargo.toml new file mode 100644 index 00000000..8b8c952d --- /dev/null +++ b/crates/kornia-dnn/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "kornia-dnn" +authors.workspace = true +description = "ONNX Deep Neural Network (DNN) library for Rust" +edition.workspace = true +homepage.workspace = true +license.workspace = true +publish = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[features] +ort-load-dynamic = ["ort/load-dynamic"] +ort-cuda = ["ort/cuda"] + +[dependencies] +kornia-core = { workspace = true } +kornia-image = { workspace = true } +ort = { version = "2.0.0-rc.4", default-features = false } +thiserror = "1" diff --git a/crates/kornia-dnn/src/error.rs b/crates/kornia-dnn/src/error.rs new file mode 100644 index 00000000..e0c4cda2 --- /dev/null +++ b/crates/kornia-dnn/src/error.rs @@ -0,0 +1,14 @@ +#[derive(thiserror::Error, Debug)] +pub enum DnnError { + #[error("Please set the ORT_DYLIB_PATH environment variable to the path of the ORT dylib. Error: {0}")] + OrtDylibError(String), + + #[error("Failed to create ORT session")] + OrtError(#[from] ort::Error), + + #[error("Image error")] + ImageError(#[from] kornia_image::ImageError), + + #[error("Tensor error")] + TensorError(#[from] kornia_core::TensorError), +} diff --git a/crates/kornia-dnn/src/lib.rs b/crates/kornia-dnn/src/lib.rs new file mode 100644 index 00000000..0711869d --- /dev/null +++ b/crates/kornia-dnn/src/lib.rs @@ -0,0 +1,30 @@ +//! # Kornia DNN +//! +//! This module contains DNN (Deep Neural Network) related functionality. + +/// Error type for the dnn module. +pub mod error; + +/// This module contains the RT-DETR model. +pub mod rtdetr; + +// re-export ort::ExecutionProvider +pub use ort::{CPUExecutionProvider, CUDAExecutionProvider, TensorRTExecutionProvider}; + +// TODO: put this in to some sort of structs pool module +/// Represents a detected object in an image. +#[derive(Debug)] +pub struct Detection { + /// The class label of the detected object. + pub label: u32, + /// The confidence score of the detection (typically between 0 and 1). + pub score: f32, + /// The x-coordinate of the top-left corner of the bounding box. + pub x: f32, + /// The y-coordinate of the top-left corner of the bounding box. + pub y: f32, + /// The width of the bounding box. + pub w: f32, + /// The height of the bounding box. + pub h: f32, +} diff --git a/crates/kornia-dnn/src/rtdetr.rs b/crates/kornia-dnn/src/rtdetr.rs new file mode 100644 index 00000000..8efa8951 --- /dev/null +++ b/crates/kornia-dnn/src/rtdetr.rs @@ -0,0 +1,206 @@ +//! # RT-DETR +//! +//! This module contains the RT-DETR model. +//! +//! The RT-DETR model is a state-of-the-art object detection model. + +use std::{path::PathBuf, sync::Arc}; + +use crate::{error::DnnError, CPUExecutionProvider, Detection}; +use kornia_core::{CpuAllocator, Tensor}; +use kornia_image::Image; +use ort::{ExecutionProviderDispatch, GraphOptimizationLevel, Session}; + +/// Builder for the RT-DETR detector. +/// +/// This struct provides a convenient way to configure and create an `RTDETRDetector` instance. +pub struct RTDETRDetectorBuilder { + /// Path to the RT-DETR model file. + pub model_path: PathBuf, + /// Number of threads to use for inference. + pub num_threads: usize, + /// Execution providers to use for inference. + pub execution_providers: Vec, +} + +impl RTDETRDetectorBuilder { + /// Creates a new `RTDETRDetectorBuilder` with default settings. + /// + /// # Arguments + /// + /// * `model_path` - Path to the RT-DETR model file. + /// + /// # Returns + /// + /// A `Result` containing the `RTDETRDetectorBuilder` if successful, or a `DnnError` if an error occurred. + pub fn new(model_path: PathBuf) -> Result { + Ok(Self { + model_path, + num_threads: 4, + execution_providers: vec![CPUExecutionProvider::default().build()], + }) + } + + /// Sets the number of threads to use for inference. + /// + /// # Arguments + /// + /// * `num_threads` - The number of threads to use. + /// + /// # Returns + /// + /// The updated `RTDETRDetectorBuilder` instance. + pub fn with_num_threads(mut self, num_threads: usize) -> Self { + self.num_threads = num_threads; + self + } + + /// Sets the execution providers to use for inference. + /// + /// # Arguments + /// + /// * `execution_providers` - The execution providers to use. + /// + /// # Returns + /// + /// The updated `RTDETRDetectorBuilder` instance. + pub fn with_execution_providers( + mut self, + execution_providers: Vec, + ) -> Self { + self.execution_providers = execution_providers; + self + } + + /// Builds and returns an `RTDETRDetector` instance. + /// + /// # Returns + /// + /// A `Result` containing the `RTDETRDetector` if successful, or a `DnnError` if an error occurred. + pub fn build(self) -> Result { + RTDETRDetector::new(self.model_path, self.num_threads, self.execution_providers) + } +} + +/// RT-DETR object detector. +/// +/// This struct represents an instance of the RT-DETR object detection model. +pub struct RTDETRDetector { + session: Arc, +} + +impl RTDETRDetector { + // TODO: default to hf hub + /// Creates a new `RTDETRDetector` instance. + /// + /// # Arguments + /// + /// * `model_path` - Path to the RT-DETR model file. + /// * `num_threads` - Number of threads to use for inference. + /// + /// # Returns + /// + /// A `Result` containing the `RTDETRDetector` if successful, or a `DnnError` if an error occurred. + /// + /// Pre-requisites: + /// - ORT_DYLIB_PATH environment variable must be set to the path of the ORT dylib. + pub fn new( + model_path: PathBuf, + num_threads: usize, + execution_providers: Vec, + ) -> Result { + // get the ort dylib path from the environment variable + let dylib_path = + std::env::var("ORT_DYLIB_PATH").map_err(|e| DnnError::OrtDylibError(e.to_string()))?; + + // set the ort dylib path + ort::init_from(dylib_path).commit()?; + + // create the ort session + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(num_threads)? + .with_execution_providers(execution_providers)? + .commit_from_file(model_path)?; + + let session = Arc::new(session); + // TODO: perform a dummy run to warm up the model + // let session_clone = session.clone(); + // std::thread::spawn(move || -> Result<(), DnnError> { + // let dummy_input = + // ort::Tensor::from_array(([480, 640 * 3], vec![0.0f32; 480 * 640 * 3]))?; + // session_clone.run(ort::inputs!["input" => dummy_input]?)?; + // Ok(()) + // }); + + Ok(Self { session }) + } + + /// Runs object detection on the given image. + /// + /// # Arguments + /// + /// * `image` - The input image as an `Image`. + /// + /// # Returns + /// + /// A `Result` containing a vector of `Detection` objects if successful, or a `DnnError` if an error occurred. + pub fn run(&self, image: &Image) -> Result, DnnError> { + // TODO: explore pre-allocating memory for the image + // cast and scale the image to f32 + let mut image_hwc_f32 = Image::from_size_val(image.size(), 0.0f32)?; + kornia_image::ops::cast_and_scale(image, &mut image_hwc_f32, 1.0 / 255.)?; + + // convert to HWC -> CHW + let image_chw = image_hwc_f32.permute_axes([2, 0, 1]).as_contiguous(); + + // TODO: create a Tensor::insert_axis in kornia-rs + let image_nchw = Tensor::from_shape_vec( + [ + 1, + image_chw.shape[0], + image_chw.shape[1], + image_chw.shape[2], + ], + image_chw.into_vec(), + CpuAllocator, + )?; + + // make the ort tensor + let ort_tensor = ort::Tensor::from_array((image_nchw.shape, image_nchw.into_vec()))?; + + // run the model + let outputs = self.session.run(ort::inputs!["input" => ort_tensor]?)?; + + // extract the output tensor + let (out_shape, out_ort) = outputs[0].try_extract_raw_tensor::()?; + + let out_tensor = Tensor::::from_shape_vec( + [ + out_shape[0] as usize, + out_shape[1] as usize, + out_shape[2] as usize, + ], + out_ort.to_vec(), + CpuAllocator, + )?; + + // parse the output tensor + // we expect the output tensor to be a tensor of shape [1, N, 6] + // where each element is a detection [label, score, x, y, w, h] + let detections = out_tensor + .as_slice() + .chunks_exact(6) + .map(|chunk| Detection { + label: chunk[0] as u32, + score: chunk[1], + x: chunk[2], + y: chunk[3], + w: chunk[4], + h: chunk[5], + }) + .collect::>(); + + Ok(detections) + } +} diff --git a/crates/kornia/Cargo.toml b/crates/kornia/Cargo.toml index 075283de..69687994 100644 --- a/crates/kornia/Cargo.toml +++ b/crates/kornia/Cargo.toml @@ -14,9 +14,12 @@ version.workspace = true [features] gstreamer = ["kornia-io/gstreamer"] jpegturbo = ["kornia-io/jpegturbo"] +ort-load-dynamic = ["kornia-dnn/ort-load-dynamic"] +ort-cuda = ["kornia-dnn/ort-cuda"] [dependencies] kornia-core.workspace = true +kornia-dnn = { workspace = true, features = [] } kornia-image.workspace = true kornia-imgproc.workspace = true kornia-io = { workspace = true, features = [] } diff --git a/crates/kornia/src/lib.rs b/crates/kornia/src/lib.rs index c494ce0e..09cbe966 100644 --- a/crates/kornia/src/lib.rs +++ b/crates/kornia/src/lib.rs @@ -1,6 +1,9 @@ #[doc(inline)] pub use kornia_core as core; +#[doc(inline)] +pub use kornia_dnn as dnn; + #[doc(inline)] pub use kornia_image as image; diff --git a/examples/onnx/Cargo.toml b/examples/onnx/Cargo.toml index 6c7ce852..d395987a 100644 --- a/examples/onnx/Cargo.toml +++ b/examples/onnx/Cargo.toml @@ -7,7 +7,6 @@ edition.workspace = true homepage.workspace = true include.workspace = true license.workspace = true -license-file.workspace = true readme.workspace = true repository.workspace = true rust-version.workspace = true diff --git a/examples/rtdetr/Cargo.toml b/examples/rtdetr/Cargo.toml new file mode 100644 index 00000000..f381335a --- /dev/null +++ b/examples/rtdetr/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rtdetr" +version = "0.1.0" +authors = ["Edgar Riba "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +clap = { version = "4.5.4", features = ["derive"] } +ctrlc = "3.4.4" +kornia = { workspace = true, features = [ + "gstreamer", + "ort-load-dynamic", + "ort-cuda", +] } +rerun = "0.18" +tokio = { version = "1" } diff --git a/examples/rtdetr/README.md b/examples/rtdetr/README.md new file mode 100644 index 00000000..d92bf5e9 --- /dev/null +++ b/examples/rtdetr/README.md @@ -0,0 +1,27 @@ +An example showing how to use the RTDETR model with the `kornia::dnn` module and the webcam with the `kornia::io` module with the ability to cancel the feed after a certain amount of time. This example will display the webcam feed in a [`rerun`](https://github.com/rerun-io/rerun) window. + +NOTE: This example requires the gstremer backend to be enabled. To enable the gstreamer backend, use the `gstreamer` feature flag when building the `kornia` crate and its dependencies. + +## Prerequisites + +Maily you need to download onnxruntime from: + +## Usage + +```bash +Usage: rtdetr [OPTIONS] --model-path + +Options: + -c, --camera-id [default: 0] + -f, --fps [default: 5] + -m, --model-path + -n, --num-threads [default: 8] + -s, --score-threshold [default: 0.75] + -h, --help Print help +``` + +Example: + +```bash +ORT_DYLIB_PATH=/path/to/libonnxruntime.so cargo run --bin rtdetr --release -- --camera-id 0 --model-path rtdetr.onnx --num-threads 8 --score-threshold 0.75 +``` diff --git a/examples/rtdetr/src/main.rs b/examples/rtdetr/src/main.rs new file mode 100644 index 00000000..8995fa78 --- /dev/null +++ b/examples/rtdetr/src/main.rs @@ -0,0 +1,140 @@ +use clap::Parser; +use std::{ + path::PathBuf, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, +}; + +use kornia::{ + dnn::{ + rtdetr::RTDETRDetectorBuilder, + {CPUExecutionProvider, CUDAExecutionProvider}, + }, + io::{ + fps_counter::FpsCounter, + stream::{StreamCaptureError, V4L2CameraConfig}, + }, +}; + +#[derive(Parser)] +struct Args { + #[arg(short, long, default_value = "0")] + camera_id: u32, + + #[arg(short, long, default_value = "5")] + fps: u32, + + #[arg(short, long)] + model_path: PathBuf, + + #[arg(short, long, default_value = "8")] + num_threads: usize, + + #[arg(short, long, default_value = "0.75")] + score_threshold: f32, + + #[arg(short, long)] + use_cuda: bool, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + + // start the recording stream + let rec = rerun::RecordingStreamBuilder::new("Kornia RTDETR App").spawn()?; + + // create a webcam capture object with camera id 0 + // and force the image size to 640x480 + let camera = V4L2CameraConfig::new() + .with_camera_id(args.camera_id) + .with_size([640, 480].into()) + .with_fps(args.fps) + .build()?; + + let mut execution_providers = vec![CPUExecutionProvider::default().build()]; + if args.use_cuda { + execution_providers.push(CUDAExecutionProvider::default().build()); + } + + let detector = RTDETRDetectorBuilder::new(args.model_path)? + .with_num_threads(args.num_threads) + .with_execution_providers(execution_providers) + .build()?; + + // create a cancel token to stop the webcam capture + let cancel_token = Arc::new(AtomicBool::new(false)); + + // create a shared fps counter + let fps_counter = Arc::new(Mutex::new(FpsCounter::new())); + + ctrlc::set_handler({ + let cancel_token = cancel_token.clone(); + move || { + println!("Received Ctrl-C signal. Sending cancel signal !!"); + cancel_token.store(true, Ordering::SeqCst); + } + })?; + + // start grabbing frames from the camera + camera + .run(|img| { + // check if the cancel token is set, if so we return an error to stop the pipeline + if cancel_token.load(Ordering::SeqCst) { + return Err(StreamCaptureError::PipelineCancelled.into()); + } + + // run the detector + let detections = detector.run(&img)?; + + // filter the detections by score + let detections = detections + .into_iter() + .filter(|d| d.score > args.score_threshold); + + // update the fps counter + fps_counter + .lock() + .expect("Failed to lock fps counter") + .new_frame(); + + // log the detections + let mut boxes_mins = Vec::new(); + let mut boxes_sizes = Vec::new(); + let mut class_ids = Vec::new(); + for detection in detections { + boxes_mins.push((detection.x, detection.y)); + boxes_sizes.push((detection.w, detection.h)); + class_ids.push(detection.label as u16); + } + + // log the image + rec.log_static( + "image", + &rerun::Image::from_elements( + img.as_slice(), + img.size().into(), + rerun::ColorModel::RGB, + ), + )?; + + // log the detections + rec.log_static( + "detections", + &rerun::Boxes2D::from_mins_and_sizes(boxes_mins, boxes_sizes) + .with_class_ids(class_ids), + )?; + + Ok(()) + }) + .await?; + + // NOTE: this is important to close the webcam properly, otherwise the app will hang + camera.close()?; + + println!("Finished recording. Closing app."); + + Ok(()) +} diff --git a/kornia-py/Cargo.toml b/kornia-py/Cargo.toml index 4ffcfd7e..ac141012 100644 --- a/kornia-py/Cargo.toml +++ b/kornia-py/Cargo.toml @@ -18,7 +18,11 @@ crate-type = ["cdylib"] [dependencies] # kornia -kornia = { path = "../crates/kornia", features = ["gstreamer", "jpegturbo"] } +kornia = { path = "../crates/kornia", features = [ + "gstreamer", + "jpegturbo", + "ort-load-dynamic", +] } # external pyo3 = { version = "0.21.2", features = ["extension-module"] }