From 7d3db00d4d0db3a254fb2e197a9bb375148ae4f9 Mon Sep 17 00:00:00 2001 From: frectonz Date: Sun, 3 Mar 2024 10:22:48 +0300 Subject: [PATCH 1/4] Implement equality check --- core/src/ast.rs | 23 ++++++++++++++++++++++- core/src/ident.rs | 2 +- core/src/lexer.rs | 6 +++++- core/src/num/dist.rs | 2 +- core/src/num/unit.rs | 4 ++-- core/src/num/unit/unit_exponent.rs | 2 +- core/src/parser.rs | 12 +++++++++++- core/src/scope.rs | 4 ++-- core/src/value.rs | 2 +- core/tests/integration_tests.rs | 6 ++++++ 10 files changed, 52 insertions(+), 11 deletions(-) diff --git a/core/src/ast.rs b/core/src/ast.rs index 22c4ac56..f7bbe844 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -99,7 +99,7 @@ impl fmt::Display for Bop { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum Expr { Literal(Value), Ident(Ident), @@ -122,6 +122,7 @@ pub(crate) enum Expr { Of(Ident, Box), Assign(Ident, Box), + Equality(Box, Box), Statements(Box, Box), } @@ -202,6 +203,11 @@ impl Expr { a.serialize(write)?; b.serialize(write)?; } + Self::Equality(a, b) => { + 16u8.serialize(write)?; + a.serialize(write)?; + b.serialize(write)?; + } } Ok(()) } @@ -252,6 +258,10 @@ impl Expr { Box::new(Self::deserialize(read)?), Box::new(Self::deserialize(read)?), ), + 16 => Self::Equality( + Box::new(Self::deserialize(read)?), + Box::new(Self::deserialize(read)?), + ), _ => return Err(FendError::DeserializationError), }) } @@ -309,6 +319,11 @@ impl Expr { a.format(attrs, ctx, int)?, b.format(attrs, ctx, int)? ), + Self::Equality(a, b) => format!( + "{} == {}", + a.format(attrs, ctx, int)?, + b.format(attrs, ctx, int)? + ), }) } } @@ -452,6 +467,12 @@ pub(crate) fn evaluate( let _lhs = evaluate(*a, scope.clone(), attrs, context, int)?; evaluate(*b, scope, attrs, context, int)? } + Expr::Equality(a, b) => { + let lhs = evaluate(*a, scope.clone(), attrs, context, int)?; + let rhs = evaluate(*b, scope.clone(), attrs, context, int)?; + + Value::Bool(lhs == rhs) + } }) } diff --git a/core/src/ident.rs b/core/src/ident.rs index 0c9fc9a2..4436aa5c 100644 --- a/core/src/ident.rs +++ b/core/src/ident.rs @@ -5,7 +5,7 @@ use crate::{ serialize::{Deserialize, Serialize}, }; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct Ident(Cow<'static, str>); impl Ident { diff --git a/core/src/lexer.rs b/core/src/lexer.rs index 88722cb8..a060b593 100644 --- a/core/src/lexer.rs +++ b/core/src/lexer.rs @@ -36,7 +36,8 @@ pub(crate) enum Symbol { ShiftLeft, ShiftRight, Semicolon, - Equals, // used for assignment + Equals, // used for assignment + DoubleEquals, // used for equality Combination, Permutation, } @@ -65,6 +66,7 @@ impl fmt::Display for Symbol { Self::ShiftRight => ">>", Self::Semicolon => ";", Self::Equals => "=", + Self::DoubleEquals => "==", Self::Combination => "nCr", Self::Permutation => "nPr", }; @@ -523,6 +525,8 @@ fn parse_symbol(ch: char, input: &mut &str) -> FResult { '=' => { if test_next('>') { Symbol::Fn + } else if test_next('=') { + Symbol::DoubleEquals } else { Symbol::Equals } diff --git a/core/src/num/dist.rs b/core/src/num/dist.rs index 902e44d9..997f0761 100644 --- a/core/src/num/dist.rs +++ b/core/src/num/dist.rs @@ -13,7 +13,7 @@ use std::{fmt, io}; use super::real::Real; use super::{Base, Exact, FormattingStyle}; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) struct Dist { // invariant: probabilities must sum to 1 parts: HashMap, diff --git a/core/src/num/unit.rs b/core/src/num/unit.rs index a9306506..9445ab4d 100644 --- a/core/src/num/unit.rs +++ b/core/src/num/unit.rs @@ -25,7 +25,7 @@ use unit_exponent::UnitExponent; use super::Exact; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] #[allow(clippy::pedantic)] pub(crate) struct Value { #[allow(clippy::struct_field_names)] @@ -982,7 +982,7 @@ impl fmt::Display for FormattedValue { } } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] struct Unit { components: Vec, } diff --git a/core/src/num/unit/unit_exponent.rs b/core/src/num/unit/unit_exponent.rs index ee03447e..06af42e2 100644 --- a/core/src/num/unit/unit_exponent.rs +++ b/core/src/num/unit/unit_exponent.rs @@ -8,7 +8,7 @@ use crate::Interrupt; use super::{base_unit::BaseUnit, named_unit::NamedUnit}; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) struct UnitExponent { pub(crate) unit: NamedUnit, pub(crate) exponent: Complex, diff --git a/core/src/parser.rs b/core/src/parser.rs index 87af346c..75f737d5 100644 --- a/core/src/parser.rs +++ b/core/src/parser.rs @@ -464,8 +464,18 @@ fn parse_function(input: &[Token]) -> ParseResult<'_> { Ok((lhs, input)) } -fn parse_assignment(input: &[Token]) -> ParseResult<'_> { +fn parse_equality(input: &[Token]) -> ParseResult<'_> { let (lhs, input) = parse_function(input)?; + if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::DoubleEquals) { + let (rhs, remaining) = parse_function(remaining)?; + Ok((Expr::Equality(Box::new(lhs), Box::new(rhs)), remaining)) + } else { + Ok((lhs, input)) + } +} + +fn parse_assignment(input: &[Token]) -> ParseResult<'_> { + let (lhs, input) = parse_equality(input)?; if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::Equals) { if let Expr::Ident(s) = lhs { let (rhs, remaining) = parse_assignment(remaining)?; diff --git a/core/src/scope.rs b/core/src/scope.rs index aa5ccb7a..7f56f005 100644 --- a/core/src/scope.rs +++ b/core/src/scope.rs @@ -7,7 +7,7 @@ use crate::{ast::Expr, error::Interrupt}; use std::io; use std::sync::Arc; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] enum ScopeValue { //Variable(Value), LazyVariable(Expr, Option>), @@ -55,7 +55,7 @@ impl ScopeValue { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct Scope { ident: Ident, value: ScopeValue, diff --git a/core/src/value.rs b/core/src/value.rs index ecd663d0..0511aa70 100644 --- a/core/src/value.rs +++ b/core/src/value.rs @@ -18,7 +18,7 @@ pub(crate) mod built_in_function; use built_in_function::BuiltInFunction; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub(crate) enum Value { Num(Box), BuiltInFunction(BuiltInFunction), diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index 27daac92..d8092620 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5873,3 +5873,9 @@ fn test_superscript() { test_eval("200²", "40000"); test_eval("13¹³ days", "302875106592253 days"); } + +#[test] +fn test_equality() { + test_eval("1 + 2 == 3", "true"); + test_eval("true == false", "false"); +} From 0e8e192ee3a0a4d0cd31b34bcd5d475bb38352f0 Mon Sep 17 00:00:00 2001 From: frectonz Date: Sun, 3 Mar 2024 13:07:10 +0300 Subject: [PATCH 2/4] Implement inequality check --- core/src/ast.rs | 15 +++++++++------ core/src/lexer.rs | 10 +++++++++- core/src/parser.rs | 11 ++++++++++- core/tests/integration_tests.rs | 2 ++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/core/src/ast.rs b/core/src/ast.rs index f7bbe844..f2cad338 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -122,7 +122,7 @@ pub(crate) enum Expr { Of(Ident, Box), Assign(Ident, Box), - Equality(Box, Box), + Equality(bool, Box, Box), Statements(Box, Box), } @@ -203,8 +203,9 @@ impl Expr { a.serialize(write)?; b.serialize(write)?; } - Self::Equality(a, b) => { + Self::Equality(is_equals, a, b) => { 16u8.serialize(write)?; + is_equals.serialize(write)?; a.serialize(write)?; b.serialize(write)?; } @@ -259,6 +260,7 @@ impl Expr { Box::new(Self::deserialize(read)?), ), 16 => Self::Equality( + bool::deserialize(read)?, Box::new(Self::deserialize(read)?), Box::new(Self::deserialize(read)?), ), @@ -319,9 +321,10 @@ impl Expr { a.format(attrs, ctx, int)?, b.format(attrs, ctx, int)? ), - Self::Equality(a, b) => format!( - "{} == {}", + Self::Equality(is_equals, a, b) => format!( + "{} {} {}", a.format(attrs, ctx, int)?, + if *is_equals { "==" } else { "!=" }, b.format(attrs, ctx, int)? ), }) @@ -467,11 +470,11 @@ pub(crate) fn evaluate( let _lhs = evaluate(*a, scope.clone(), attrs, context, int)?; evaluate(*b, scope, attrs, context, int)? } - Expr::Equality(a, b) => { + Expr::Equality(is_equals, a, b) => { let lhs = evaluate(*a, scope.clone(), attrs, context, int)?; let rhs = evaluate(*b, scope.clone(), attrs, context, int)?; - Value::Bool(lhs == rhs) + Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) } }) } diff --git a/core/src/lexer.rs b/core/src/lexer.rs index a060b593..4ecc3c4f 100644 --- a/core/src/lexer.rs +++ b/core/src/lexer.rs @@ -38,6 +38,7 @@ pub(crate) enum Symbol { Semicolon, Equals, // used for assignment DoubleEquals, // used for equality + NotEquals, Combination, Permutation, } @@ -67,6 +68,7 @@ impl fmt::Display for Symbol { Self::Semicolon => ";", Self::Equals => "=", Self::DoubleEquals => "==", + Self::NotEquals => "!=", Self::Combination => "nCr", Self::Permutation => "nPr", }; @@ -507,7 +509,13 @@ fn parse_symbol(ch: char, input: &mut &str) -> FResult { '(' => Symbol::OpenParens, ')' => Symbol::CloseParens, '+' => Symbol::Add, - '!' => Symbol::Factorial, + '!' => { + if test_next('=') { + Symbol::NotEquals + } else { + Symbol::Factorial + } + } // unicode minus sign '-' | '\u{2212}' => Symbol::Sub, '*' | '\u{d7}' | '\u{2715}' => { diff --git a/core/src/parser.rs b/core/src/parser.rs index 75f737d5..a752b5aa 100644 --- a/core/src/parser.rs +++ b/core/src/parser.rs @@ -468,7 +468,16 @@ fn parse_equality(input: &[Token]) -> ParseResult<'_> { let (lhs, input) = parse_function(input)?; if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::DoubleEquals) { let (rhs, remaining) = parse_function(remaining)?; - Ok((Expr::Equality(Box::new(lhs), Box::new(rhs)), remaining)) + Ok(( + Expr::Equality(true, Box::new(lhs), Box::new(rhs)), + remaining, + )) + } else if let Ok(((), remaining)) = parse_fixed_symbol(input, Symbol::NotEquals) { + let (rhs, remaining) = parse_function(remaining)?; + Ok(( + Expr::Equality(false, Box::new(lhs), Box::new(rhs)), + remaining, + )) } else { Ok((lhs, input)) } diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index d8092620..527b1dae 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5877,5 +5877,7 @@ fn test_superscript() { #[test] fn test_equality() { test_eval("1 + 2 == 3", "true"); + test_eval("1 + 2 != 4", "true"); test_eval("true == false", "false"); + test_eval("true != false", "true"); } From c3e7795f50071f61213ca9da448efe510b4e8e0f Mon Sep 17 00:00:00 2001 From: frectonz Date: Sun, 3 Mar 2024 14:40:12 +0300 Subject: [PATCH 3/4] Custom `PartialEq` impl for numbers and support for comparisons of different units --- core/src/ast.rs | 7 ++++++- core/src/num/unit.rs | 8 +++++++- core/tests/integration_tests.rs | 1 + 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/ast.rs b/core/src/ast.rs index f2cad338..fa363273 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -474,7 +474,12 @@ pub(crate) fn evaluate( let lhs = evaluate(*a, scope.clone(), attrs, context, int)?; let rhs = evaluate(*b, scope.clone(), attrs, context, int)?; - Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) + if let (Value::Num(l), Value::Num(r)) = (&lhs, &rhs) { + let a = l.clone().sub(*r.clone(), int)?; + Value::Bool(if is_equals { a.is_zero() } else { !a.is_zero() }) + } else { + Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) + } } }) } diff --git a/core/src/num/unit.rs b/core/src/num/unit.rs index 9445ab4d..a1746641 100644 --- a/core/src/num/unit.rs +++ b/core/src/num/unit.rs @@ -25,7 +25,7 @@ use unit_exponent::UnitExponent; use super::Exact; -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, Eq)] #[allow(clippy::pedantic)] pub(crate) struct Value { #[allow(clippy::struct_field_names)] @@ -37,6 +37,12 @@ pub(crate) struct Value { simplifiable: bool, } +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.value == other.value && self.unit == other.unit + } +} + impl Value { pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> { self.value.serialize(write)?; diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index 527b1dae..3eb7e7eb 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5880,4 +5880,5 @@ fn test_equality() { test_eval("1 + 2 != 4", "true"); test_eval("true == false", "false"); test_eval("true != false", "true"); + test_eval("2m == 200cm", "true"); } From 842a769ba0a78d8114ae305c35af0a50dab67a15 Mon Sep 17 00:00:00 2001 From: printfn Date: Tue, 5 Mar 2024 08:39:13 +0000 Subject: [PATCH 4/4] Minor fixes and more tests --- core/src/ast.rs | 7 +------ core/src/num/unit.rs | 15 ++++++++++++--- core/tests/integration_tests.rs | 4 ++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/ast.rs b/core/src/ast.rs index fa363273..f2cad338 100644 --- a/core/src/ast.rs +++ b/core/src/ast.rs @@ -474,12 +474,7 @@ pub(crate) fn evaluate( let lhs = evaluate(*a, scope.clone(), attrs, context, int)?; let rhs = evaluate(*b, scope.clone(), attrs, context, int)?; - if let (Value::Num(l), Value::Num(r)) = (&lhs, &rhs) { - let a = l.clone().sub(*r.clone(), int)?; - Value::Bool(if is_equals { a.is_zero() } else { !a.is_zero() }) - } else { - Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) - } + Value::Bool(if is_equals { lhs == rhs } else { lhs != rhs }) } }) } diff --git a/core/src/num/unit.rs b/core/src/num/unit.rs index a1746641..7d63e410 100644 --- a/core/src/num/unit.rs +++ b/core/src/num/unit.rs @@ -8,7 +8,7 @@ use crate::scope::Scope; use crate::serialize::{Deserialize, Serialize}; use crate::units::{lookup_default_unit, query_unit_static}; use crate::{ast, ident::Ident}; -use crate::{Attrs, Span, SpanKind}; +use crate::{interrupt::Never, Attrs, Span, SpanKind}; use std::borrow::Cow; use std::collections::HashMap; use std::ops::Neg; @@ -25,7 +25,7 @@ use unit_exponent::UnitExponent; use super::Exact; -#[derive(Clone, Eq)] +#[derive(Clone)] #[allow(clippy::pedantic)] pub(crate) struct Value { #[allow(clippy::struct_field_names)] @@ -39,10 +39,18 @@ pub(crate) struct Value { impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { - self.value == other.value && self.unit == other.unit + if self.value == other.value && self.unit == other.unit { + return true; + } + match self.clone().sub(other.clone(), &Never) { + Err(_) => false, + Ok(result) => result.is_zero(), + } } } +impl Eq for Value {} + impl Value { pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> { self.value.serialize(write)?; @@ -988,6 +996,7 @@ impl fmt::Display for FormattedValue { } } +// TODO: equality comparisons should not depend on order #[derive(Clone, PartialEq, Eq)] struct Unit { components: Vec, diff --git a/core/tests/integration_tests.rs b/core/tests/integration_tests.rs index 3eb7e7eb..432bdae3 100644 --- a/core/tests/integration_tests.rs +++ b/core/tests/integration_tests.rs @@ -5881,4 +5881,8 @@ fn test_equality() { test_eval("true == false", "false"); test_eval("true != false", "true"); test_eval("2m == 200cm", "true"); + test_eval("2kg == 200cm", "false"); + test_eval("2kg == true", "false"); + test_eval("2.010m == 200cm", "false"); + test_eval("2.000m == approx. 200cm", "true"); }