Skip to content

Commit

Permalink
Implement zeroed for enums
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher committed Jan 31, 2025
1 parent 8d39337 commit 90e16ac
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 63 deletions.
143 changes: 80 additions & 63 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"unresolved_type_is_bool" => unresolved_type_is_bool(interner, arguments, location),
"unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location),
"unresolved_type_is_unit" => unresolved_type_is_unit(interner, arguments, location),
"zeroed" => zeroed(return_type, location.span),
"zeroed" => Ok(zeroed(return_type, location.span)),
_ => {
let item = format!("Comptime evaluation for builtin function '{name}'");
Err(InterpreterError::Unimplemented { item, location })
Expand Down Expand Up @@ -499,21 +499,21 @@ fn struct_def_generics(
_ => return Err(InterpreterError::TypeMismatch { expected, actual, location }),
};

let generics: IResult<_> = struct_def
let generics = struct_def
.generics
.iter()
.map(|generic| -> IResult<Value> {
.map(|generic| {
let generic_as_named = generic.clone().as_named_generic();
let numeric_type = match generic_as_named.kind() {
Kind::Numeric(numeric_type) => Some(Value::Type(*numeric_type)),
_ => None,
};
let numeric_type = option(option_typ.clone(), numeric_type, location.span)?;
Ok(Value::Tuple(vec![Value::Type(generic_as_named), numeric_type]))
let numeric_type = option(option_typ.clone(), numeric_type, location.span);
Value::Tuple(vec![Value::Type(generic_as_named), numeric_type])
})
.collect();

Ok(Value::Slice(generics?, slice_item_type))
Ok(Value::Slice(generics, slice_item_type))
}

fn struct_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
Expand Down Expand Up @@ -811,7 +811,7 @@ fn quoted_as_expr(
},
);

option(return_type, value, location.span)
Ok(option(return_type, value, location.span))
}

// fn as_module(quoted: Quoted) -> Option<Module>
Expand All @@ -834,7 +834,7 @@ fn quoted_as_module(
module.map(Value::ModuleDefinition)
});

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_trait_constraint(quoted: Quoted) -> TraitConstraint
Expand Down Expand Up @@ -1146,7 +1146,7 @@ where

let option_value = f(typ)?;

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn type_eq(_first: Type, _second: Type) -> bool
Expand Down Expand Up @@ -1181,7 +1181,7 @@ fn type_get_trait_impl(
_ => None,
};

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn implements(self, constraint: TraitConstraint) -> bool
Expand Down Expand Up @@ -1302,7 +1302,7 @@ fn typed_expr_as_function_definition(
} else {
None
};
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn get_type(self) -> Option<Type>
Expand All @@ -1324,7 +1324,7 @@ fn typed_expr_get_type(
} else {
None
};
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_mutable_reference(self) -> Option<UnresolvedType>
Expand Down Expand Up @@ -1407,80 +1407,97 @@ where
let typ = get_unresolved_type(interner, value)?;

let option_value = f(typ);

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn zeroed<T>() -> T
fn zeroed(return_type: Type, span: Span) -> IResult<Value> {
fn zeroed(return_type: Type, span: Span) -> Value {
match return_type {
Type::FieldElement => Ok(Value::Field(0u128.into())),
Type::FieldElement => Value::Field(0u128.into()),
Type::Array(length_type, elem) => {
if let Ok(length) = length_type.evaluate_to_u32(span) {
let element = zeroed(elem.as_ref().clone(), span)?;
let element = zeroed(elem.as_ref().clone(), span);
let array = std::iter::repeat(element).take(length as usize).collect();
Ok(Value::Array(array, Type::Array(length_type, elem)))
Value::Array(array, Type::Array(length_type, elem))
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(Type::Array(length_type, elem)))
Value::Zeroed(Type::Array(length_type, elem))
}
}
Type::Slice(_) => Ok(Value::Slice(im::Vector::new(), return_type)),
Type::Slice(_) => Value::Slice(im::Vector::new(), return_type),
Type::Integer(sign, bits) => match (sign, bits) {
(Signedness::Unsigned, IntegerBitSize::One) => Ok(Value::U8(0)),
(Signedness::Unsigned, IntegerBitSize::Eight) => Ok(Value::U8(0)),
(Signedness::Unsigned, IntegerBitSize::Sixteen) => Ok(Value::U16(0)),
(Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Ok(Value::U32(0)),
(Signedness::Unsigned, IntegerBitSize::SixtyFour) => Ok(Value::U64(0)),
(Signedness::Signed, IntegerBitSize::One) => Ok(Value::I8(0)),
(Signedness::Signed, IntegerBitSize::Eight) => Ok(Value::I8(0)),
(Signedness::Signed, IntegerBitSize::Sixteen) => Ok(Value::I16(0)),
(Signedness::Signed, IntegerBitSize::ThirtyTwo) => Ok(Value::I32(0)),
(Signedness::Signed, IntegerBitSize::SixtyFour) => Ok(Value::I64(0)),
(Signedness::Unsigned, IntegerBitSize::One) => Value::U8(0),
(Signedness::Unsigned, IntegerBitSize::Eight) => Value::U8(0),
(Signedness::Unsigned, IntegerBitSize::Sixteen) => Value::U16(0),
(Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Value::U32(0),
(Signedness::Unsigned, IntegerBitSize::SixtyFour) => Value::U64(0),
(Signedness::Signed, IntegerBitSize::One) => Value::I8(0),
(Signedness::Signed, IntegerBitSize::Eight) => Value::I8(0),
(Signedness::Signed, IntegerBitSize::Sixteen) => Value::I16(0),
(Signedness::Signed, IntegerBitSize::ThirtyTwo) => Value::I32(0),
(Signedness::Signed, IntegerBitSize::SixtyFour) => Value::I64(0),
},
Type::Bool => Ok(Value::Bool(false)),
Type::Bool => Value::Bool(false),
Type::String(length_type) => {
if let Ok(length) = length_type.evaluate_to_u32(span) {
Ok(Value::String(Rc::new("\0".repeat(length as usize))))
Value::String(Rc::new("\0".repeat(length as usize)))
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(Type::String(length_type)))
Value::Zeroed(Type::String(length_type))
}
}
Type::FmtString(length_type, captures) => {
let length = length_type.evaluate_to_u32(span);
let typ = Type::FmtString(length_type, captures);
if let Ok(length) = length {
Ok(Value::FormatString(Rc::new("\0".repeat(length as usize)), typ))
Value::FormatString(Rc::new("\0".repeat(length as usize)), typ)
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(typ))
Value::Zeroed(typ)
}
}
Type::Unit => Ok(Value::Unit),
Type::Tuple(fields) => Ok(Value::Tuple(try_vecmap(fields, |field| zeroed(field, span))?)),
Type::DataType(struct_type, generics) => {
// TODO: Handle enums
let fields = struct_type.borrow().get_fields(&generics).unwrap();
let mut values = HashMap::default();

for (field_name, field_type) in fields {
let field_value = zeroed(field_type, span)?;
values.insert(Rc::new(field_name), field_value);
}
Type::Unit => Value::Unit,
Type::Tuple(fields) => Value::Tuple(vecmap(fields, |field| zeroed(field, span))),
Type::DataType(data_type, generics) => {
let typ = data_type.borrow();

if let Some(fields) = typ.get_fields(&generics) {
let mut values = HashMap::default();

for (field_name, field_type) in fields {
let field_value = zeroed(field_type, span);
values.insert(Rc::new(field_name), field_value);
}

let typ = Type::DataType(struct_type, generics);
Ok(Value::Struct(values, typ))
drop(typ);
Value::Struct(values, Type::DataType(data_type, generics))
} else if let Some(mut variants) = typ.get_variants(&generics) {
// Since we're defaulting to Vec::new(), this'd allow us to construct 0 element
// variants... `zeroed` is often used for uninitialized values e.g. in a BoundedVec
// though so we'll allow it.
let mut args = Vec::new();
if !variants.is_empty() {
// is_empty & swap_remove let us avoid a .clone() we'd need if we did .get(0)
let (_name, params) = variants.swap_remove(0);
args = vecmap(params, |param| zeroed(param, span));
}

drop(typ);
Value::Enum(0, args, Type::DataType(data_type, generics))
} else {
drop(typ);
Value::Zeroed(Type::DataType(data_type, generics))
}
}
Type::Alias(alias, generics) => zeroed(alias.borrow().get_type(&generics), span),
Type::CheckedCast { to, .. } => zeroed(*to, span),
typ @ Type::Function(..) => {
// Using Value::Zeroed here is probably safer than using FuncId::dummy_id() or similar
Ok(Value::Zeroed(typ))
Value::Zeroed(typ)
}
Type::MutableReference(element) => {
let element = zeroed(*element, span)?;
Ok(Value::Pointer(Shared::new(element), false))
let element = zeroed(*element, span);
Value::Pointer(Shared::new(element), false)
}
// Optimistically assume we can resolve this type later or that the value is unused
Type::TypeVariable(_)
Expand All @@ -1490,7 +1507,7 @@ fn zeroed(return_type: Type, span: Span) -> IResult<Value> {
| Type::Quoted(_)
| Type::Error
| Type::TraitAsType(..)
| Type::NamedGeneric(_, _) => Ok(Value::Zeroed(return_type)),
| Type::NamedGeneric(_, _) => Value::Zeroed(return_type),
}
}

Expand Down Expand Up @@ -1543,7 +1560,7 @@ fn expr_as_assert(

let option_type = tuple_types.pop().unwrap();
let message = message.map(|msg| Value::expression(msg.kind));
let message = option(option_type, message, location.span).ok()?;
let message = option(option_type, message, location.span);

Some(Value::Tuple(vec![predicate, message]))
} else {
Expand Down Expand Up @@ -1589,7 +1606,7 @@ fn expr_as_assert_eq(

let option_type = tuple_types.pop().unwrap();
let message = message.map(|message| Value::expression(message.kind));
let message = option(option_type, message, location.span).ok()?;
let message = option(option_type, message, location.span);

Some(Value::Tuple(vec![lhs, rhs, message]))
} else {
Expand Down Expand Up @@ -1765,7 +1782,7 @@ fn expr_as_constructor(
None
};

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_for(self) -> Option<(Quoted, Expr, Expr)>
Expand Down Expand Up @@ -1865,7 +1882,7 @@ fn expr_as_if(
Some(Value::Tuple(vec![
Value::expression(if_expr.condition.kind),
Value::expression(if_expr.consequence.kind),
alternative.ok()?,
alternative,
]))
} else {
None
Expand Down Expand Up @@ -1948,7 +1965,7 @@ fn expr_as_lambda(
} else {
Some(Value::UnresolvedType(typ.typ))
};
let typ = option(option_unresolved_type.clone(), typ, location.span).unwrap();
let typ = option(option_unresolved_type.clone(), typ, location.span);
Value::Tuple(vec![pattern, typ])
})
.collect();
Expand All @@ -1967,7 +1984,7 @@ fn expr_as_lambda(
Some(return_type)
};
let return_type = return_type.map(Value::UnresolvedType);
let return_type = option(option_unresolved_type, return_type, location.span).ok()?;
let return_type = option(option_unresolved_type, return_type, location.span);

let body = Value::expression(lambda.body.kind);

Expand Down Expand Up @@ -2001,7 +2018,7 @@ fn expr_as_let(
Some(Value::UnresolvedType(let_statement.r#type.typ))
};

let typ = option(option_type, typ, location.span).ok()?;
let typ = option(option_type, typ, location.span);

Some(Value::Tuple(vec![
Value::pattern(let_statement.pattern),
Expand Down Expand Up @@ -2253,7 +2270,7 @@ where
let expr_value = unwrap_expr_value(interner, expr_value);

let option_value = f(expr_value);
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn resolve(self, in_function: Option<FunctionDefinition>) -> TypedExpr
Expand Down Expand Up @@ -2902,18 +2919,18 @@ fn trait_def_as_trait_constraint(

/// Creates a value that holds an `Option`.
/// `option_type` must be a Type referencing the `Option` type.
pub(crate) fn option(option_type: Type, value: Option<Value>, span: Span) -> IResult<Value> {
pub(crate) fn option(option_type: Type, value: Option<Value>, span: Span) -> Value {
let t = extract_option_generic_type(option_type.clone());

let (is_some, value) = match value {
Some(value) => (Value::Bool(true), value),
None => (Value::Bool(false), zeroed(t, span)?),
None => (Value::Bool(false), zeroed(t, span)),
};

let mut fields = HashMap::default();
fields.insert(Rc::new("_is_some".to_string()), is_some);
fields.insert(Rc::new("_value".to_string()), value);
Ok(Value::Struct(fields, option_type))
Value::Struct(fields, option_type)
}

/// Given a type, assert that it's an Option<T> and return the Type for T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ fn main() {
let _two = Foo::Couple(1, 2);
let _one = Foo::One(3);
let _none = Foo::None();

// Ensure zeroed works with enums
let _zeroed: Foo = std::mem::zeroed();
}
}

Expand Down

0 comments on commit 90e16ac

Please sign in to comment.