Skip to content

Commit

Permalink
feat: assert and assert_eq are now expressions (#7313)
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite authored Feb 6, 2025
1 parent 0d156ff commit 9ae3c6c
Show file tree
Hide file tree
Showing 20 changed files with 336 additions and 336 deletions.
51 changes: 51 additions & 0 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum ExpressionKind {
Index(Box<IndexExpression>),
Call(Box<CallExpression>),
MethodCall(Box<MethodCallExpression>),
Constrain(ConstrainExpression),
Constructor(Box<ConstructorExpression>),
MemberAccess(Box<MemberAccessExpression>),
Cast(Box<CastExpression>),
Expand Down Expand Up @@ -582,6 +583,55 @@ impl BlockExpression {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ConstrainExpression {
pub kind: ConstrainKind,
pub arguments: Vec<Expression>,
pub span: Span,
}

impl Display for ConstrainExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ConstrainKind::Assert | ConstrainKind::AssertEq => write!(
f,
"{}({})",
self.kind,
vecmap(&self.arguments, |arg| arg.to_string()).join(", ")
),
ConstrainKind::Constrain => {
write!(f, "constrain {}", &self.arguments[0])
}
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ConstrainKind {
Assert,
AssertEq,
Constrain,
}

impl ConstrainKind {
pub fn required_arguments_count(&self) -> usize {
match self {
ConstrainKind::Assert | ConstrainKind::Constrain => 1,
ConstrainKind::AssertEq => 2,
}
}
}

impl Display for ConstrainKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstrainKind::Assert => write!(f, "assert"),
ConstrainKind::AssertEq => write!(f, "assert_eq"),
ConstrainKind::Constrain => write!(f, "constrain"),
}
}
}

impl Display for Expression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.kind.fmt(f)
Expand All @@ -598,6 +648,7 @@ impl Display for ExpressionKind {
Index(index) => index.fmt(f),
Call(call) => call.fmt(f),
MethodCall(call) => call.fmt(f),
Constrain(constrain) => constrain.fmt(f),
Cast(cast) => cast.fmt(f),
Infix(infix) => infix.fmt(f),
If(if_expr) => if_expr.fmt(f),
Expand Down
52 changes: 0 additions & 52 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub struct Statement {
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum StatementKind {
Let(LetStatement),
Constrain(ConstrainStatement),
Expression(Expression),
Assign(AssignStatement),
For(ForLoopStatement),
Expand Down Expand Up @@ -88,7 +87,6 @@ impl StatementKind {

match self {
StatementKind::Let(_)
| StatementKind::Constrain(_)
| StatementKind::Assign(_)
| StatementKind::Semi(_)
| StatementKind::Break
Expand Down Expand Up @@ -565,55 +563,6 @@ pub enum LValue {
Interned(InternedExpressionKind, Span),
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ConstrainStatement {
pub kind: ConstrainKind,
pub arguments: Vec<Expression>,
pub span: Span,
}

impl Display for ConstrainStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.kind {
ConstrainKind::Assert | ConstrainKind::AssertEq => write!(
f,
"{}({})",
self.kind,
vecmap(&self.arguments, |arg| arg.to_string()).join(", ")
),
ConstrainKind::Constrain => {
write!(f, "constrain {}", &self.arguments[0])
}
}
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ConstrainKind {
Assert,
AssertEq,
Constrain,
}

impl ConstrainKind {
pub fn required_arguments_count(&self) -> usize {
match self {
ConstrainKind::Assert | ConstrainKind::Constrain => 1,
ConstrainKind::AssertEq => 2,
}
}
}

impl Display for ConstrainKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstrainKind::Assert => write!(f, "assert"),
ConstrainKind::AssertEq => write!(f, "assert_eq"),
ConstrainKind::Constrain => write!(f, "constrain"),
}
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Pattern {
Identifier(Ident),
Expand Down Expand Up @@ -935,7 +884,6 @@ impl Display for StatementKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StatementKind::Let(let_statement) => let_statement.fmt(f),
StatementKind::Constrain(constrain) => constrain.fmt(f),
StatementKind::Expression(expression) => expression.fmt(f),
StatementKind::Assign(assign) => assign.fmt(f),
StatementKind::For(for_loop) => for_loop.fmt(f),
Expand Down
12 changes: 6 additions & 6 deletions compiler/noirc_frontend/src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use noirc_errors::Span;
use crate::{
ast::{
ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression,
CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind,
CastExpression, ConstrainExpression, ConstructorExpression, Expression, ExpressionKind,
ForLoopStatement, ForRange, Ident, IfExpression, IndexExpression, InfixExpression, LValue,
Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression,
ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path,
Expand Down Expand Up @@ -294,7 +294,7 @@ pub trait Visitor {
true
}

fn visit_constrain_statement(&mut self, _: &ConstrainStatement) -> bool {
fn visit_constrain_statement(&mut self, _: &ConstrainExpression) -> bool {
true
}

Expand Down Expand Up @@ -855,6 +855,9 @@ impl Expression {
ExpressionKind::MethodCall(method_call_expression) => {
method_call_expression.accept(self.span, visitor);
}
ExpressionKind::Constrain(constrain) => {
constrain.accept(visitor);
}
ExpressionKind::Constructor(constructor_expression) => {
constructor_expression.accept(self.span, visitor);
}
Expand Down Expand Up @@ -1148,9 +1151,6 @@ impl Statement {
StatementKind::Let(let_statement) => {
let_statement.accept(visitor);
}
StatementKind::Constrain(constrain_statement) => {
constrain_statement.accept(visitor);
}
StatementKind::Expression(expression) => {
expression.accept(visitor);
}
Expand Down Expand Up @@ -1199,7 +1199,7 @@ impl LetStatement {
}
}

impl ConstrainStatement {
impl ConstrainExpression {
pub fn accept(&self, visitor: &mut impl Visitor) {
if visitor.visit_constrain_statement(self) {
self.accept_children(visitor);
Expand Down
74 changes: 65 additions & 9 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use acvm::{AcirField, FieldElement};
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use noirc_errors::{Location, Span, Spanned};
use rustc_hash::FxHashSet as HashSet;

use crate::{
ast::{
ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression,
Expression, ExpressionKind, Ident, IfExpression, IndexExpression, InfixExpression,
ItemVisibility, Lambda, Literal, MatchExpression, MemberAccessExpression,
MethodCallExpression, Path, PathSegment, PrefixExpression, StatementKind, UnaryOp,
UnresolvedTypeData, UnresolvedTypeExpression,
ArrayLiteral, BinaryOpKind, BlockExpression, CallExpression, CastExpression,
ConstrainExpression, ConstrainKind, ConstructorExpression, Expression, ExpressionKind,
Ident, IfExpression, IndexExpression, InfixExpression, ItemVisibility, Lambda, Literal,
MatchExpression, MemberAccessExpression, MethodCallExpression, Path, PathSegment,
PrefixExpression, StatementKind, UnaryOp, UnresolvedTypeData, UnresolvedTypeExpression,
},
hir::{
comptime::{self, InterpreterError},
Expand All @@ -21,9 +21,9 @@ use crate::{
hir_def::{
expr::{
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression,
HirConstructorExpression, HirExpression, HirIdent, HirIfExpression, HirIndexExpression,
HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, HirMethodCallExpression,
HirPrefixExpression,
HirConstrainExpression, HirConstructorExpression, HirExpression, HirIdent,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral,
HirMemberAccess, HirMethodCallExpression, HirPrefixExpression,
},
stmt::HirStatement,
traits::{ResolvedTraitBound, TraitConstraint},
Expand Down Expand Up @@ -52,6 +52,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::Index(index) => self.elaborate_index(*index),
ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span),
ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span),
ExpressionKind::Constrain(constrain) => self.elaborate_constrain(constrain),
ExpressionKind::Constructor(constructor) => self.elaborate_constructor(*constructor),
ExpressionKind::MemberAccess(access) => {
return self.elaborate_member_access(*access, expr.span)
Expand Down Expand Up @@ -583,6 +584,61 @@ impl<'context> Elaborator<'context> {
}
}

pub(super) fn elaborate_constrain(
&mut self,
mut expr: ConstrainExpression,
) -> (HirExpression, Type) {
let span = expr.span;
let min_args_count = expr.kind.required_arguments_count();
let max_args_count = min_args_count + 1;
let actual_args_count = expr.arguments.len();

let (message, expr) = if !(min_args_count..=max_args_count).contains(&actual_args_count) {
self.push_err(TypeCheckError::AssertionParameterCountMismatch {
kind: expr.kind,
found: actual_args_count,
span,
});

// Given that we already produced an error, let's make this an `assert(true)` so
// we don't get further errors.
let message = None;
let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true));
let expr = Expression { kind, span };
(message, expr)
} else {
let message =
(actual_args_count != min_args_count).then(|| expr.arguments.pop().unwrap());
let expr = match expr.kind {
ConstrainKind::Assert | ConstrainKind::Constrain => expr.arguments.pop().unwrap(),
ConstrainKind::AssertEq => {
let rhs = expr.arguments.pop().unwrap();
let lhs = expr.arguments.pop().unwrap();
let span = Span::from(lhs.span.start()..rhs.span.end());
let operator = Spanned::from(span, BinaryOpKind::Equal);
let kind =
ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs }));
Expression { kind, span }
}
};
(message, expr)
};

let expr_span = expr.span;
let (expr_id, expr_type) = self.elaborate_expression(expr);

// Must type check the assertion message expression so that we instantiate bindings
let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0);

self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch {
expr_typ: expr_type.to_string(),
expected_typ: Type::Bool.to_string(),
expr_span,
});

(HirExpression::Constrain(HirConstrainExpression(expr_id, self.file, msg)), Type::Unit)
}

/// Elaborates an expression knowing that it has to match a given type.
fn elaborate_expression_with_type(
&mut self,
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/lints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i
// Rust doesn't seem to check the for loop body (it's bounds might mean it's never called).
HirStatement::For(e) => check(e.start_range) && check(e.end_range),
HirStatement::Loop(e) => check(e),
HirStatement::Constrain(_)
| HirStatement::Comptime(_)
HirStatement::Comptime(_)
| HirStatement::Break
| HirStatement::Continue
| HirStatement::Error => true,
Expand All @@ -310,6 +309,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i
HirExpression::MemberAccess(e) => check(e.lhs),
HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check),
HirExpression::MethodCall(e) => check(e.object) && e.arguments.iter().cloned().all(check),
HirExpression::Constrain(e) => check(e.0) && e.2.map(check).unwrap_or(true),
HirExpression::Cast(e) => check(e.lhs),
HirExpression::If(e) => {
check(e.condition) && (check(e.consequence) || e.alternative.map(check).unwrap_or(true))
Expand Down
Loading

0 comments on commit 9ae3c6c

Please sign in to comment.