Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): Implement zeroed for enums #7252

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 80 additions & 63 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"unresolved_type_is_bool" => unresolved_type_is_bool(interner, arguments, location),
"unresolved_type_is_field" => unresolved_type_is_field(interner, arguments, location),
"unresolved_type_is_unit" => unresolved_type_is_unit(interner, arguments, location),
"zeroed" => zeroed(return_type, location.span),
"zeroed" => Ok(zeroed(return_type, location.span)),
_ => {
let item = format!("Comptime evaluation for builtin function '{name}'");
Err(InterpreterError::Unimplemented { item, location })
Expand Down Expand Up @@ -499,21 +499,21 @@ fn struct_def_generics(
_ => return Err(InterpreterError::TypeMismatch { expected, actual, location }),
};

let generics: IResult<_> = struct_def
let generics = struct_def
.generics
.iter()
.map(|generic| -> IResult<Value> {
.map(|generic| {
let generic_as_named = generic.clone().as_named_generic();
let numeric_type = match generic_as_named.kind() {
Kind::Numeric(numeric_type) => Some(Value::Type(*numeric_type)),
_ => None,
};
let numeric_type = option(option_typ.clone(), numeric_type, location.span)?;
Ok(Value::Tuple(vec![Value::Type(generic_as_named), numeric_type]))
let numeric_type = option(option_typ.clone(), numeric_type, location.span);
Value::Tuple(vec![Value::Type(generic_as_named), numeric_type])
})
.collect();

Ok(Value::Slice(generics?, slice_item_type))
Ok(Value::Slice(generics, slice_item_type))
}

fn struct_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
Expand Down Expand Up @@ -811,7 +811,7 @@ fn quoted_as_expr(
},
);

option(return_type, value, location.span)
Ok(option(return_type, value, location.span))
}

// fn as_module(quoted: Quoted) -> Option<Module>
Expand All @@ -834,7 +834,7 @@ fn quoted_as_module(
module.map(Value::ModuleDefinition)
});

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_trait_constraint(quoted: Quoted) -> TraitConstraint
Expand Down Expand Up @@ -1146,7 +1146,7 @@ where

let option_value = f(typ)?;

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn type_eq(_first: Type, _second: Type) -> bool
Expand Down Expand Up @@ -1181,7 +1181,7 @@ fn type_get_trait_impl(
_ => None,
};

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn implements(self, constraint: TraitConstraint) -> bool
Expand Down Expand Up @@ -1302,7 +1302,7 @@ fn typed_expr_as_function_definition(
} else {
None
};
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn get_type(self) -> Option<Type>
Expand All @@ -1324,7 +1324,7 @@ fn typed_expr_get_type(
} else {
None
};
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_mutable_reference(self) -> Option<UnresolvedType>
Expand Down Expand Up @@ -1407,80 +1407,97 @@ where
let typ = get_unresolved_type(interner, value)?;

let option_value = f(typ);

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn zeroed<T>() -> T
fn zeroed(return_type: Type, span: Span) -> IResult<Value> {
fn zeroed(return_type: Type, span: Span) -> Value {
match return_type {
Type::FieldElement => Ok(Value::Field(0u128.into())),
Type::FieldElement => Value::Field(0u128.into()),
Type::Array(length_type, elem) => {
if let Ok(length) = length_type.evaluate_to_u32(span) {
let element = zeroed(elem.as_ref().clone(), span)?;
let element = zeroed(elem.as_ref().clone(), span);
let array = std::iter::repeat(element).take(length as usize).collect();
Ok(Value::Array(array, Type::Array(length_type, elem)))
Value::Array(array, Type::Array(length_type, elem))
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(Type::Array(length_type, elem)))
Value::Zeroed(Type::Array(length_type, elem))
}
}
Type::Slice(_) => Ok(Value::Slice(im::Vector::new(), return_type)),
Type::Slice(_) => Value::Slice(im::Vector::new(), return_type),
Type::Integer(sign, bits) => match (sign, bits) {
(Signedness::Unsigned, IntegerBitSize::One) => Ok(Value::U8(0)),
(Signedness::Unsigned, IntegerBitSize::Eight) => Ok(Value::U8(0)),
(Signedness::Unsigned, IntegerBitSize::Sixteen) => Ok(Value::U16(0)),
(Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Ok(Value::U32(0)),
(Signedness::Unsigned, IntegerBitSize::SixtyFour) => Ok(Value::U64(0)),
(Signedness::Signed, IntegerBitSize::One) => Ok(Value::I8(0)),
(Signedness::Signed, IntegerBitSize::Eight) => Ok(Value::I8(0)),
(Signedness::Signed, IntegerBitSize::Sixteen) => Ok(Value::I16(0)),
(Signedness::Signed, IntegerBitSize::ThirtyTwo) => Ok(Value::I32(0)),
(Signedness::Signed, IntegerBitSize::SixtyFour) => Ok(Value::I64(0)),
(Signedness::Unsigned, IntegerBitSize::One) => Value::U8(0),
(Signedness::Unsigned, IntegerBitSize::Eight) => Value::U8(0),
(Signedness::Unsigned, IntegerBitSize::Sixteen) => Value::U16(0),
(Signedness::Unsigned, IntegerBitSize::ThirtyTwo) => Value::U32(0),
(Signedness::Unsigned, IntegerBitSize::SixtyFour) => Value::U64(0),
(Signedness::Signed, IntegerBitSize::One) => Value::I8(0),
(Signedness::Signed, IntegerBitSize::Eight) => Value::I8(0),
(Signedness::Signed, IntegerBitSize::Sixteen) => Value::I16(0),
(Signedness::Signed, IntegerBitSize::ThirtyTwo) => Value::I32(0),
(Signedness::Signed, IntegerBitSize::SixtyFour) => Value::I64(0),
},
Type::Bool => Ok(Value::Bool(false)),
Type::Bool => Value::Bool(false),
Type::String(length_type) => {
if let Ok(length) = length_type.evaluate_to_u32(span) {
Ok(Value::String(Rc::new("\0".repeat(length as usize))))
Value::String(Rc::new("\0".repeat(length as usize)))
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(Type::String(length_type)))
Value::Zeroed(Type::String(length_type))
}
}
Type::FmtString(length_type, captures) => {
let length = length_type.evaluate_to_u32(span);
let typ = Type::FmtString(length_type, captures);
if let Ok(length) = length {
Ok(Value::FormatString(Rc::new("\0".repeat(length as usize)), typ))
Value::FormatString(Rc::new("\0".repeat(length as usize)), typ)
} else {
// Assume we can resolve the length later
Ok(Value::Zeroed(typ))
Value::Zeroed(typ)
}
}
Type::Unit => Ok(Value::Unit),
Type::Tuple(fields) => Ok(Value::Tuple(try_vecmap(fields, |field| zeroed(field, span))?)),
Type::DataType(struct_type, generics) => {
// TODO: Handle enums
let fields = struct_type.borrow().get_fields(&generics).unwrap();
let mut values = HashMap::default();

for (field_name, field_type) in fields {
let field_value = zeroed(field_type, span)?;
values.insert(Rc::new(field_name), field_value);
}
Type::Unit => Value::Unit,
Type::Tuple(fields) => Value::Tuple(vecmap(fields, |field| zeroed(field, span))),
Type::DataType(data_type, generics) => {
let typ = data_type.borrow();

if let Some(fields) = typ.get_fields(&generics) {
let mut values = HashMap::default();

for (field_name, field_type) in fields {
let field_value = zeroed(field_type, span);
values.insert(Rc::new(field_name), field_value);
}

let typ = Type::DataType(struct_type, generics);
Ok(Value::Struct(values, typ))
drop(typ);
Value::Struct(values, Type::DataType(data_type, generics))
} else if let Some(mut variants) = typ.get_variants(&generics) {
// Since we're defaulting to Vec::new(), this'd allow us to construct 0 element
// variants... `zeroed` is often used for uninitialized values e.g. in a BoundedVec
// though so we'll allow it.
let mut args = Vec::new();
if !variants.is_empty() {
// is_empty & swap_remove let us avoid a .clone() we'd need if we did .get(0)
let (_name, params) = variants.swap_remove(0);
args = vecmap(params, |param| zeroed(param, span));
}

drop(typ);
Value::Enum(0, args, Type::DataType(data_type, generics))
} else {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
drop(typ);
Value::Zeroed(Type::DataType(data_type, generics))
}
}
Type::Alias(alias, generics) => zeroed(alias.borrow().get_type(&generics), span),
Type::CheckedCast { to, .. } => zeroed(*to, span),
typ @ Type::Function(..) => {
// Using Value::Zeroed here is probably safer than using FuncId::dummy_id() or similar
Ok(Value::Zeroed(typ))
Value::Zeroed(typ)
}
Type::MutableReference(element) => {
let element = zeroed(*element, span)?;
Ok(Value::Pointer(Shared::new(element), false))
let element = zeroed(*element, span);
Value::Pointer(Shared::new(element), false)
}
// Optimistically assume we can resolve this type later or that the value is unused
Type::TypeVariable(_)
Expand All @@ -1490,7 +1507,7 @@ fn zeroed(return_type: Type, span: Span) -> IResult<Value> {
| Type::Quoted(_)
| Type::Error
| Type::TraitAsType(..)
| Type::NamedGeneric(_, _) => Ok(Value::Zeroed(return_type)),
| Type::NamedGeneric(_, _) => Value::Zeroed(return_type),
}
}

Expand Down Expand Up @@ -1543,7 +1560,7 @@ fn expr_as_assert(

let option_type = tuple_types.pop().unwrap();
let message = message.map(|msg| Value::expression(msg.kind));
let message = option(option_type, message, location.span).ok()?;
let message = option(option_type, message, location.span);

Some(Value::Tuple(vec![predicate, message]))
} else {
Expand Down Expand Up @@ -1589,7 +1606,7 @@ fn expr_as_assert_eq(

let option_type = tuple_types.pop().unwrap();
let message = message.map(|message| Value::expression(message.kind));
let message = option(option_type, message, location.span).ok()?;
let message = option(option_type, message, location.span);

Some(Value::Tuple(vec![lhs, rhs, message]))
} else {
Expand Down Expand Up @@ -1765,7 +1782,7 @@ fn expr_as_constructor(
None
};

option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn as_for(self) -> Option<(Quoted, Expr, Expr)>
Expand Down Expand Up @@ -1865,7 +1882,7 @@ fn expr_as_if(
Some(Value::Tuple(vec![
Value::expression(if_expr.condition.kind),
Value::expression(if_expr.consequence.kind),
alternative.ok()?,
alternative,
]))
} else {
None
Expand Down Expand Up @@ -1948,7 +1965,7 @@ fn expr_as_lambda(
} else {
Some(Value::UnresolvedType(typ.typ))
};
let typ = option(option_unresolved_type.clone(), typ, location.span).unwrap();
let typ = option(option_unresolved_type.clone(), typ, location.span);
Value::Tuple(vec![pattern, typ])
})
.collect();
Expand All @@ -1967,7 +1984,7 @@ fn expr_as_lambda(
Some(return_type)
};
let return_type = return_type.map(Value::UnresolvedType);
let return_type = option(option_unresolved_type, return_type, location.span).ok()?;
let return_type = option(option_unresolved_type, return_type, location.span);

let body = Value::expression(lambda.body.kind);

Expand Down Expand Up @@ -2001,7 +2018,7 @@ fn expr_as_let(
Some(Value::UnresolvedType(let_statement.r#type.typ))
};

let typ = option(option_type, typ, location.span).ok()?;
let typ = option(option_type, typ, location.span);

Some(Value::Tuple(vec![
Value::pattern(let_statement.pattern),
Expand Down Expand Up @@ -2253,7 +2270,7 @@ where
let expr_value = unwrap_expr_value(interner, expr_value);

let option_value = f(expr_value);
option(return_type, option_value, location.span)
Ok(option(return_type, option_value, location.span))
}

// fn resolve(self, in_function: Option<FunctionDefinition>) -> TypedExpr
Expand Down Expand Up @@ -2902,18 +2919,18 @@ fn trait_def_as_trait_constraint(

/// Creates a value that holds an `Option`.
/// `option_type` must be a Type referencing the `Option` type.
pub(crate) fn option(option_type: Type, value: Option<Value>, span: Span) -> IResult<Value> {
pub(crate) fn option(option_type: Type, value: Option<Value>, span: Span) -> Value {
let t = extract_option_generic_type(option_type.clone());

let (is_some, value) = match value {
Some(value) => (Value::Bool(true), value),
None => (Value::Bool(false), zeroed(t, span)?),
None => (Value::Bool(false), zeroed(t, span)),
};

let mut fields = HashMap::default();
fields.insert(Rc::new("_is_some".to_string()), is_some);
fields.insert(Rc::new("_value".to_string()), value);
Ok(Value::Struct(fields, option_type))
Value::Struct(fields, option_type)
}

/// Given a type, assert that it's an Option<T> and return the Type for T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ fn main() {
let _two = Foo::Couple(1, 2);
let _one = Foo::One(3);
let _none = Foo::None;

// Ensure zeroed works with enums
let _zeroed: Foo = std::mem::zeroed();
}
}

Expand Down
Loading