Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add new burn-vision crate with one initial op #2753

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e7a84b9
Update cubecl
wingertge Jan 25, 2025
267ec63
Update to scope merge
wingertge Jan 26, 2025
71584e0
Fix bitwise shift
wingertge Jan 26, 2025
5838364
Initial JIT implementation
wingertge Jan 28, 2025
d292f85
Merge branch 'main' into feat/burn-vision
wingertge Jan 28, 2025
9e65150
Move testgen to burn-jit
wingertge Jan 28, 2025
0484f51
Improve HA4/8 algo
wingertge Jan 28, 2025
f62a9ee
Terminate units past the predefined 32 plane size
wingertge Jan 28, 2025
8edac2b
move jit backend back into `burn-vision` and make tests work
wingertge Jan 30, 2025
05b40e3
Add initial CPU implementation without stats
wingertge Jan 30, 2025
7708993
Implement stats
wingertge Jan 31, 2025
aeea3a8
Implement all backends except fusion
wingertge Jan 31, 2025
a994ca7
Fix autodiff to use GPU when available
wingertge Jan 31, 2025
866307b
Fixes and cleanup
wingertge Jan 31, 2025
a8e3994
Add docs
wingertge Jan 31, 2025
021360b
Update cubecl
wingertge Jan 31, 2025
a1d727f
Merge branch 'main' into feat/burn-vision
wingertge Jan 31, 2025
01ff01b
Compact labels for JIT
wingertge Feb 1, 2025
d790113
Improve JIT backend implementation by adding label compaction
wingertge Feb 2, 2025
15c431c
Use GPU reduction for max label
wingertge Feb 2, 2025
e3ec085
Manually fuse presence and prefix sum
wingertge Feb 2, 2025
11c8f1f
Make prefix sum more generic over line size
wingertge Feb 2, 2025
ee5ad73
Merge branch 'main' into feat/burn-vision
wingertge Feb 3, 2025
e6126c8
Add vision tests to xtask
wingertge Feb 3, 2025
1bbf50a
Fix CPU and other review stuff
wingertge Feb 3, 2025
70b8b7e
Merge branch 'main' into feat/burn-vision
wingertge Feb 5, 2025
db92cd8
Merge branch 'main' into feat/burn-vision
wingertge Feb 6, 2025
4f174a8
Add publish job
laggui Feb 7, 2025
5a1ada3
Review fixes
wingertge Feb 8, 2025
5a2931f
Merge branch 'feat/burn-vision' of https://github.com/wingertge/burn …
wingertge Feb 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ on:
- "v*"

jobs:
publish-burn-vision:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
crate: burn-vision
needs:
- publish-burn-autodiff
- publish-burn-candle
- publish-burn-fusion
- publish-burn-jit
- publish-burn-ndarray
- publish-burn-tch
- publish-burn-tensor
- publish-burn-tensor-testgen
# dev dependencies
- publish-burn-wgpu
- publish-burn-cuda
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

publish-burn-router:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with:
Expand Down
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/burn-candle/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use burn_tensor::Element;
use candle_core::{FloatDType, Tensor, WithDType};
use half::{bf16, f16};

/// Candle element
pub trait CandleElement: Element + WithDType {}
/// Candle float element
pub trait FloatCandleElement: CandleElement + FloatDType {}
/// Candle int element
pub trait IntCandleElement: CandleElement {}

impl CandleElement for f64 {}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod ops;
mod tensor;

pub use backend::*;
pub use element::*;
pub use tensor::*;

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
version.workspace = true

[features]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
autotune = ["burn-jit/autotune"]
default = ["fusion", "autotune", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
std = ["burn-jit/std", "cubecl/std"]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub(crate) use flip::*;
pub(crate) use repeat_dim::*;
pub(crate) use select::*;
pub(crate) use select_assign::*;
pub(crate) use slice::*;
pub use slice::*;
pub(crate) use slice_assign::*;

pub(crate) use gather::*;
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/index/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use burn_tensor::Shape;
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use std::ops::Range;

pub(crate) fn slice<R: JitRuntime, E: JitElement>(
/// Slice a jit tensor with a set of ranges
pub fn slice<R: JitRuntime, E: JitElement>(
tensor: JitTensor<R>,
indices: &[Range<usize>],
) -> JitTensor<R> {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ pub mod reduce;

pub(crate) use clamp::*;
pub(crate) use comparison::*;
pub(crate) use index::*;
pub use index::*;
3 changes: 2 additions & 1 deletion crates/burn-jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
extern crate derive_new;
extern crate alloc;

mod ops;
/// Utilities for implementing JIT kernels
pub mod ops;

/// Kernel module
pub mod kernel;
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub(crate) fn swap_dims<R: JitRuntime>(
tensor
}

/// Permute a tensor's dimensions
pub fn permute<R: JitRuntime>(mut tensor: JitTensor<R>, axes: &[usize]) -> JitTensor<R> {
// remap strides
tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect();
Expand Down Expand Up @@ -135,7 +136,8 @@ pub(crate) fn expand<R: JitRuntime>(tensor: JitTensor<R>, target_shape: Shape) -
}
}

pub(crate) fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
/// Reshape a jit tensor to a new shape
pub fn reshape<R: JitRuntime>(tensor: JitTensor<R>, shape: Shape) -> JitTensor<R> {
// TODO: Not force standard layout all the time (improve performance).
let tensor = kernel::into_contiguous(tensor);

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-jit/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod qtensor;
mod transaction;

pub(crate) mod base;
pub(crate) use base::*;
pub use base::*;

pub(crate) mod numeric;
/// Numeric utility functions for jit backends
pub mod numeric;
24 changes: 24 additions & 0 deletions crates/burn-jit/src/ops/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use cubecl::client::ComputeClient;
use cubecl::tensor_vectorization_factor;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

/// Create a tensor filled with `value`
pub fn full<R: JitRuntime, E: JitElement>(
shape: Shape,
device: &R::Device,
Expand All @@ -19,6 +20,7 @@ pub fn full<R: JitRuntime, E: JitElement>(
full_device::<R, E>(client, shape, device.clone(), value)
}

/// Create a tensor filled with `value`
pub fn full_device<R: JitRuntime, E: JitElement>(
client: ComputeClient<R::Server, R::Channel>,
shape: Shape,
Expand Down Expand Up @@ -56,12 +58,14 @@ pub fn full_device<R: JitRuntime, E: JitElement>(
empty
}

/// Create a tensor filled with zeros
pub fn zeros<R: JitRuntime, E: JitElement>(shape: Shape, device: &R::Device) -> JitTensor<R> {
let client = R::client(device);

zeros_device::<R, E>(client, device.clone(), shape)
}

/// Create a tensor filled with zeros
pub fn zeros_device<R: JitRuntime, E: JitElement>(
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
Expand All @@ -70,12 +74,14 @@ pub fn zeros_device<R: JitRuntime, E: JitElement>(
full_device::<R, E>(client, shape, device, 0.elem())
}

/// Create a tensor filled with ones
pub fn ones<R: JitRuntime, E: JitElement>(shape: Shape, device: &R::Device) -> JitTensor<R> {
let client = R::client(device);

ones_device::<R, E>(client, device.clone(), shape)
}

/// Create a tensor filled with ones
pub fn ones_device<R: JitRuntime, E: JitElement>(
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
Expand All @@ -84,6 +90,7 @@ pub fn ones_device<R: JitRuntime, E: JitElement>(
full_device::<R, E>(client, shape, device, 1.elem())
}

/// Create a tensor with uninitialized memory
pub fn empty_device<R: JitRuntime, E: JitElement>(
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
Expand All @@ -94,82 +101,99 @@ pub fn empty_device<R: JitRuntime, E: JitElement>(
JitTensor::new_contiguous(client, device, shape, buffer, E::dtype())
}

/// Add two tensors
pub fn add<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: JitTensor<R>) -> JitTensor<R> {
launch_binop::<R, E, AddOp>(lhs, rhs)
}

/// Add a tensor and a scalar
pub fn add_scalar<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop::<R, E, AddOp>(lhs, rhs)
}

/// Subtract two tensors
pub fn sub<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: JitTensor<R>) -> JitTensor<R> {
launch_binop::<R, E, SubOp>(lhs, rhs)
}

/// Subtract a tensor and a scalar
pub fn sub_scalar<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop::<R, E, SubOp>(lhs, rhs)
}

/// Multiply two tensors
pub fn mul<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: JitTensor<R>) -> JitTensor<R> {
launch_binop::<R, E, MulOp>(lhs, rhs)
}

/// Multiply a tensor and a scalar
pub fn mul_scalar<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop::<R, E, MulOp>(lhs, rhs)
}

/// Divide two tensors
pub fn div<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: JitTensor<R>) -> JitTensor<R> {
launch_binop::<R, E, DivOp>(lhs, rhs)
}

/// Divide a tensor by a scalar
pub fn div_scalar<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop::<R, E, DivOp>(lhs, rhs)
}

/// Calculate remainder of two tensors
pub fn remainder<R: JitRuntime, E: JitElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
launch_binop::<R, E, RemainderOp>(lhs, rhs)
}

/// Calculate the remainder of a tensor with a scalar
pub fn remainder_scalar<R: JitRuntime, E: JitElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop::<R, E, RemainderOp>(lhs, rhs)
}

/// Calculate the power of two tensors
pub fn pow<R: JitRuntime, E: FloatElement>(lhs: JitTensor<R>, rhs: JitTensor<R>) -> JitTensor<R> {
launch_binop::<R, E, PowOp<E>>(lhs, rhs)
}

/// Bitwise and two tensors
pub fn bitwise_and<R: JitRuntime, E: IntElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
launch_binop_int::<R, E, BitwiseAndOp>(lhs, rhs)
}

/// Bitwise and with a scalar
pub fn bitwise_and_scalar<R: JitRuntime, E: IntElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop_int::<R, E, BitwiseAndOp>(lhs, rhs)
}

/// Bitwise or two tensors
pub fn bitwise_or<R: JitRuntime, E: IntElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
launch_binop_int::<R, E, BitwiseOrOp>(lhs, rhs)
}

/// Bitwise or with a scalar
pub fn bitwise_or_scalar<R: JitRuntime, E: IntElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop_int::<R, E, BitwiseOrOp>(lhs, rhs)
}

/// Bitwise xor two tensors
pub fn bitwise_xor<R: JitRuntime, E: IntElement>(
lhs: JitTensor<R>,
rhs: JitTensor<R>,
) -> JitTensor<R> {
launch_binop_int::<R, E, BitwiseXorOp>(lhs, rhs)
}

/// Bitwise xor with a scalar
pub fn bitwise_xor_scalar<R: JitRuntime, E: IntElement>(lhs: JitTensor<R>, rhs: E) -> JitTensor<R> {
launch_scalar_binop_int::<R, E, BitwiseXorOp>(lhs, rhs)
}
3 changes: 2 additions & 1 deletion crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ pub struct JitTensor<R: JitRuntime> {
pub device: R::Device,
/// The strides of the tensor.
pub strides: Vec<usize>,
pub(crate) dtype: DType,
/// The datatype of the tensor.
pub dtype: DType,
}

impl<R: JitRuntime, E: JitElement> From<JitTensor<R>> for TensorHandle<R, E> {
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ where
{
}

/// An int element for ndarray backend.
pub trait IntNdArrayElement: NdArrayElement + Signed {}

/// A general element for ndarray backend.
Expand All @@ -34,13 +35,21 @@ pub trait NdArrayElement:

/// A element for ndarray backend that supports exp ops.
pub trait ExpElement {
/// Exponent
fn exp_elem(self) -> Self;
/// Log
fn log_elem(self) -> Self;
/// Log1p
fn log1p_elem(self) -> Self;
/// Powf
fn powf_elem(self, value: f32) -> Self;
/// Powi
fn powi_elem(self, value: i32) -> Self;
/// Sqrt
fn sqrt_elem(self) -> Self;
/// Abs
fn abs_elem(self) -> Self;
/// Abs for int
fn int_abs_elem(self) -> Self;
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-ndarray/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod sharing;
mod tensor;

pub use backend::*;
pub use element::FloatNdArrayElement;
pub use element::*;
pub(crate) use sharing::*;
pub use tensor::*;

Expand Down
10 changes: 5 additions & 5 deletions crates/burn-ndarray/src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use ndarray::{
};

use crate::{
element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},
element::FloatNdArrayElement,
ops::padding::{apply_padding_4d, apply_padding_5d},
sharing::UnsafeSharedRef,
tensor::NdArrayTensor,
Expand Down Expand Up @@ -98,7 +98,7 @@ fn conv3d_mad_inner<E: FloatNdArrayElement>(
}
}

pub(crate) fn conv2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
pub(crate) fn conv2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E>,
weight: NdArrayTensor<E>,
bias: Option<NdArrayTensor<E>>,
Expand Down Expand Up @@ -126,7 +126,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantEleme
in_width,
);

let x = apply_padding_4d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
let x = apply_padding_4d::<E>(x, options.padding, 0i32.elem()).array;

// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();
Expand Down Expand Up @@ -310,7 +310,7 @@ pub(crate) fn conv_transpose2d<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}

pub(crate) fn conv3d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement>(
pub(crate) fn conv3d<E: FloatNdArrayElement>(
x: NdArrayTensor<E>,
weight: NdArrayTensor<E>,
bias: Option<NdArrayTensor<E>>,
Expand Down Expand Up @@ -345,7 +345,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantEleme
in_width,
);

let x = apply_padding_5d::<E, I, Q>(x, options.padding, 0i32.elem()).array;
let x = apply_padding_5d::<E>(x, options.padding, 0i32.elem()).array;

// Convert inputs from dynamic indexes to static to improve perf.
let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();
Expand Down
Loading