Skip to content

Commit

Permalink
Update to burn 0.14
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Sep 16, 2024
1 parent 00dfeac commit 6fc4767
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions yolox-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0", default-features = false }
burn-import = { version = "0.13.0" }
burn = { version = "0.14.0", default-features = false }
burn-import = { version = "0.14.0" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
Expand All @@ -24,5 +24,5 @@ serde = { version = "1.0.192", default-features = false, features = [
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
burn = { version = "0.14.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
11 changes: 7 additions & 4 deletions yolox-burn/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox};

use burn::{
backend::NdArray,
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
tensor::{backend::Backend, Device, Element, Tensor, TensorData},
};

const HEIGHT: usize = 640;
Expand All @@ -16,9 +16,12 @@ fn to_tensor<B: Backend, T: Element>(
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
Tensor::<B, 3>::from_data(
TensorData::new(data, shape).convert::<B::FloatElem>(),
device,
)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
}

/// Draws bounding boxes on the given image.
Expand Down
9 changes: 3 additions & 6 deletions yolox-burn/src/model/boxes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,19 @@ pub fn nms<B: Backend>(
let (cls_score, cls_idx) = candidate_scores.squeeze::<2>(0).max_dim_with_indices(1);
let cls_score: Vec<_> = cls_score
.into_data()
.value
.iter()
.iter::<B::FloatElem>()
.map(|v| v.elem::<f32>())
.collect();
let cls_idx: Vec<_> = cls_idx
.into_data()
.value
.iter()
.iter::<B::IntElem>()
.map(|v| v.elem::<i64>() as usize)
.collect();

// [num_boxes, 4]
let candidate_boxes: Vec<_> = candidate_boxes
.into_data()
.value
.iter()
.iter::<B::FloatElem>()
.map(|v| v.elem::<f32>())
.collect();

Expand Down
4 changes: 2 additions & 2 deletions yolox-burn/src/model/head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ const PRIOR_PROB: f64 = 1e-2;
fn create_2d_grid<B: Backend>(x: usize, y: usize, device: &Device<B>) -> Tensor<B, 3, Int> {
let y_idx = Tensor::arange(0..y as i64, device)
.reshape(Shape::new([y, 1]))
.repeat(1, x)
.repeat_dim(1, x)
.reshape(Shape::new([y, x]));
let x_idx = Tensor::arange(0..x as i64, device)
.reshape(Shape::new([1, x])) // can only repeat with dim=1
.repeat(0, y)
.repeat_dim(0, y)
.reshape(Shape::new([y, x]));

Tensor::stack(vec![x_idx, y_idx], 2)
Expand Down

0 comments on commit 6fc4767

Please sign in to comment.