Skip to content

Commit

Permalink
[naga spv-out] Avoid undefined behaviour for integer division and modulo
Browse files Browse the repository at this point in the history
Integer division or modulo is undefined behaviour in SPIR-V when the
divisor is zero, or when the dividend is the most negative number
representable by the result type and the divisor is negative one.

This patch makes us avoid this undefined behaviour and instead ensures
we adhere to the WGSL spec in these cases: for divisions the
expression evaluates to the value of the dividend, and for modulos the
expression evaluates to zero.

Similarily to how we handle these cases for the MSL and HLSL backends,
prior to emitting each function we emit code for any helper functions
required by that function's expressions. In this case that is helper
functions for integer division and modulo. Then, when emitting the
actual function's body, if we encounter an expression which needs
wrapped we instead emit a function call to the helper.
  • Loading branch information
jamienicol committed Jan 28, 2025
1 parent 7109a66 commit 46e2945
Show file tree
Hide file tree
Showing 9 changed files with 2,730 additions and 2,207 deletions.
389 changes: 208 additions & 181 deletions naga/src/back/spv/block.rs

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ impl NumericType {
_ => None,
}
}

const fn with_scalar(self, scalar: crate::Scalar) -> Self {
match self {
NumericType::Scalar(_) => NumericType::Scalar(scalar),
NumericType::Vector { size, .. } => NumericType::Vector { size, scalar },
NumericType::Matrix { columns, rows, .. } => NumericType::Matrix {
columns,
rows,
scalar,
},
}
}
}

/// A SPIR-V type constructed during code generation.
Expand Down Expand Up @@ -475,6 +487,18 @@ enum Dimension {
Matrix,
}

/// Key used to look up an operation which we have wrapped in a helper
/// function, which should be called instead of directly emitting code
/// for the expression. See [`Writer::wrapped_functions`].
#[derive(Debug, Eq, PartialEq, Hash)]
enum WrappedFunction {
BinaryOp {
op: crate::BinaryOperator,
left_type_id: Word,
right_type_id: Word,
},
}

/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids.
///
/// When we emit code to evaluate a given `Expression`, we record the
Expand Down Expand Up @@ -752,6 +776,10 @@ pub struct Writer {
lookup_type: crate::FastHashMap<LookupType, Word>,
lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
/// Operations which have been wrapped in a helper function. The value is
/// the ID of the function, which should be called instead of emitting code
/// for the operation directly.
wrapped_functions: crate::FastHashMap<WrappedFunction, Word>,
/// Indexed by const-expression handle indexes
constant_ids: HandleVec<crate::Expression, Word>,
cached_constants: crate::FastHashMap<CachedConstant, Word>,
Expand Down
215 changes: 214 additions & 1 deletion naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
};
use crate::{
arena::{Handle, HandleVec, UniqueArena},
back::spv::BindingInfo,
back::spv::{BindingInfo, WrappedFunction},
proc::{Alignment, TypeResolution},
valid::{FunctionInfo, ModuleInfo},
};
Expand Down Expand Up @@ -74,6 +74,7 @@ impl Writer {
lookup_type: crate::FastHashMap::default(),
lookup_function: crate::FastHashMap::default(),
lookup_function_type: crate::FastHashMap::default(),
wrapped_functions: crate::FastHashMap::default(),
constant_ids: HandleVec::new(),
cached_constants: crate::FastHashMap::default(),
global_variables: HandleVec::new(),
Expand Down Expand Up @@ -127,6 +128,7 @@ impl Writer {
lookup_type: take(&mut self.lookup_type).recycle(),
lookup_function: take(&mut self.lookup_function).recycle(),
lookup_function_type: take(&mut self.lookup_function_type).recycle(),
wrapped_functions: take(&mut self.wrapped_functions).recycle(),
constant_ids: take(&mut self.constant_ids).recycle(),
cached_constants: take(&mut self.cached_constants).recycle(),
global_variables: take(&mut self.global_variables).recycle(),
Expand Down Expand Up @@ -305,6 +307,215 @@ impl Writer {
.push(Instruction::decorate(id, decoration, operands));
}

/// Emits code for any wrapper functions required by the expressions in ir_function.
/// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
fn write_wrapped_functions(
&mut self,
ir_function: &crate::Function,
info: &FunctionInfo,
ir_module: &crate::Module,
) -> Result<(), Error> {
log::trace!("Generating wrapped functions for {:?}", ir_function.name);

for (expr_handle, expr) in ir_function.expressions.iter() {
match *expr {
crate::Expression::Binary { op, left, right } => {
let expr_ty = info[expr_handle].ty.inner_with(&ir_module.types);
match (op, expr_ty.scalar_kind()) {
// Division and modulo are undefined behaviour when the dividend is the
// minimum representable value and the divisor is negative one, or when
// the divisor is zero. These wrapped functions override the divisor to
// one in these cases, matching the WGSL spec.
(
crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
) => {
let return_type_id = self.get_expression_type_id(&info[expr_handle].ty);
let left_type_id = self.get_expression_type_id(&info[left].ty);
let right_type_id = self.get_expression_type_id(&info[right].ty);
let wrapped = WrappedFunction::BinaryOp {
op,
left_type_id,
right_type_id,
};
let function_id = *match self.wrapped_functions.entry(wrapped) {
Entry::Occupied(_) => continue,
Entry::Vacant(e) => e.insert(self.id_gen.next()),
};
if self.flags.contains(WriterFlags::DEBUG) {
let function_name = match op {
crate::BinaryOperator::Divide => "naga_div",
crate::BinaryOperator::Modulo => "naga_mod",
_ => unreachable!(),
};
self.debugs
.push(Instruction::name(function_id, function_name));
}
let mut function = Function::default();

let function_type_id = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![left_type_id, right_type_id],
return_type_id,
});
function.signature = Some(Instruction::function(
return_type_id,
function_id,
spirv::FunctionControl::empty(),
function_type_id,
));

let lhs_id = self.id_gen.next();
let rhs_id = self.id_gen.next();
if self.flags.contains(WriterFlags::DEBUG) {
self.debugs.push(Instruction::name(lhs_id, "lhs"));
self.debugs.push(Instruction::name(rhs_id, "rhs"));
}
let left_par = Instruction::function_parameter(left_type_id, lhs_id);
let right_par = Instruction::function_parameter(right_type_id, rhs_id);
for instruction in [left_par, right_par] {
function.parameters.push(FunctionArgument {
instruction,
handle_id: 0,
});
}

let label_id = self.id_gen.next();
let mut block = Block::new(label_id);

let scalar = expr_ty.scalar().unwrap();
let numeric_type = NumericType::from_inner(expr_ty).unwrap();
let bool_type = numeric_type.with_scalar(crate::Scalar::BOOL);
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type)));

let maybe_splat_const = |writer: &mut Self, const_id| match numeric_type
{
NumericType::Scalar(_) => const_id,
NumericType::Vector { size, .. } => {
let constituent_ids = [const_id; crate::VectorSize::MAX];
writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(numeric_type)),
&constituent_ids[..size as usize],
)
}
NumericType::Matrix { .. } => unreachable!(),
};

let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
let composite_zero_id = maybe_splat_const(self, const_zero_id);
let rhs_eq_zero_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_zero_id,
rhs_id,
composite_zero_id,
));
let divisor_selector_id = match scalar.kind {
crate::ScalarKind::Sint => {
let (const_min_id, const_neg_one_id) = match scalar.width {
4 => Ok((
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
self.get_constant_scalar(crate::Literal::I32(-1i32)),
)),
8 => Ok((
self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
self.get_constant_scalar(crate::Literal::I64(-1i64)),
)),
_ => Err(Error::Validation("Unexpected scalar width")),
}?;
let composite_min_id = maybe_splat_const(self, const_min_id);
let composite_neg_one_id =
maybe_splat_const(self, const_neg_one_id);

let lhs_eq_int_min_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
lhs_eq_int_min_id,
lhs_id,
composite_min_id,
));
let rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_neg_one_id,
rhs_id,
composite_neg_one_id,
));
let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalAnd,
bool_type_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
lhs_eq_int_min_id,
rhs_eq_neg_one_id,
));
let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id =
self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalOr,
bool_type_id,
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
rhs_eq_zero_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
));
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
}
crate::ScalarKind::Uint => rhs_eq_zero_id,
_ => unreachable!(),
};

let const_one_id = self.get_constant_scalar_with(1, scalar)?;
let composite_one_id = maybe_splat_const(self, const_one_id);
let divisor_id = self.id_gen.next();
block.body.push(Instruction::select(
right_type_id,
divisor_id,
divisor_selector_id,
composite_one_id,
rhs_id,
));
let op = match (op, scalar.kind) {
(crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => {
spirv::Op::SDiv
}
(crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => {
spirv::Op::UDiv
}
(crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => {
spirv::Op::SRem
}
(crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => {
spirv::Op::UMod
}
_ => unreachable!(),
};
let return_id = self.id_gen.next();
block.body.push(Instruction::binary(
op,
return_type_id,
return_id,
lhs_id,
divisor_id,
));

function.consume(block, Instruction::return_value(return_id));
function.to_words(&mut self.logical_layout.function_definitions);
Instruction::function_end()
.to_words(&mut self.logical_layout.function_definitions);
}
_ => {}
}
}
_ => {}
}
}

Ok(())
}

fn write_function(
&mut self,
ir_function: &crate::Function,
Expand All @@ -313,6 +524,8 @@ impl Writer {
mut interface: Option<FunctionInterface>,
debug_info: &Option<DebugInfoInner>,
) -> Result<Word, Error> {
self.write_wrapped_functions(ir_function, info, ir_module)?;

log::trace!("Generating code for {:?}", ir_function.name);
let mut function = Function::default();

Expand Down
Loading

0 comments on commit 46e2945

Please sign in to comment.