diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 70e67c4f0e6..9c18cc0dd34 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -26,6 +26,7 @@ pub enum ExpressionKind { Index(Box), Call(Box), MethodCall(Box), + Constrain(ConstrainExpression), Constructor(Box), MemberAccess(Box), Cast(Box), @@ -582,6 +583,55 @@ impl BlockExpression { } } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ConstrainExpression { + pub kind: ConstrainKind, + pub arguments: Vec, + pub span: Span, +} + +impl Display for ConstrainExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.kind { + ConstrainKind::Assert | ConstrainKind::AssertEq => write!( + f, + "{}({})", + self.kind, + vecmap(&self.arguments, |arg| arg.to_string()).join(", ") + ), + ConstrainKind::Constrain => { + write!(f, "constrain {}", &self.arguments[0]) + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ConstrainKind { + Assert, + AssertEq, + Constrain, +} + +impl ConstrainKind { + pub fn required_arguments_count(&self) -> usize { + match self { + ConstrainKind::Assert | ConstrainKind::Constrain => 1, + ConstrainKind::AssertEq => 2, + } + } +} + +impl Display for ConstrainKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConstrainKind::Assert => write!(f, "assert"), + ConstrainKind::AssertEq => write!(f, "assert_eq"), + ConstrainKind::Constrain => write!(f, "constrain"), + } + } +} + impl Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.kind.fmt(f) @@ -598,6 +648,7 @@ impl Display for ExpressionKind { Index(index) => index.fmt(f), Call(call) => call.fmt(f), MethodCall(call) => call.fmt(f), + Constrain(constrain) => constrain.fmt(f), Cast(cast) => cast.fmt(f), Infix(infix) => infix.fmt(f), If(if_expr) => if_expr.fmt(f), diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index ed1d26ef149..88d1e97a96f 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -42,7 +42,6 @@ pub struct Statement { #[derive(Debug, PartialEq, Eq, Clone)] pub enum StatementKind { Let(LetStatement), - Constrain(ConstrainStatement), Expression(Expression), Assign(AssignStatement), For(ForLoopStatement), @@ -88,7 +87,6 @@ impl StatementKind { match self { StatementKind::Let(_) - | StatementKind::Constrain(_) | StatementKind::Assign(_) | StatementKind::Semi(_) | StatementKind::Break @@ -565,55 +563,6 @@ pub enum LValue { Interned(InternedExpressionKind, Span), } -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct ConstrainStatement { - pub kind: ConstrainKind, - pub arguments: Vec, - pub span: Span, -} - -impl Display for ConstrainStatement { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.kind { - ConstrainKind::Assert | ConstrainKind::AssertEq => write!( - f, - "{}({})", - self.kind, - vecmap(&self.arguments, |arg| arg.to_string()).join(", ") - ), - ConstrainKind::Constrain => { - write!(f, "constrain {}", &self.arguments[0]) - } - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum ConstrainKind { - Assert, - AssertEq, - Constrain, -} - -impl ConstrainKind { - pub fn required_arguments_count(&self) -> usize { - match self { - ConstrainKind::Assert | ConstrainKind::Constrain => 1, - ConstrainKind::AssertEq => 2, - } - } -} - -impl Display for ConstrainKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ConstrainKind::Assert => write!(f, "assert"), - ConstrainKind::AssertEq => write!(f, "assert_eq"), - ConstrainKind::Constrain => write!(f, "constrain"), - } - } -} - #[derive(Debug, PartialEq, Eq, Clone)] pub enum Pattern { Identifier(Ident), @@ -935,7 +884,6 @@ impl Display for StatementKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { StatementKind::Let(let_statement) => let_statement.fmt(f), - StatementKind::Constrain(constrain) => constrain.fmt(f), StatementKind::Expression(expression) => expression.fmt(f), StatementKind::Assign(assign) => assign.fmt(f), StatementKind::For(for_loop) => for_loop.fmt(f), diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index a43bd0a5d3d..e40c534c3b9 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -4,7 +4,7 @@ use noirc_errors::Span; use crate::{ ast::{ ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, - CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, + CastExpression, ConstrainExpression, ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, ForRange, Ident, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, @@ -294,7 +294,7 @@ pub trait Visitor { true } - fn visit_constrain_statement(&mut self, _: &ConstrainStatement) -> bool { + fn visit_constrain_statement(&mut self, _: &ConstrainExpression) -> bool { true } @@ -855,6 +855,9 @@ impl Expression { ExpressionKind::MethodCall(method_call_expression) => { method_call_expression.accept(self.span, visitor); } + ExpressionKind::Constrain(constrain) => { + constrain.accept(visitor); + } ExpressionKind::Constructor(constructor_expression) => { constructor_expression.accept(self.span, visitor); } @@ -1148,9 +1151,6 @@ impl Statement { StatementKind::Let(let_statement) => { let_statement.accept(visitor); } - StatementKind::Constrain(constrain_statement) => { - constrain_statement.accept(visitor); - } StatementKind::Expression(expression) => { expression.accept(visitor); } @@ -1199,7 +1199,7 @@ impl LetStatement { } } -impl ConstrainStatement { +impl ConstrainExpression { pub fn accept(&self, visitor: &mut impl Visitor) { if visitor.visit_constrain_statement(self) { self.accept_children(visitor); diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 328048bb942..8bee7241d43 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -1,15 +1,15 @@ use acvm::{AcirField, FieldElement}; use iter_extended::vecmap; -use noirc_errors::{Location, Span}; +use noirc_errors::{Location, Span, Spanned}; use rustc_hash::FxHashSet as HashSet; use crate::{ ast::{ - ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, - Expression, ExpressionKind, Ident, IfExpression, IndexExpression, InfixExpression, - ItemVisibility, Lambda, Literal, MatchExpression, MemberAccessExpression, - MethodCallExpression, Path, PathSegment, PrefixExpression, StatementKind, UnaryOp, - UnresolvedTypeData, UnresolvedTypeExpression, + ArrayLiteral, BinaryOpKind, BlockExpression, CallExpression, CastExpression, + ConstrainExpression, ConstrainKind, ConstructorExpression, Expression, ExpressionKind, + Ident, IfExpression, IndexExpression, InfixExpression, ItemVisibility, Lambda, Literal, + MatchExpression, MemberAccessExpression, MethodCallExpression, Path, PathSegment, + PrefixExpression, StatementKind, UnaryOp, UnresolvedTypeData, UnresolvedTypeExpression, }, hir::{ comptime::{self, InterpreterError}, @@ -21,9 +21,9 @@ use crate::{ hir_def::{ expr::{ HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCastExpression, - HirConstructorExpression, HirExpression, HirIdent, HirIfExpression, HirIndexExpression, - HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, HirMethodCallExpression, - HirPrefixExpression, + HirConstrainExpression, HirConstructorExpression, HirExpression, HirIdent, + HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, + HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, }, stmt::HirStatement, traits::{ResolvedTraitBound, TraitConstraint}, @@ -52,6 +52,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::Index(index) => self.elaborate_index(*index), ExpressionKind::Call(call) => self.elaborate_call(*call, expr.span), ExpressionKind::MethodCall(call) => self.elaborate_method_call(*call, expr.span), + ExpressionKind::Constrain(constrain) => self.elaborate_constrain(constrain), ExpressionKind::Constructor(constructor) => self.elaborate_constructor(*constructor), ExpressionKind::MemberAccess(access) => { return self.elaborate_member_access(*access, expr.span) @@ -583,6 +584,61 @@ impl<'context> Elaborator<'context> { } } + pub(super) fn elaborate_constrain( + &mut self, + mut expr: ConstrainExpression, + ) -> (HirExpression, Type) { + let span = expr.span; + let min_args_count = expr.kind.required_arguments_count(); + let max_args_count = min_args_count + 1; + let actual_args_count = expr.arguments.len(); + + let (message, expr) = if !(min_args_count..=max_args_count).contains(&actual_args_count) { + self.push_err(TypeCheckError::AssertionParameterCountMismatch { + kind: expr.kind, + found: actual_args_count, + span, + }); + + // Given that we already produced an error, let's make this an `assert(true)` so + // we don't get further errors. + let message = None; + let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true)); + let expr = Expression { kind, span }; + (message, expr) + } else { + let message = + (actual_args_count != min_args_count).then(|| expr.arguments.pop().unwrap()); + let expr = match expr.kind { + ConstrainKind::Assert | ConstrainKind::Constrain => expr.arguments.pop().unwrap(), + ConstrainKind::AssertEq => { + let rhs = expr.arguments.pop().unwrap(); + let lhs = expr.arguments.pop().unwrap(); + let span = Span::from(lhs.span.start()..rhs.span.end()); + let operator = Spanned::from(span, BinaryOpKind::Equal); + let kind = + ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs })); + Expression { kind, span } + } + }; + (message, expr) + }; + + let expr_span = expr.span; + let (expr_id, expr_type) = self.elaborate_expression(expr); + + // Must type check the assertion message expression so that we instantiate bindings + let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); + + self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { + expr_typ: expr_type.to_string(), + expected_typ: Type::Bool.to_string(), + expr_span, + }); + + (HirExpression::Constrain(HirConstrainExpression(expr_id, self.file, msg)), Type::Unit) + } + /// Elaborates an expression knowing that it has to match a given type. fn elaborate_expression_with_type( &mut self, diff --git a/compiler/noirc_frontend/src/elaborator/lints.rs b/compiler/noirc_frontend/src/elaborator/lints.rs index af80dfaa823..7910d8cebdb 100644 --- a/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/compiler/noirc_frontend/src/elaborator/lints.rs @@ -283,8 +283,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i // Rust doesn't seem to check the for loop body (it's bounds might mean it's never called). HirStatement::For(e) => check(e.start_range) && check(e.end_range), HirStatement::Loop(e) => check(e), - HirStatement::Constrain(_) - | HirStatement::Comptime(_) + HirStatement::Comptime(_) | HirStatement::Break | HirStatement::Continue | HirStatement::Error => true, @@ -310,6 +309,7 @@ fn can_return_without_recursing(interner: &NodeInterner, func_id: FuncId, expr_i HirExpression::MemberAccess(e) => check(e.lhs), HirExpression::Call(e) => check(e.func) && e.arguments.iter().cloned().all(check), HirExpression::MethodCall(e) => check(e.object) && e.arguments.iter().cloned().all(check), + HirExpression::Constrain(e) => check(e.0) && e.2.map(check).unwrap_or(true), HirExpression::Cast(e) => check(e.lhs), HirExpression::If(e) => { check(e.condition) && (check(e.consequence) || e.alternative.map(check).unwrap_or(true)) diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index b17052d01ef..c401646332f 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -1,9 +1,8 @@ -use noirc_errors::{Location, Span, Spanned}; +use noirc_errors::{Location, Span}; use crate::{ ast::{ - AssignStatement, BinaryOpKind, ConstrainKind, ConstrainStatement, Expression, - ExpressionKind, ForLoopStatement, ForRange, Ident, InfixExpression, ItemVisibility, LValue, + AssignStatement, Expression, ForLoopStatement, ForRange, Ident, ItemVisibility, LValue, LetStatement, Path, Statement, StatementKind, }, hir::{ @@ -15,10 +14,7 @@ use crate::{ }, hir_def::{ expr::HirIdent, - stmt::{ - HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, - HirStatement, - }, + stmt::{HirAssignStatement, HirForStatement, HirLValue, HirLetStatement, HirStatement}, }, node_interner::{DefinitionId, DefinitionKind, GlobalId, StmtId}, DataType, Type, @@ -38,7 +34,6 @@ impl<'context> Elaborator<'context> { ) -> (HirStatement, Type) { match statement.kind { StatementKind::Let(let_stmt) => self.elaborate_local_let(let_stmt), - StatementKind::Constrain(constrain) => self.elaborate_constrain(constrain), StatementKind::Assign(assign) => self.elaborate_assign(assign), StatementKind::For(for_stmt) => self.elaborate_for(for_stmt), StatementKind::Loop(block, span) => self.elaborate_loop(block, span), @@ -149,61 +144,6 @@ impl<'context> Elaborator<'context> { (HirStatement::Let(let_), Type::Unit) } - pub(super) fn elaborate_constrain( - &mut self, - mut stmt: ConstrainStatement, - ) -> (HirStatement, Type) { - let span = stmt.span; - let min_args_count = stmt.kind.required_arguments_count(); - let max_args_count = min_args_count + 1; - let actual_args_count = stmt.arguments.len(); - - let (message, expr) = if !(min_args_count..=max_args_count).contains(&actual_args_count) { - self.push_err(TypeCheckError::AssertionParameterCountMismatch { - kind: stmt.kind, - found: actual_args_count, - span, - }); - - // Given that we already produced an error, let's make this an `assert(true)` so - // we don't get further errors. - let message = None; - let kind = ExpressionKind::Literal(crate::ast::Literal::Bool(true)); - let expr = Expression { kind, span }; - (message, expr) - } else { - let message = - (actual_args_count != min_args_count).then(|| stmt.arguments.pop().unwrap()); - let expr = match stmt.kind { - ConstrainKind::Assert | ConstrainKind::Constrain => stmt.arguments.pop().unwrap(), - ConstrainKind::AssertEq => { - let rhs = stmt.arguments.pop().unwrap(); - let lhs = stmt.arguments.pop().unwrap(); - let span = Span::from(lhs.span.start()..rhs.span.end()); - let operator = Spanned::from(span, BinaryOpKind::Equal); - let kind = - ExpressionKind::Infix(Box::new(InfixExpression { lhs, operator, rhs })); - Expression { kind, span } - } - }; - (message, expr) - }; - - let expr_span = expr.span; - let (expr_id, expr_type) = self.elaborate_expression(expr); - - // Must type check the assertion message expression so that we instantiate bindings - let msg = message.map(|assert_msg_expr| self.elaborate_expression(assert_msg_expr).0); - - self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { - expr_typ: expr_type.to_string(), - expected_typ: Type::Bool.to_string(), - expr_span, - }); - - (HirStatement::Constrain(HirConstrainStatement(expr_id, self.file, msg)), Type::Unit) - } - pub(super) fn elaborate_assign(&mut self, assign: AssignStatement) -> (HirStatement, Type) { let expr_span = assign.expression.span; let (expression, expr_type) = self.elaborate_expression(assign.expression); diff --git a/compiler/noirc_frontend/src/hir/comptime/display.rs b/compiler/noirc_frontend/src/hir/comptime/display.rs index 1be4bbe61ab..a6927ab3fe8 100644 --- a/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -6,7 +6,7 @@ use noirc_errors::Span; use crate::{ ast::{ ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, - CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, + CastExpression, ConstrainExpression, ConstructorExpression, Expression, ExpressionKind, ForBounds, ForLoopStatement, ForRange, GenericTypeArgs, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, Pattern, PrefixExpression, Statement, @@ -573,6 +573,12 @@ fn remove_interned_in_expression_kind( ..*call })) } + ExpressionKind::Constrain(constrain) => ExpressionKind::Constrain(ConstrainExpression { + arguments: vecmap(constrain.arguments, |expr| { + remove_interned_in_expression(interner, expr) + }), + ..constrain + }), ExpressionKind::Constructor(constructor) => { ExpressionKind::Constructor(Box::new(ConstructorExpression { fields: vecmap(constructor.fields, |(name, expr)| { @@ -728,12 +734,6 @@ fn remove_interned_in_statement_kind( r#type: remove_interned_in_unresolved_type(interner, let_statement.r#type), ..let_statement }), - StatementKind::Constrain(constrain) => StatementKind::Constrain(ConstrainStatement { - arguments: vecmap(constrain.arguments, |expr| { - remove_interned_in_expression(interner, expr) - }), - ..constrain - }), StatementKind::Expression(expr) => { StatementKind::Expression(remove_interned_in_expression(interner, expr)) } diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index d46484d05fa..3ba7ae42950 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -8,7 +8,7 @@ use crate::ast::{ MemberAccessExpression, MethodCallExpression, Path, PathKind, PathSegment, Pattern, PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; -use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; +use crate::ast::{ConstrainExpression, Expression, Statement, StatementKind}; use crate::hir_def::expr::{ HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent, HirLiteral, }; @@ -32,20 +32,6 @@ impl HirStatement { let expression = let_stmt.expression.to_display_ast(interner); StatementKind::new_let(pattern, r#type, expression, let_stmt.attributes.clone()) } - HirStatement::Constrain(constrain) => { - let expr = constrain.0.to_display_ast(interner); - let mut arguments = vec![expr]; - if let Some(message) = constrain.2 { - arguments.push(message.to_display_ast(interner)); - } - - // TODO: Find difference in usage between Assert & AssertEq - StatementKind::Constrain(ConstrainStatement { - kind: ConstrainKind::Assert, - arguments, - span, - }) - } HirStatement::Assign(assign) => StatementKind::Assign(AssignStatement { lvalue: assign.lvalue.to_display_ast(interner), expression: assign.expression.to_display_ast(interner), @@ -180,6 +166,20 @@ impl HirExpression { is_macro_call: false, })) } + HirExpression::Constrain(constrain) => { + let expr = constrain.0.to_display_ast(interner); + let mut arguments = vec![expr]; + if let Some(message) = constrain.2 { + arguments.push(message.to_display_ast(interner)); + } + + // TODO: Find difference in usage between Assert & AssertEq + ExpressionKind::Constrain(ConstrainExpression { + kind: ConstrainKind::Assert, + arguments, + span, + }) + } HirExpression::Cast(cast) => { let lhs = cast.lhs.to_display_ast(interner); let r#type = cast.r#type.to_display_ast(); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 33f8e43863e..5f001192dac 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -14,7 +14,7 @@ use crate::elaborator::Elaborator; use crate::graph::CrateId; use crate::hir::def_map::ModuleId; use crate::hir::type_check::TypeCheckError; -use crate::hir_def::expr::{HirEnumConstructorExpression, ImplKind}; +use crate::hir_def::expr::{HirConstrainExpression, HirEnumConstructorExpression, ImplKind}; use crate::hir_def::function::FunctionBody; use crate::monomorphization::{ perform_impl_bindings, perform_instantiation_bindings, resolve_trait_method, @@ -32,8 +32,8 @@ use crate::{ HirPrefixExpression, }, stmt::{ - HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, - HirPattern, HirStatement, + HirAssignStatement, HirForStatement, HirLValue, HirLetStatement, HirPattern, + HirStatement, }, types::Kind, }, @@ -532,6 +532,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { HirExpression::MemberAccess(access) => self.evaluate_access(access, id), HirExpression::Call(call) => self.evaluate_call(call, id), HirExpression::MethodCall(call) => self.evaluate_method_call(call, id), + HirExpression::Constrain(constrain) => self.evaluate_constrain(constrain), HirExpression::Cast(cast) => self.evaluate_cast(&cast, id), HirExpression::If(if_) => self.evaluate_if(if_, id), HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple), @@ -1560,7 +1561,6 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { pub fn evaluate_statement(&mut self, statement: StmtId) -> IResult { match self.elaborator.interner.statement(&statement) { HirStatement::Let(let_) => self.evaluate_let(let_), - HirStatement::Constrain(constrain) => self.evaluate_constrain(constrain), HirStatement::Assign(assign) => self.evaluate_assign(assign), HirStatement::For(for_) => self.evaluate_for(for_), HirStatement::Loop(expression) => self.evaluate_loop(expression), @@ -1586,7 +1586,7 @@ impl<'local, 'interner> Interpreter<'local, 'interner> { Ok(Value::Unit) } - fn evaluate_constrain(&mut self, constrain: HirConstrainStatement) -> IResult { + fn evaluate_constrain(&mut self, constrain: HirConstrainExpression) -> IResult { match self.evaluate(constrain.0)? { Value::Bool(true) => Ok(Value::Unit), Value::Bool(false) => { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 9abb1b190d5..6655c8977e2 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -1540,7 +1540,7 @@ fn expr_as_assert( location: Location, ) -> IResult { expr_as(interner, arguments, return_type.clone(), location, |expr| { - if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr { + if let ExprValue::Expression(ExpressionKind::Constrain(mut constrain)) = expr { if constrain.kind == ConstrainKind::Assert && !constrain.arguments.is_empty() && constrain.arguments.len() <= 2 @@ -1580,7 +1580,7 @@ fn expr_as_assert_eq( location: Location, ) -> IResult { expr_as(interner, arguments, return_type.clone(), location, |expr| { - if let ExprValue::Statement(StatementKind::Constrain(mut constrain)) = expr { + if let ExprValue::Expression(ExpressionKind::Constrain(mut constrain)) = expr { if constrain.kind == ConstrainKind::AssertEq && constrain.arguments.len() >= 2 && constrain.arguments.len() <= 3 diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 00b94411fcd..543c13fac9c 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -34,6 +34,7 @@ pub enum HirExpression { MemberAccess(HirMemberAccess), Call(HirCallExpression), MethodCall(HirMethodCallExpression), + Constrain(HirConstrainExpression), Cast(HirCastExpression), If(HirIfExpression), Tuple(Vec), @@ -200,6 +201,13 @@ pub struct HirMethodCallExpression { pub location: Location, } +/// Corresponds to `assert` and `assert_eq` in the source code. +/// This node also contains the FileId of the file the constrain +/// originates from. This is used later in the SSA pass to issue +/// an error if a constrain is found to be always false. +#[derive(Debug, Clone)] +pub struct HirConstrainExpression(pub ExprId, pub FileId, pub Option); + #[derive(Debug, Clone)] pub enum HirMethodReference { /// A method can be defined in a regular `impl` block, in which case diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 8a580e735b1..96ef7161341 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -3,7 +3,6 @@ use crate::ast::Ident; use crate::node_interner::{ExprId, StmtId}; use crate::token::SecondaryAttribute; use crate::Type; -use fm::FileId; use noirc_errors::{Location, Span}; /// A HirStatement is the result of performing name resolution on @@ -13,7 +12,6 @@ use noirc_errors::{Location, Span}; #[derive(Debug, Clone)] pub enum HirStatement { Let(HirLetStatement), - Constrain(HirConstrainStatement), Assign(HirAssignStatement), For(HirForStatement), Loop(ExprId), @@ -74,13 +72,6 @@ pub struct HirAssignStatement { pub expression: ExprId, } -/// Corresponds to `constrain expr;` in the source code. -/// This node also contains the FileId of the file the constrain -/// originates from. This is used later in the SSA pass to issue -/// an error if a constrain is found to be always false. -#[derive(Debug, Clone)] -pub struct HirConstrainStatement(pub ExprId, pub FileId, pub Option); - #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum HirPattern { Identifier(HirIdent), diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 7ad703523d4..5d81913f4ec 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -557,6 +557,22 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Call(call) => self.function_call(call, expr)?, + HirExpression::Constrain(constrain) => { + let expr = self.expr(constrain.0)?; + let location = self.interner.expr_location(&constrain.0); + let assert_message = constrain + .2 + .map(|assert_msg_expr| { + self.expr(assert_msg_expr).map(|expr| { + (expr, self.interner.id_type(assert_msg_expr).follow_bindings()) + }) + }) + .transpose()? + .map(Box::new); + + ast::Expression::Constrain(Box::new(expr), location, assert_message) + } + HirExpression::Cast(cast) => { let location = self.interner.expr_location(&expr); let typ = Self::convert_type(&cast.r#type, location)?; @@ -658,21 +674,6 @@ impl<'interner> Monomorphizer<'interner> { fn statement(&mut self, id: StmtId) -> Result { match self.interner.statement(&id) { HirStatement::Let(let_statement) => self.let_statement(let_statement), - HirStatement::Constrain(constrain) => { - let expr = self.expr(constrain.0)?; - let location = self.interner.expr_location(&constrain.0); - let assert_message = constrain - .2 - .map(|assert_msg_expr| { - self.expr(assert_msg_expr).map(|expr| { - (expr, self.interner.id_type(assert_msg_expr).follow_bindings()) - }) - }) - .transpose()? - .map(Box::new); - - Ok(ast::Expression::Constrain(Box::new(expr), location, assert_message)) - } HirStatement::Assign(assign) => self.assign(assign), HirStatement::For(for_loop) => { self.is_range_loop = true; diff --git a/compiler/noirc_frontend/src/parser/parser/expression.rs b/compiler/noirc_frontend/src/parser/parser/expression.rs index eff309154e3..319eefc190a 100644 --- a/compiler/noirc_frontend/src/parser/parser/expression.rs +++ b/compiler/noirc_frontend/src/parser/parser/expression.rs @@ -3,9 +3,10 @@ use noirc_errors::Span; use crate::{ ast::{ - ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, - Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, MatchExpression, - MemberAccessExpression, MethodCallExpression, Statement, TypePath, UnaryOp, UnresolvedType, + ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstrainExpression, + ConstrainKind, ConstructorExpression, Expression, ExpressionKind, Ident, IfExpression, + IndexExpression, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, + Statement, TypePath, UnaryOp, UnresolvedType, }, parser::{labels::ParsingRuleLabel, parser::parse_many::separated_by_comma, ParserErrorReason}, token::{Keyword, Token, TokenKind}, @@ -653,6 +654,7 @@ impl<'a> Parser<'a> { /// | ArrayExpression /// | SliceExpression /// | BlockExpression + /// | ConstrainExpression /// /// QuoteExpression = 'quote' '{' token* '}' /// @@ -696,6 +698,10 @@ impl<'a> Parser<'a> { return Some(ExpressionKind::Block(kind)); } + if let Some(constrain) = self.parse_constrain_expression() { + return Some(ExpressionKind::Constrain(constrain)); + } + None } @@ -800,6 +806,49 @@ impl<'a> Parser<'a> { } } + /// ConstrainExpression + /// = 'constrain' Expression + /// | 'assert' Arguments + /// | 'assert_eq' Arguments + pub(super) fn parse_constrain_expression(&mut self) -> Option { + let start_span = self.current_token_span; + let kind = self.parse_constrain_kind()?; + + Some(match kind { + ConstrainKind::Assert | ConstrainKind::AssertEq => { + let arguments = self.parse_arguments(); + if arguments.is_none() { + self.expected_token(Token::LeftParen); + } + let arguments = arguments.unwrap_or_default(); + + ConstrainExpression { kind, arguments, span: self.span_since(start_span) } + } + ConstrainKind::Constrain => { + self.push_error(ParserErrorReason::ConstrainDeprecated, self.previous_token_span); + + let expression = self.parse_expression_or_error(); + ConstrainExpression { + kind, + arguments: vec![expression], + span: self.span_since(start_span), + } + } + }) + } + + fn parse_constrain_kind(&mut self) -> Option { + if self.eat_keyword(Keyword::Assert) { + Some(ConstrainKind::Assert) + } else if self.eat_keyword(Keyword::AssertEq) { + Some(ConstrainKind::AssertEq) + } else if self.eat_keyword(Keyword::Constrain) { + Some(ConstrainKind::Constrain) + } else { + None + } + } + /// Block = '{' Statement* '}' pub(super) fn parse_block(&mut self) -> Option { if !self.eat_left_brace() { @@ -849,8 +898,8 @@ mod tests { use crate::{ ast::{ - ArrayLiteral, BinaryOpKind, Expression, ExpressionKind, Literal, StatementKind, - UnaryOp, UnresolvedTypeData, + ArrayLiteral, BinaryOpKind, ConstrainKind, Expression, ExpressionKind, Literal, + StatementKind, UnaryOp, UnresolvedTypeData, }, parser::{ parser::tests::{ @@ -1749,4 +1798,45 @@ mod tests { }; assert_eq!(expr.kind.to_string(), "((1 + 2))"); } + + #[test] + fn parses_assert() { + let src = "assert(true, \"good\")"; + let expression = parse_expression_no_errors(src); + let ExpressionKind::Constrain(constrain) = expression.kind else { + panic!("Expected constrain expression"); + }; + assert_eq!(constrain.kind, ConstrainKind::Assert); + assert_eq!(constrain.arguments.len(), 2); + } + + #[test] + fn parses_assert_eq() { + let src = "assert_eq(1, 2, \"bad\")"; + let expression = parse_expression_no_errors(src); + let ExpressionKind::Constrain(constrain) = expression.kind else { + panic!("Expected constrain expression"); + }; + assert_eq!(constrain.kind, ConstrainKind::AssertEq); + assert_eq!(constrain.arguments.len(), 3); + } + + #[test] + fn parses_constrain() { + let src = " + constrain 1 + ^^^^^^^^^ + "; + let (src, span) = get_source_with_error_span(src); + let mut parser = Parser::for_str(&src); + let expression = parser.parse_expression_or_error(); + let ExpressionKind::Constrain(constrain) = expression.kind else { + panic!("Expected constrain expression"); + }; + assert_eq!(constrain.kind, ConstrainKind::Constrain); + assert_eq!(constrain.arguments.len(), 1); + + let reason = get_single_error_reason(&parser.errors, span); + assert!(matches!(reason, ParserErrorReason::ConstrainDeprecated)); + } } diff --git a/compiler/noirc_frontend/src/parser/parser/statement.rs b/compiler/noirc_frontend/src/parser/parser/statement.rs index 37013e91528..f9cc63a364e 100644 --- a/compiler/noirc_frontend/src/parser/parser/statement.rs +++ b/compiler/noirc_frontend/src/parser/parser/statement.rs @@ -2,9 +2,9 @@ use noirc_errors::{Span, Spanned}; use crate::{ ast::{ - AssignStatement, BinaryOp, BinaryOpKind, ConstrainKind, ConstrainStatement, Expression, - ExpressionKind, ForBounds, ForLoopStatement, ForRange, Ident, InfixExpression, LValue, - LetStatement, Statement, StatementKind, + AssignStatement, BinaryOp, BinaryOpKind, Expression, ExpressionKind, ForBounds, + ForLoopStatement, ForRange, Ident, InfixExpression, LValue, LetStatement, Statement, + StatementKind, }, parser::{labels::ParsingRuleLabel, ParserErrorReason}, token::{Attribute, Keyword, Token, TokenKind}, @@ -89,7 +89,6 @@ impl<'a> Parser<'a> { /// | ContinueStatement /// | ReturnStatement /// | LetStatement - /// | ConstrainStatement /// | ComptimeStatement /// | ForStatement /// | LoopStatement @@ -145,10 +144,6 @@ impl<'a> Parser<'a> { return Some(StatementKind::Let(let_statement)); } - if let Some(constrain) = self.parse_constrain_statement() { - return Some(StatementKind::Constrain(constrain)); - } - if self.at_keyword(Keyword::Comptime) { return self.parse_comptime_statement(attributes); } @@ -432,58 +427,12 @@ impl<'a> Parser<'a> { is_global_let: false, }) } - - /// ConstrainStatement - /// = 'constrain' Expression - /// | 'assert' Arguments - /// | 'assert_eq' Arguments - fn parse_constrain_statement(&mut self) -> Option { - let start_span = self.current_token_span; - let kind = self.parse_constrain_kind()?; - - Some(match kind { - ConstrainKind::Assert | ConstrainKind::AssertEq => { - let arguments = self.parse_arguments(); - if arguments.is_none() { - self.expected_token(Token::LeftParen); - } - let arguments = arguments.unwrap_or_default(); - - ConstrainStatement { kind, arguments, span: self.span_since(start_span) } - } - ConstrainKind::Constrain => { - self.push_error(ParserErrorReason::ConstrainDeprecated, self.previous_token_span); - - let expression = self.parse_expression_or_error(); - ConstrainStatement { - kind, - arguments: vec![expression], - span: self.span_since(start_span), - } - } - }) - } - - fn parse_constrain_kind(&mut self) -> Option { - if self.eat_keyword(Keyword::Assert) { - Some(ConstrainKind::Assert) - } else if self.eat_keyword(Keyword::AssertEq) { - Some(ConstrainKind::AssertEq) - } else if self.eat_keyword(Keyword::Constrain) { - Some(ConstrainKind::Constrain) - } else { - None - } - } } #[cfg(test)] mod tests { use crate::{ - ast::{ - ConstrainKind, ExpressionKind, ForRange, LValue, Statement, StatementKind, - UnresolvedTypeData, - }, + ast::{ExpressionKind, ForRange, LValue, Statement, StatementKind, UnresolvedTypeData}, parser::{ parser::tests::{ expect_no_errors, get_single_error, get_single_error_reason, @@ -551,47 +500,6 @@ mod tests { assert_eq!(let_statement.pattern.to_string(), "x"); } - #[test] - fn parses_assert() { - let src = "assert(true, \"good\")"; - let statement = parse_statement_no_errors(src); - let StatementKind::Constrain(constrain) = statement.kind else { - panic!("Expected constrain statement"); - }; - assert_eq!(constrain.kind, ConstrainKind::Assert); - assert_eq!(constrain.arguments.len(), 2); - } - - #[test] - fn parses_assert_eq() { - let src = "assert_eq(1, 2, \"bad\")"; - let statement = parse_statement_no_errors(src); - let StatementKind::Constrain(constrain) = statement.kind else { - panic!("Expected constrain statement"); - }; - assert_eq!(constrain.kind, ConstrainKind::AssertEq); - assert_eq!(constrain.arguments.len(), 3); - } - - #[test] - fn parses_constrain() { - let src = " - constrain 1 - ^^^^^^^^^ - "; - let (src, span) = get_source_with_error_span(src); - let mut parser = Parser::for_str(&src); - let statement = parser.parse_statement_or_error(); - let StatementKind::Constrain(constrain) = statement.kind else { - panic!("Expected constrain statement"); - }; - assert_eq!(constrain.kind, ConstrainKind::Constrain); - assert_eq!(constrain.arguments.len(), 1); - - let reason = get_single_error_reason(&parser.errors, span); - assert!(matches!(reason, ParserErrorReason::ConstrainDeprecated)); - } - #[test] fn parses_comptime_block() { let src = "comptime { 1 }"; @@ -851,4 +759,15 @@ mod tests { }; assert_eq!(block.statements.len(), 2); } + + #[test] + fn parses_let_with_assert() { + let src = "let _ = assert(true);"; + let mut parser = Parser::for_str(src); + let statement = parser.parse_statement_or_error(); + let StatementKind::Let(let_statement) = statement.kind else { + panic!("Expected let"); + }; + assert!(matches!(let_statement.expression.kind, ExpressionKind::Constrain(..))); + } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 64b71a2cf8a..cda6c267ec7 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -900,7 +900,6 @@ fn find_lambda_captures(stmts: &[StmtId], interner: &NodeInterner, result: &mut HirStatement::Expression(expr_id) => expr_id, HirStatement::Let(let_stmt) => let_stmt.expression, HirStatement::Assign(assign_stmt) => assign_stmt.expression, - HirStatement::Constrain(constr_stmt) => constr_stmt.0, HirStatement::Semi(semi_expr) => semi_expr, HirStatement::For(for_loop) => for_loop.block, HirStatement::Loop(block) => block, diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index 8e091d1eb04..b9673755da6 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -575,6 +575,7 @@ fn get_expression_name(expression: &Expression) -> Option { ExpressionKind::Parenthesized(expr) => get_expression_name(expr), ExpressionKind::AsTraitPath(path) => Some(path.impl_item.to_string()), ExpressionKind::TypePath(path) => Some(path.item.to_string()), + ExpressionKind::Constrain(constrain) => Some(constrain.kind.to_string()), ExpressionKind::Constructor(..) | ExpressionKind::Infix(..) | ExpressionKind::Index(..) diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index 99bd463f44a..4a2609d7ae3 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -8,7 +8,7 @@ use lsp_types::{ use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ - CallExpression, ConstrainKind, ConstrainStatement, Expression, FunctionReturnType, + CallExpression, ConstrainExpression, ConstrainKind, Expression, FunctionReturnType, MethodCallExpression, Statement, Visitor, }, hir_def::{function::FuncMeta, stmt::HirPattern}, @@ -383,7 +383,7 @@ impl<'a> Visitor for SignatureFinder<'a> { false } - fn visit_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) -> bool { + fn visit_constrain_statement(&mut self, constrain_statement: &ConstrainExpression) -> bool { constrain_statement.accept_children(self); if self.signature_help.is_some() { diff --git a/tooling/nargo_fmt/src/formatter/expression.rs b/tooling/nargo_fmt/src/formatter/expression.rs index 98eabe10e7e..54d9d2e41f5 100644 --- a/tooling/nargo_fmt/src/formatter/expression.rs +++ b/tooling/nargo_fmt/src/formatter/expression.rs @@ -1,9 +1,10 @@ use noirc_frontend::{ ast::{ ArrayLiteral, BinaryOpKind, BlockExpression, CallExpression, CastExpression, - ConstructorExpression, Expression, ExpressionKind, IfExpression, IndexExpression, - InfixExpression, Lambda, Literal, MatchExpression, MemberAccessExpression, - MethodCallExpression, PrefixExpression, TypePath, UnaryOp, UnresolvedTypeData, + ConstrainExpression, ConstrainKind, ConstructorExpression, Expression, ExpressionKind, + IfExpression, IndexExpression, InfixExpression, Lambda, Literal, MatchExpression, + MemberAccessExpression, MethodCallExpression, PrefixExpression, TypePath, UnaryOp, + UnresolvedTypeData, }, token::{Keyword, Token}, }; @@ -39,6 +40,9 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { ExpressionKind::MethodCall(method_call) => { group.group(self.format_method_call(*method_call)); } + ExpressionKind::Constrain(constrain) => { + group.group(self.format_constrain(constrain)); + } ExpressionKind::Constructor(constructor) => { group.group(self.format_constructor(*constructor)); } @@ -1145,6 +1149,40 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { group } + fn format_constrain(&mut self, constrain_statement: ConstrainExpression) -> ChunkGroup { + let mut group = ChunkGroup::new(); + + let keyword = match constrain_statement.kind { + ConstrainKind::Assert => Keyword::Assert, + ConstrainKind::AssertEq => Keyword::AssertEq, + ConstrainKind::Constrain => { + unreachable!("constrain always produces an error, and the formatter doesn't run when there are errors") + } + }; + + group.text(self.chunk(|formatter| { + formatter.write_keyword(keyword); + formatter.write_left_paren(); + })); + + group.kind = GroupKind::ExpressionList { + prefix_width: group.width(), + expressions_count: constrain_statement.arguments.len(), + }; + + self.format_expressions_separated_by_comma( + constrain_statement.arguments, + false, // force trailing comma + &mut group, + ); + + group.text(self.chunk(|formatter| { + formatter.write_right_paren(); + })); + + group + } + pub(super) fn format_block_expression( &mut self, block: BlockExpression, diff --git a/tooling/nargo_fmt/src/formatter/statement.rs b/tooling/nargo_fmt/src/formatter/statement.rs index 751bc419d4a..ae4177c224b 100644 --- a/tooling/nargo_fmt/src/formatter/statement.rs +++ b/tooling/nargo_fmt/src/formatter/statement.rs @@ -1,8 +1,7 @@ use noirc_frontend::{ ast::{ - AssignStatement, ConstrainKind, ConstrainStatement, Expression, ExpressionKind, - ForLoopStatement, ForRange, LetStatement, Pattern, Statement, StatementKind, - UnresolvedType, UnresolvedTypeData, + AssignStatement, Expression, ExpressionKind, ForLoopStatement, ForRange, LetStatement, + Pattern, Statement, StatementKind, UnresolvedType, UnresolvedTypeData, }, token::{Keyword, SecondaryAttribute, Token, TokenKind}, }; @@ -49,9 +48,6 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { StatementKind::Let(let_statement) => { group.group(self.format_let_statement(let_statement)); } - StatementKind::Constrain(constrain_statement) => { - group.group(self.format_constrain_statement(constrain_statement)); - } StatementKind::Expression(expression) => match expression.kind { ExpressionKind::Block(block) => group.group(self.format_block_expression( block, true, // force multiple lines @@ -153,44 +149,6 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { group } - fn format_constrain_statement( - &mut self, - constrain_statement: ConstrainStatement, - ) -> ChunkGroup { - let mut group = ChunkGroup::new(); - - let keyword = match constrain_statement.kind { - ConstrainKind::Assert => Keyword::Assert, - ConstrainKind::AssertEq => Keyword::AssertEq, - ConstrainKind::Constrain => { - unreachable!("constrain always produces an error, and the formatter doesn't run when there are errors") - } - }; - - group.text(self.chunk(|formatter| { - formatter.write_keyword(keyword); - formatter.write_left_paren(); - })); - - group.kind = GroupKind::ExpressionList { - prefix_width: group.width(), - expressions_count: constrain_statement.arguments.len(), - }; - - self.format_expressions_separated_by_comma( - constrain_statement.arguments, - false, // force trailing comma - &mut group, - ); - - group.text(self.chunk(|formatter| { - formatter.write_right_paren(); - formatter.write_semicolon(); - })); - - group - } - fn format_assign(&mut self, assign_statement: AssignStatement) -> ChunkGroup { let mut group = ChunkGroup::new(); let mut is_op_assign = false;