From 7f570c28441cb44c3c028a5854212653f12f4206 Mon Sep 17 00:00:00 2001 From: Francois Brodeur Date: Thu, 30 Nov 2023 23:30:16 -0500 Subject: [PATCH] cmp --- src/eval.rs | 18 ++++++++++++------ src/expr.rs | 24 ++++++++++++++++++++++++ src/infer.rs | 5 +++-- src/lexer.rs | 12 ++++++++++++ src/parser.rs | 41 +++++++++++++++++++++++++++++++++++------ 5 files changed, 86 insertions(+), 14 deletions(-) diff --git a/src/eval.rs b/src/eval.rs index b22d221..89dfd49 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -178,13 +178,17 @@ impl Env { let lhs = lhs.as_int()?; let rhs = self.eval_inner(rhs)?; let rhs = rhs.as_int()?; - let i = match op { - IntBinOp::Plus => lhs + rhs, - IntBinOp::Minus => lhs - rhs, - IntBinOp::Multiply => lhs * rhs, - IntBinOp::Divide => lhs / rhs, + let val = match op { + IntBinOp::Plus => Value::Int(lhs + rhs), + IntBinOp::Minus => Value::Int(lhs - rhs), + IntBinOp::Multiply => Value::Int(lhs * rhs), + IntBinOp::Divide => Value::Int(lhs / rhs), + IntBinOp::LessThan => Value::Bool(lhs < rhs), + IntBinOp::LessThanOrEqual => Value::Bool(lhs <= rhs), + IntBinOp::GreaterThan => Value::Bool(lhs > rhs), + IntBinOp::GreaterThanOrEqual => Value::Bool(lhs >= rhs), }; - Ok(Value::Int(i)) + Ok(val) } Expr::Negate(expr) => { let v = self.eval_inner(expr)?; @@ -308,6 +312,8 @@ mod tests { "let f({x,y}) = x + y in let x = 1 in let y = 2 in f({x,y})", Value::Int(3), ), + ("2 > 1", Value::Bool(true)), + ("2 > 3", Value::Bool(false)), ]; for (expr_str, expected) in cases { let expr = Parser::expr(expr_str).unwrap(); diff --git a/src/expr.rs b/src/expr.rs index b417a4a..ec06ba6 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -11,6 +11,10 @@ pub enum IntBinOp { Minus, Multiply, Divide, + LessThan, + LessThanOrEqual, + GreaterThan, + GreaterThanOrEqual, } impl fmt::Display for IntBinOp { @@ -20,11 +24,27 @@ impl fmt::Display for IntBinOp { IntBinOp::Minus => "-", IntBinOp::Multiply => "*", IntBinOp::Divide => "/", + IntBinOp::LessThan => "<", + IntBinOp::LessThanOrEqual => "<=", + IntBinOp::GreaterThan => ">", + IntBinOp::GreaterThanOrEqual => ">=", }; write!(f, "{}", op) } } +impl IntBinOp { + pub fn output_ty(&self) -> Type { + match self { + IntBinOp::Plus | IntBinOp::Minus | IntBinOp::Multiply | IntBinOp::Divide => Type::int(), + IntBinOp::LessThan + | IntBinOp::LessThanOrEqual + | IntBinOp::GreaterThan + | IntBinOp::GreaterThanOrEqual => Type::bool(), + } + } +} + #[derive(Clone, Debug, PartialEq)] pub enum Pattern { Var(String), @@ -280,6 +300,10 @@ pub mod util { Expr::EqualEqual(lhs.into(), rhs.into()) } + pub fn gt(lhs: Expr, rhs: Expr) -> Expr { + Expr::IntBinOp(IntBinOp::GreaterThan, lhs.into(), rhs.into()) + } + pub fn match_(val: Expr, cases: Vec<(&str, &str, Expr)>, def: Option<(&str, Expr)>) -> Expr { let cases = cases .into_iter() diff --git a/src/infer.rs b/src/infer.rs index feb3e00..244add4 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -619,13 +619,13 @@ impl Env { match expr { Expr::Bool(_) => Ok(Type::bool()), Expr::Int(_) => Ok(Type::int()), - Expr::IntBinOp(_, lhs, rhs) => { + Expr::IntBinOp(op, lhs, rhs) => { let ty = Type::int(); let lhs_ty = self.infer_inner(level, lhs)?; self.unify(&ty, &lhs_ty)?; let rhs_ty = self.infer_inner(level, rhs)?; self.unify(&ty, &rhs_ty)?; - Ok(ty) + Ok(op.output_ty()) } Expr::Negate(expr) => { let ty = Type::bool(); @@ -1155,6 +1155,7 @@ mod tests { "forall a b => (a -> b) -> a -> b", ); pass("1 == 1", "bool"); + pass("1 > 2", "bool"); } #[test] diff --git a/src/lexer.rs b/src/lexer.rs index d23060e..99003a6 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -86,6 +86,14 @@ pub enum Token { Elif, #[token("else")] Else, + #[token("<")] + LessThan, + #[token("<=")] + LessThanOrEqual, + #[token(">")] + GreaterThan, + #[token(">=")] + GreaterThanOrEqual, #[regex("[a-zA-Z_][a-zA-Z0-9_]*", ident)] Ident(String), } @@ -132,6 +140,10 @@ impl fmt::Display for Token { Token::Then => "then", Token::Elif => "elif", Token::Else => "else", + Token::LessThan => "<", + Token::LessThanOrEqual => "<=", + Token::GreaterThan => ">", + Token::GreaterThanOrEqual => ">=", }; write!(f, "{}", s) } diff --git a/src/parser.rs b/src/parser.rs index 6cfc78f..29c174a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -439,6 +439,30 @@ impl<'a> Parser<'a> { } else if self.matches(Token::EqualEqual)? { let rhs = self.expr_inner(r_bp)?; Ok(Expr::EqualEqual(lhs.into(), rhs.into())) + } else if self.matches(Token::LessThan)? { + let rhs = self.expr_inner(r_bp)?; + Ok(Expr::IntBinOp(IntBinOp::LessThan, lhs.into(), rhs.into())) + } else if self.matches(Token::LessThanOrEqual)? { + let rhs = self.expr_inner(r_bp)?; + Ok(Expr::IntBinOp( + IntBinOp::LessThanOrEqual, + lhs.into(), + rhs.into(), + )) + } else if self.matches(Token::GreaterThan)? { + let rhs = self.expr_inner(r_bp)?; + Ok(Expr::IntBinOp( + IntBinOp::GreaterThan, + lhs.into(), + rhs.into(), + )) + } else if self.matches(Token::GreaterThanOrEqual)? { + let rhs = self.expr_inner(r_bp)?; + Ok(Expr::IntBinOp( + IntBinOp::GreaterThanOrEqual, + lhs.into(), + rhs.into(), + )) } else { self.expected( vec![ @@ -480,7 +504,7 @@ impl<'a> Parser<'a> { fn prefix_bp(&self) -> Result { match self.token { - Some(Token::Negate) => Ok(7), + Some(Token::Negate) => Ok(9), None => Err(Error::UnexpectedEof), _ => Err(Error::InvalidPrefix(self.token.clone())), } @@ -489,9 +513,9 @@ impl<'a> Parser<'a> { // TODO: Not sure how to determine precedence fn postfix_bp(&self) -> Option { match self.token { - Some(Token::LParen) => Some(8), - Some(Token::Dot) => Some(10), - Some(Token::Backslash) => Some(9), + Some(Token::LParen) => Some(10), + Some(Token::Dot) => Some(12), + Some(Token::Backslash) => Some(11), _ => None, } } @@ -499,8 +523,12 @@ impl<'a> Parser<'a> { fn infix_bp(&self) -> Option<(u8, u8)> { match self.token { Some(Token::EqualEqual) => Some((2, 1)), - Some(Token::Plus) | Some(Token::Minus) => Some((3, 4)), - Some(Token::Multiply) | Some(Token::Divide) => Some((5, 6)), + Some(Token::Plus) | Some(Token::Minus) => Some((5, 6)), + Some(Token::Multiply) | Some(Token::Divide) => Some((7, 8)), + Some(Token::LessThan) + | Some(Token::LessThanOrEqual) + | Some(Token::GreaterThan) + | Some(Token::GreaterThanOrEqual) => Some((3, 4)), _ => None, } } @@ -731,6 +759,7 @@ mod tests { "false == !true", equalequal(bool(false), negate(bool(true))), ); + pass("1 + 2 > 2", gt(plus(int(1), int(2)), int(2))); } #[test]