diff --git a/core/src/ast.rs b/core/src/ast.rs index 22c4ac56..f2cad338 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -99,7 +99,7 @@ impl fmt::Display for Bop { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum Expr { Literal(Value), Ident(Ident), @@ -122,6 +122,7 @@ pub(crate) enum Expr { Of(Ident, Box), Assign(Ident, Box), + Equality(bool, Box, Box), Statements(Box, Box), } @@ -202,6 +203,12 @@ impl Expr { a.serialize(write)?; b.serialize(write)?; } + Self::Equality(is_equals, a, b) => { + 16u8.serialize(write)?; + is_equals.serialize(write)?; + a.serialize(write)?; + b.serialize(write)?; + } } Ok(()) } @@ -252,6 +259,11 @@ impl Expr { Box::new(Self::deserialize(read)?), Box::new(Self::deserialize(read)?), ), + 16 => Self::Equality( + bool::deserialize(read)?, + Box::new(Self::deserialize(read)?), + Box::new(Self::deserialize(read)?), + ), _ => return Err(FendError::DeserializationError), }) } @@ -309,6 +321,12 @@ impl Expr { a.format(attrs, ctx, int)?, b.format(attrs, ctx, int)? ), + Self::Equality(is_equals, a, b) => format!( + "{} {} {}", + a.format(attrs, ctx, int)?, + if *is_equals { "==" } else { "!=" }, + b.format(attrs, ctx, int)? + ), }) } } @@ -452,6 +470,12 @@ pub(crate) fn evaluate( let _lhs = evaluate(*a, scope.clone(), attrs, context, int)?; evaluate(*b, scope, attrs, context, int)? } + Expr::Equality(is_equals, a, b) => { + let lhs = evaluate(*a, scope.clone(), attrs, context, int)?; + let rhs = evaluate(*b, scope.clone(), attrs, context, int)?; + + Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) + } }) } diff --git a/core/src/ident.rs b/core/src/ident.rs index 0c9fc9a2..4436aa5c 100644 --- a/core/src/ident.rs +++ b/core/src/ident.rs @@ -5,7 +5,7 @@ use crate::{ serialize::{Deserialize, Serialize}, }; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct Ident(Cow<'static, str>); impl Ident { diff --git a/core/src/lexer.rs b/core/src/lexer.rs index 88722cb8..4ecc3c4f 100644 --- a/core/src/lexer.rs +++ b/core/src/lexer.rs @@ -36,7 +36,9 @@ pub(crate) enum Symbol { ShiftLeft, ShiftRight, Semicolon, - Equals, // used for assignment + Equals, // used for assignment + DoubleEquals, // used for equality + NotEquals, Combination, Permutation, } @@ -65,6 +67,8 @@ impl fmt::Display for Symbol { Self::ShiftRight => ">>", Self::Semicolon => ";", Self::Equals => "=", + Self::DoubleEquals => "==", + Self::NotEquals => "!=", Self::Combination => "nCr", Self::Permutation => "nPr", }; @@ -505,7 +509,13 @@ fn parse_symbol(ch: char, input: &mut &str) -> FResult { '(' => Symbol::OpenParens, ')' => Symbol::CloseParens, '+' => Symbol::Add, - '!' => Symbol::Factorial, + '!' => { + if test_next('=') { + Symbol::NotEquals + } else { + Symbol::Factorial + } + } // unicode minus sign '-' | '\u{2212}' => Symbol::Sub, '*' | '\u{d7}' | '\u{2715}' => { @@ -523,6 +533,8 @@ fn parse_symbol(ch: char, input: &mut &str) -> FResult { '=' => { if test_next('>') { Symbol::Fn + } else if test_next('=') { + Symbol::DoubleEquals } else { Symbol::Equals } diff --git a/core/src/num/dist.rs b/core/src/num/dist.rs index 902e44d9..997f0761 100644 --- a/core/src/num/dist.rs +++ b/core/src/num/dist.rs @@ -13,7 +13,7 @@ use std::{fmt, io}; use super::real::Real; use super::{Base, Exact, FormattingStyle}; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) struct Dist { // invariant: probabilities must sum to 1 parts: HashMap, diff --git a/core/src/num/unit.rs b/core/src/num/unit.rs index a9306506..7d63e410 100644 --- a/core/src/num/unit.rs +++ b/core/src/num/unit.rs @@ -8,7 +8,7 @@ use crate::scope::Scope; use crate::serialize::{Deserialize, Serialize}; use crate::units::{lookup_default_unit, query_unit_static}; use crate::{ast, ident::Ident}; -use crate::{Attrs, Span, SpanKind}; +use crate::{interrupt::Never, Attrs, Span, SpanKind}; use std::borrow::Cow; use std::collections::HashMap; use std::ops::Neg; @@ -37,6 +37,20 @@ pub(crate) struct Value { simplifiable: bool, } +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + if self.value == other.value && self.unit == other.unit { + return true; + } + match self.clone().sub(other.clone(), &Never) { + Err(_) => false, + Ok(result) => result.is_zero(), + } + } +} + +impl Eq for Value {} + impl Value { pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> { self.value.serialize(write)?; @@ -982,7 +996,8 @@ impl fmt::Display for FormattedValue { } } -#[derive(Clone)] +// TODO: equality comparisons should not depend on order +#[derive(Clone, PartialEq, Eq)] struct Unit { components: Vec, } diff --git a/core/src/num/unit/unit_exponent.rs b/core/src/num/unit/unit_exponent.rs index ee03447e..06af42e2 100644 --- a/core/src/num/unit/unit_exponent.rs +++ b/core/src/num/unit/unit_exponent.rs @@ -8,7 +8,7 @@ use crate::Interrupt; use super::{base_unit::BaseUnit, named_unit::NamedUnit}; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) struct UnitExponent { pub(crate) unit: NamedUnit, pub(crate) exponent: Complex, diff --git a/core/src/parser.rs b/core/src/parser.rs index 87af346c..a752b5aa 100644 --- a/core/src/parser.rs +++ b/core/src/parser.rs @@ -464,8 +464,27 @@ fn parse_function(input: &[Token]) -> ParseResult<'_> { Ok((lhs, input)) } -fn parse_assignment(input: &[Token]) -> ParseResult<'_> { +fn parse_equality(input: &[Token]) -> ParseResult<'_> { let (lhs, input) = parse_function(input)?; + if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::DoubleEquals) { + let (rhs, remaining) = parse_function(remaining)?; + Ok(( + Expr::Equality(true, Box::new(lhs), Box::new(rhs)), + remaining, + )) + } else if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::NotEquals) { + let (rhs, remaining) = parse_function(remaining)?; + Ok(( + Expr::Equality(false, Box::new(lhs), Box::new(rhs)), + remaining, + )) + } else { + Ok((lhs, input)) + } +} + +fn parse_assignment(input: &[Token]) -> ParseResult<'_> { + let (lhs, input) = parse_equality(input)?; if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::Equals) { if let Expr::Ident(s) = lhs { let (rhs, remaining) = parse_assignment(remaining)?; diff --git a/core/src/scope.rs b/core/src/scope.rs index aa5ccb7a..7f56f005 100644 --- a/core/src/scope.rs +++ b/core/src/scope.rs @@ -7,7 +7,7 @@ use crate::{ast::Expr, error::Interrupt}; use std::io; use std::sync::Arc; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] enum ScopeValue { //Variable(Value), LazyVariable(Expr, Option>), @@ -55,7 +55,7 @@ impl ScopeValue { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct Scope { ident: Ident, value: ScopeValue, diff --git a/core/src/value.rs b/core/src/value.rs index ecd663d0..0511aa70 100644 --- a/core/src/value.rs +++ b/core/src/value.rs @@ -18,7 +18,7 @@ pub(crate) mod built_in_function; use built_in_function::BuiltInFunction; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) enum Value { Num(Box), BuiltInFunction(BuiltInFunction), diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index 27daac92..432bdae3 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5873,3 +5873,16 @@ fn test_superscript() { test_eval("200²", "40000"); test_eval("13¹³ days", "302875106592253 days"); } + +#[test] +fn test_equality() { + test_eval("1 + 2 == 3", "true"); + test_eval("1 + 2 != 4", "true"); + test_eval("true == false", "false"); + test_eval("true != false", "true"); + test_eval("2m == 200cm", "true"); + test_eval("2kg == 200cm", "false"); + test_eval("2kg == true", "false"); + test_eval("2.010m == 200cm", "false"); + test_eval("2.000m == approx. 200cm", "true"); +}