Skip to content

Commit

Permalink
Fix parallel macro + CI (#2678)
Browse files Browse the repository at this point in the history
* Fix rayon issues

* Fix typos

* Fix for_each no std

* Fix clippy
  • Loading branch information
laggui authored Jan 10, 2025
1 parent da8de56 commit 95593fc
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 60 deletions.
2 changes: 1 addition & 1 deletion backend-comparison/src/burnbenchapp/auth/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn verify_tokens(tokens: &Tokens) -> bool {
)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
response.map_or(false, |resp| resp.status().is_success())
response.is_ok_and(|resp| resp.status().is_success())
}

fn refresh_tokens(tokens: &Tokens) -> Option<Tokens> {
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub mod id;

pub use cubecl_common::*;

#[cfg(feature = "rayon")]
pub use rayon;

extern crate alloc;

/// Network utilities.
Expand Down
75 changes: 57 additions & 18 deletions crates/burn-common/src/parallel.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,90 @@
/// Macro for running a function in parallel.
#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
$func:expr
) => {{
#[cfg(feature = "rayon")]
use rayon::prelude::*;
use $crate::rayon::prelude::*;

#[cfg(feature = "rayon")]
#[allow(clippy::redundant_closure_call)]
let output = rayon::scope(|_| $func());
$crate::rayon::scope(|_| $func())
}};
}

#[cfg(not(feature = "rayon"))]
let output = $func();
/// Macro for running a function in parallel.
#[cfg(not(feature = "rayon"))]
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
$func:expr
) => {{
$func()
}};
}

output
/// Macro for iterating in parallel.
#[cfg(not(feature = "rayon"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{
$iter
}};
}

/// Macro for iterating in parallel.
#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(
$iter:expr
) => {{
#[cfg(feature = "rayon")]
let output = $iter.into_par_iter();
$iter.into_par_iter()
}};
}

#[cfg(not(feature = "rayon"))]
let output = $iter;
/// Macro for iterating in parallel.
#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! iter_slice_par {
(
$slice:expr
) => {{
$slice.into_par_iter()
}};
}

output
/// Macro for iterating in parallel.
#[cfg(not(feature = "rayon"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_slice_par {
(
$slice:expr
) => {{
$slice.iter()
}};
}

/// Macro for iterating over a range in parallel.
#[cfg(feature = "rayon")]
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{
#[cfg(feature = "rayon")]
let output = ($start..$end).into_par_iter();

#[cfg(not(feature = "rayon"))]
let output = ($start..$end);
($start..$end).into_par_iter()
}};
}

output
/// Macro for iterating over a range in parallel.
#[cfg(not(feature = "rayon"))]
#[macro_export(local_inner_macros)]
macro_rules! iter_range_par {
(
$start:expr, $end:expr
) => {{
($start..$end)
}};
}
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/on_write/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl<R: Runtime> GlobalArgsLaunch<'_, R> {
}
}

/// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg).
/// Resolve the [argument](Arg) to a [tensor argument](TensorArg).
///
/// # Panics
///
Expand Down
52 changes: 29 additions & 23 deletions crates/burn-ndarray/src/ops/deform_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_tensor::{
use core::ops::AddAssign;
use ndarray::{
s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim,
Ix4,
Ix4, Zip,
};
#[cfg(not(feature = "std"))]
use num_traits::Float;
Expand Down Expand Up @@ -593,31 +593,37 @@ pub mod backward {
AtomicF32::new(0.0)
});

let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| {
let group = in_channel / channels_per_offset_group;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let offset = [offset[0], offset[1]];
let mask = mask
.as_ref()
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
};

// `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise
#[cfg(feature = "std")]
run_par!(|| {
iter_par!(columns.indexed_iter()).for_each(
|((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| {
let group = in_channel / channels_per_offset_group;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let offset = [offset[0], offset[1]];
let mask = mask
.as_ref()
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
},
)
iter_par!(Zip::indexed(columns))
.for_each(|(args0, args1)| compute_for_each(args0, args1))
});

#[cfg(not(feature = "std"))]
run_par!(|| { iter_par!(Zip::indexed(columns).for_each(compute_for_each)) });

let grad_in: Array1<F> = grad_in
.into_iter()
.map(|it| F::from_elem(it.into_inner()))
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/quantization/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl QuantizedBytes {

/// Splits the quantized values of the tensor from the quantization parameters.
///
/// Returns the packed values and a newly allocated vector containining the quantization parameters.
/// Returns the packed values and a newly allocated vector containing the quantization parameters.
fn split_values_off(self) -> (Vec<u32>, Vec<u32>) {
// The bytes can be created either from packed u32 or existing bytes with the same representation.
let mut values = match self.bytes.align() {
Expand Down
43 changes: 27 additions & 16 deletions crates/burn-tensor/src/tensor/quantization/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::{
};

use alloc::vec::Vec;
use burn_common::{iter_par, run_par};
use burn_common::{iter_slice_par, run_par};
use num_traits::{Float, PrimInt};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -35,7 +35,7 @@ impl QuantizationStrategy {

/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision
/// data type `Q` and vice-versa.
pub trait Quantization<E: Float, Q: PrimInt> {
pub trait Quantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
/// Create a new quantization scheme for an input range `[alpha, beta]`.
fn new(alpha: E, beta: E) -> Self;
/// Convert the values to a lower precision data type.
Expand All @@ -48,7 +48,7 @@ pub trait Quantization<E: Float, Q: PrimInt> {
///
/// Note that the accumulation type `A` should have a bigger range than quantized type `Q`.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct AffineQuantization<E: Float, Q: PrimInt, A: PrimInt> {
pub struct AffineQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> {
/// The scaling factor.
pub scale: E,
/// The zero-point offset.
Expand All @@ -66,7 +66,7 @@ fn valid_scale<E: Float>(mut scale: E) -> E {
scale
}

impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> AffineQuantization<E, Q, A> {
/// Initialize an affine quantization scheme with the given parameters.
pub fn init(scale: E, offset: Q) -> Self {
Self {
Expand All @@ -77,7 +77,9 @@ impl<E: Float, Q: PrimInt, A: PrimInt> AffineQuantization<E, Q, A> {
}
}

impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization<E, Q, A> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt + Send + Sync> Quantization<E, Q>
for AffineQuantization<E, Q, A>
{
fn new(alpha: E, beta: E) -> Self {
// Q range `[a, b]`
let a = E::from(Q::min_value()).unwrap();
Expand Down Expand Up @@ -107,7 +109,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
// x_q = clamp(round(x / scale + offset), a, b)
let z = E::from(self.offset).unwrap();
run_par!(|| {
iter_par!(values.iter())
iter_slice_par!(values)
.map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap())
.collect()
})
Expand All @@ -116,7 +118,7 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization
fn dequantize(&self, values: &[Q]) -> Vec<E> {
// x = scale * (x_q - offset)
run_par!(|| {
iter_par!(values.iter())
iter_slice_par!(values)
.map(|x_q| {
self.scale
* (E::from(
Expand All @@ -133,14 +135,14 @@ impl<E: Float, Q: PrimInt, A: PrimInt> Quantization<E, Q> for AffineQuantization

/// Symmetric quantization scheme.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct SymmetricQuantization<E: Float, Q: PrimInt> {
pub struct SymmetricQuantization<E: Float + Send + Sync, Q: PrimInt + Send + Sync> {
/// The scaling factor.
pub scale: E,
/// The quantized type.
_q: PhantomData<Q>,
}

impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> SymmetricQuantization<E, Q> {
/// Initialize a symmetric quantization scheme with the given parameters.
pub fn init(scale: E) -> Self {
Self {
Expand All @@ -150,7 +152,9 @@ impl<E: Float, Q: PrimInt> SymmetricQuantization<E, Q> {
}
}

impl<E: Float, Q: PrimInt> Quantization<E, Q> for SymmetricQuantization<E, Q> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Quantization<E, Q>
for SymmetricQuantization<E, Q>
{
fn new(alpha: E, beta: E) -> Self {
assert!(
!Q::min_value().is_zero(),
Expand Down Expand Up @@ -214,7 +218,9 @@ fn canonicalize_signed_zero<T: Float>(x: T) -> T {
x + T::zero()
}

impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q, A> {
impl<E: Float + Send + Sync, Q: PrimInt + Hash + Send + Sync, A: PrimInt> Hash
for AffineQuantization<E, Q, A>
{
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash raw bits.
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
Expand All @@ -223,29 +229,34 @@ impl<E: Float, Q: PrimInt + Hash, A: PrimInt> Hash for AffineQuantization<E, Q,
}
}

impl<E: Float, Q: PrimInt, A: PrimInt> PartialEq for AffineQuantization<E, Q, A> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> PartialEq
for AffineQuantization<E, Q, A>
{
fn eq(&self, other: &Self) -> bool {
self.scale == other.scale && self.offset == other.offset
}
}

impl<E: Float, Q: PrimInt, A: PrimInt> Eq for AffineQuantization<E, Q, A> {}
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync, A: PrimInt> Eq
for AffineQuantization<E, Q, A>
{
}

impl<E: Float, Q: PrimInt> Hash for SymmetricQuantization<E, Q> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Hash for SymmetricQuantization<E, Q> {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash raw bits.
let bits = raw_double_bits(&canonicalize_signed_zero(self.scale));
bits.hash(state);
}
}

impl<E: Float, Q: PrimInt> PartialEq for SymmetricQuantization<E, Q> {
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> PartialEq for SymmetricQuantization<E, Q> {
fn eq(&self, other: &Self) -> bool {
self.scale == other.scale
}
}

impl<E: Float, Q: PrimInt> Eq for SymmetricQuantization<E, Q> {}
impl<E: Float + Send + Sync, Q: PrimInt + Send + Sync> Eq for SymmetricQuantization<E, Q> {}

#[cfg(test)]
mod tests {
Expand Down

0 comments on commit 95593fc

Please sign in to comment.