Skip to content

Commit

Permalink
Merge pull request #532 from robertknight/quantize-linear-reciprocal
Browse files Browse the repository at this point in the history
Add small optimizations for QuantizeLinear
  • Loading branch information
robertknight authored Jan 11, 2025
2 parents c6d4245 + 0aa771e commit e3eb30f
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/ops/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,27 @@ impl Operator for DequantizeLinear {
/// y = saturate((self / scale) + zero_point)
/// ```
///
/// For efficiency the `quantize` method takes the reciprocal of the scale.
///
/// See https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html for
/// additional details.
pub trait Quantize<To> {
fn quantize(self, scale: Self, zero_point: To) -> To;
fn quantize(self, inv_scale: Self, zero_point: To) -> To;
}

impl Quantize<u8> for f32 {
fn quantize(self, scale: Self, zero_point: u8) -> u8 {
let y = (self / scale).round_ties_even();
fn quantize(self, inv_scale: Self, zero_point: u8) -> u8 {
let y = (self * inv_scale).round_ties_even();
let y = y + zero_point as f32;
y.clamp(0., 255.) as u8
y as u8 // saturating cast
}
}

impl Quantize<i8> for f32 {
fn quantize(self, scale: Self, zero_point: i8) -> i8 {
let y = (self / scale).round_ties_even();
fn quantize(self, inv_scale: Self, zero_point: i8) -> i8 {
let y = (self * inv_scale).round_ties_even();
let y = y + zero_point as f32;
y.clamp(-128., 127.) as i8
y as i8 // saturating cast
}
}

Expand All @@ -162,10 +164,10 @@ where

match scale.ndim() {
0 => {
let scale = scale.item().unwrap();
let inv_scale = 1. / *scale.item().unwrap();
let zero_point = zero_point.and_then(|z| z.item()).unwrap();

Ok(input.map_in(pool, |x| x.quantize(*scale, *zero_point)))
Ok(input.map_in(pool, |x| x.quantize(inv_scale, *zero_point)))
}
1 => {
let axis = resolve_axis(input.ndim(), axis)?;
Expand All @@ -185,8 +187,9 @@ where
.zip(scale.iter())
.zip(zero_point.iter())
.for_each(|(((mut out_slice, in_slice), &scale), &zero_point)| {
let inv_scale = 1. / scale;
for (y, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
y.write(x.quantize(scale, zero_point));
y.write(x.quantize(inv_scale, zero_point));
}
});

Expand Down

0 comments on commit e3eb30f

Please sign in to comment.