-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add yolox base models * Fix 2d grid for anchors * Change sample image * Add post-processing and inference results * Cleanup * Fix 2d grid repeat with dim=1 * Default to ndarray backend * Switch to YOLOX-tiny for example * Remove dead comment * Use tensor.dims() * Use the new tensor.permute() * Fix comments - Pre-trained weights are from COCO (README) - Remove training outputs TODO - Current example uses YOLOX-Tiny (Nano WIP) * Add YOLOX-Nano w/ depthwise separable conv (enum) * Remove dead code comment * Remove incorrect return comment * Add YOLOX to models README * Fix dead comments and add enum variants doc * Rephrase enum variants doc * Change burn-rs -> tracel-ai links * Upgrade to Burn 0.13.0 - Removed init_with methods - Fixed empty MaxPool2d vec initialization * Update image version
- Loading branch information
Showing
19 changed files
with
1,885 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Output image | ||
*.output.png |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[package] | ||
authors = ["guillaumelagrange <[email protected]>"] | ||
license = "MIT OR Apache-2.0" | ||
name = "yolox-burn" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
default = [] | ||
std = [] | ||
pretrained = ["burn/network", "std", "dep:dirs"] | ||
|
||
[dependencies] | ||
# Note: default-features = false is needed to disable std | ||
burn = { version = "0.13.0", default-features = false } | ||
burn-import = { version = "0.13.0" } | ||
itertools = { version = "0.12.1", default-features = false, features = [ | ||
"use_alloc", | ||
] } | ||
dirs = { version = "5.0.1", optional = true } | ||
serde = { version = "1.0.192", default-features = false, features = [ | ||
"derive", | ||
"alloc", | ||
] } # alloc is for no_std, derive is needed | ||
|
||
[dev-dependencies] | ||
burn = { version = "0.13.0", features = ["ndarray"] } | ||
image = { version = "0.24.9", features = ["png", "jpeg"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-APACHE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../LICENSE-MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# NOTICES AND INFORMATION | ||
|
||
This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. | ||
|
||
## Sample Image | ||
|
||
Image Title: Man with Bike and Pet Dog circa 1900 (archive ref DDX1319-2-3) | ||
Author: East Riding Archives | ||
Source: https://commons.wikimedia.org/wiki/File:Man_with_Bike_and_Pet_Dog_circa_1900_%28archive_ref_DDX1319-2-3%29_%2826507570321%29.jpg | ||
License: [Creative Commons](https://www.flickr.com/commons/usage/) | ||
|
||
## Pre-trained Model | ||
|
||
The COCO pre-trained model was ported from the original [YOLOX implementation](https://github.com/Megvii-BaseDetection/YOLOX). | ||
|
||
As opposed to other YOLO variants (YOLOv8, YOLO-NAS, etc.), both the code and pre-trained weights are distributed under the [Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# YOLOX Burn | ||
|
||
There have been many different object detection models with the YOLO prefix released in the recent | ||
years, though most of them carry a GPL or AGPL license which restricts their usage. For this reason, | ||
we selected [YOLOX](https://arxiv.org/abs/2107.08430) as the first object detection architecture | ||
since both the original code and pre-trained weights are released under the | ||
[Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license. | ||
|
||
You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the YOLOX variants in | ||
[src/model/yolox.rs](src/model/yolox.rs). | ||
|
||
The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). | ||
|
||
## Usage | ||
|
||
### `Cargo.toml` | ||
|
||
Add this to your `Cargo.toml`: | ||
|
||
```toml | ||
[dependencies] | ||
yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", default-features = false } | ||
``` | ||
|
||
If you want to get the COCO pre-trained weights, enable the `pretrained` feature flag. | ||
|
||
```toml | ||
[dependencies] | ||
yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", features = ["pretrained"] } | ||
``` | ||
|
||
**Important:** this feature requires `std`. | ||
|
||
### Example Usage | ||
|
||
The [inference example](examples/inference.rs) initializes a YOLOX-Tiny from the COCO | ||
[pre-trained weights](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#standard-models) | ||
with the `NdArray` backend and performs inference on the provided input image. | ||
|
||
You can run the example with the following command: | ||
|
||
```sh | ||
cargo run --release --features pretrained --example inference samples/dog_bike_man.jpg | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
use std::path::Path; | ||
|
||
use image::{DynamicImage, ImageBuffer}; | ||
use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox}; | ||
|
||
use burn::{ | ||
backend::NdArray, | ||
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor}, | ||
}; | ||
|
||
const HEIGHT: usize = 640; | ||
const WIDTH: usize = 640; | ||
|
||
fn to_tensor<B: Backend, T: Element>( | ||
data: Vec<T>, | ||
shape: [usize; 3], | ||
device: &Device<B>, | ||
) -> Tensor<B, 3> { | ||
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device) | ||
// [H, W, C] -> [C, H, W] | ||
.permute([2, 0, 1]) | ||
} | ||
|
||
/// Draws bounding boxes on the given image. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `image`: Original input image. | ||
/// * `boxes` - Bounding boxes, grouped per class. | ||
/// * `color` - [R, G, B] color values to draw the boxes. | ||
/// * `ratio` - [x, y] aspect ratio to scale the predicted boxes. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The image annotated with bounding boxes. | ||
fn draw_boxes( | ||
image: DynamicImage, | ||
boxes: &[Vec<BoundingBox>], | ||
color: &[u8; 3], | ||
ratio: &[f32; 2], // (x, y) ratio | ||
) -> DynamicImage { | ||
// Assumes x1 <= x2 and y1 <= y2 | ||
fn draw_rect( | ||
image: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>, | ||
x1: u32, | ||
x2: u32, | ||
y1: u32, | ||
y2: u32, | ||
color: &[u8; 3], | ||
) { | ||
for x in x1..=x2 { | ||
let pixel = image.get_pixel_mut(x, y1); | ||
*pixel = image::Rgb(*color); | ||
let pixel = image.get_pixel_mut(x, y2); | ||
*pixel = image::Rgb(*color); | ||
} | ||
for y in y1..=y2 { | ||
let pixel = image.get_pixel_mut(x1, y); | ||
*pixel = image::Rgb(*color); | ||
let pixel = image.get_pixel_mut(x2, y); | ||
*pixel = image::Rgb(*color); | ||
} | ||
} | ||
|
||
// Annotate the original image and print boxes information. | ||
let (image_h, image_w) = (image.height(), image.width()); | ||
let mut image = image.to_rgb8(); | ||
for (class_index, bboxes_for_class) in boxes.iter().enumerate() { | ||
for b in bboxes_for_class.iter() { | ||
let xmin = (b.xmin * ratio[0]).clamp(0., image_w as f32 - 1.); | ||
let ymin = (b.ymin * ratio[1]).clamp(0., image_h as f32 - 1.); | ||
let xmax = (b.xmax * ratio[0]).clamp(0., image_w as f32 - 1.); | ||
let ymax = (b.ymax * ratio[1]).clamp(0., image_h as f32 - 1.); | ||
|
||
println!( | ||
"Predicted {} ({:.2}) at [{:.2}, {:.2}, {:.2}, {:.2}]", | ||
class_index, b.confidence, xmin, ymin, xmax, ymax, | ||
); | ||
|
||
draw_rect( | ||
&mut image, | ||
xmin as u32, | ||
xmax as u32, | ||
ymin as u32, | ||
ymax as u32, | ||
color, | ||
); | ||
} | ||
} | ||
DynamicImage::ImageRgb8(image) | ||
} | ||
|
||
pub fn main() { | ||
// Parse arguments | ||
let img_path = std::env::args().nth(1).expect("No image path provided"); | ||
|
||
// Create YOLOX-Tiny | ||
let device = Default::default(); | ||
let model: Yolox<NdArray> = Yolox::yolox_tiny_pretrained(weights::YoloxTiny::Coco, &device) | ||
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Load image | ||
let img = image::open(&img_path) | ||
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Resize to 640x640 | ||
let resized_img = img.resize_exact( | ||
WIDTH as u32, | ||
HEIGHT as u32, | ||
image::imageops::FilterType::Triangle, // also known as bilinear in 2D | ||
); | ||
|
||
// Create tensor from image data | ||
let x = to_tensor( | ||
resized_img.into_rgb8().into_raw(), | ||
[HEIGHT, WIDTH, 3], | ||
&device, | ||
) | ||
.unsqueeze::<4>(); // [B, C, H, W] | ||
|
||
// Forward pass | ||
let out = model.forward(x); | ||
|
||
// Post-processing | ||
let [_, num_boxes, num_outputs] = out.dims(); | ||
let boxes = out.clone().slice([0..1, 0..num_boxes, 0..4]); | ||
let obj_scores = out.clone().slice([0..1, 0..num_boxes, 4..5]); | ||
let cls_scores = out.slice([0..1, 0..num_boxes, 5..num_outputs]); | ||
let scores = cls_scores * obj_scores; | ||
let boxes = nms(boxes, scores, 0.65, 0.5); | ||
|
||
// Draw outputs and save results | ||
let (h, w) = (img.height(), img.width()); | ||
let img_out = draw_boxes( | ||
img, | ||
&boxes[0], | ||
&[239u8, 62u8, 5u8], | ||
&[w as f32 / WIDTH as f32, h as f32 / HEIGHT as f32], | ||
); | ||
|
||
let img_path = Path::new(&img_path); | ||
let _ = img_out.save(img_path.with_extension("output.png")); | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#![cfg_attr(not(feature = "std"), no_std)] | ||
pub mod model; | ||
extern crate alloc; |
Oops, something went wrong.