diff --git a/src/eval.rs b/src/eval.rs index 89dfd49..12c4c0a 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashMap}; use itertools::Itertools; -use crate::expr::{Expr, IntBinOp, Pattern}; +use crate::expr::{Expr, IntBinOp, Pattern, OK_LABEL}; #[derive(Debug)] pub enum Error { @@ -16,6 +16,7 @@ pub enum Error { UnexpectedNumberOfArguments, LabelNotFound(String), NoCase, + UnwrapNotVariant(Value), } impl fmt::Display for Error { @@ -30,6 +31,7 @@ impl fmt::Display for Error { Error::UnexpectedNumberOfArguments => write!(f, "unexpected number of arguments"), Error::LabelNotFound(label) => write!(f, "label not found: {}", label), Error::NoCase => write!(f, "no case"), + Error::UnwrapNotVariant(val) => write!(f, "unwrap on a non-variant: {}", val), } } } @@ -50,7 +52,7 @@ impl fmt::Display for Function { } impl Function { - fn apply(&mut self, args: Vec) -> Result { + fn apply(&mut self, args: Vec) -> Result { if self.params.len() != args.len() { return Err(Error::UnexpectedNumberOfArguments); } @@ -136,6 +138,20 @@ impl Value { } } +enum Wrap { + Value(Value), + Wrap(String, Box), +} + +impl Wrap { + fn value(self) -> Value { + match self { + Wrap::Value(val) => val, + Wrap::Wrap(name, val) => Value::Variant(name, val), + } + } +} + #[derive(Clone, Debug, Default, PartialEq)] pub struct Env { vars: HashMap, @@ -143,7 +159,8 @@ pub struct Env { impl Env { pub fn eval(&mut self, expr: &Expr) -> Result { - self.eval_inner(expr) + let wrap = self.eval_inner(expr)?; + Ok(wrap.value()) } fn eval_pattern(&mut self, pattern: &Pattern, value: Value) -> Result<()> { @@ -169,15 +186,19 @@ impl Env { Ok(()) } - fn eval_inner(&mut self, expr: &Expr) -> Result { + fn eval_inner(&mut self, expr: &Expr) -> Result { match expr { - Expr::Bool(b) => Ok(Value::Bool(*b)), - Expr::Int(i) => Ok(Value::Int(*i)), + Expr::Bool(b) => Ok(Wrap::Value(Value::Bool(*b))), + Expr::Int(i) => Ok(Wrap::Value(Value::Int(*i))), Expr::IntBinOp(op, lhs, rhs) => { - let lhs = self.eval_inner(lhs)?; - let lhs = lhs.as_int()?; - let rhs = self.eval_inner(rhs)?; - let rhs = rhs.as_int()?; + let lhs = match self.eval_inner(lhs)? { + Wrap::Value(value) => value.as_int()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + let rhs = match self.eval_inner(rhs)? { + Wrap::Value(value) => value.as_int()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; let val = match op { IntBinOp::Plus => Value::Int(lhs + rhs), IntBinOp::Minus => Value::Int(lhs - rhs), @@ -188,78 +209,113 @@ impl Env { IntBinOp::GreaterThan => Value::Bool(lhs > rhs), IntBinOp::GreaterThanOrEqual => Value::Bool(lhs >= rhs), }; - Ok(val) + Ok(Wrap::Value(val)) } Expr::Negate(expr) => { - let v = self.eval_inner(expr)?; - let b = v.as_bool()?; - Ok(Value::Bool(!b)) + let b = match self.eval_inner(expr)? { + Wrap::Value(value) => value.as_bool()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + Ok(Wrap::Value(Value::Bool(!b))) } Expr::EqualEqual(lhs, rhs) => { - let lhs = self.eval_inner(lhs)?; - let rhs = self.eval_inner(rhs)?; - Ok(Value::Bool(lhs == rhs)) + let lhs = match self.eval_inner(lhs)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + let rhs = match self.eval_inner(rhs)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + Ok(Wrap::Value(Value::Bool(lhs == rhs))) } Expr::Var(s) => { let value = self .vars .get(s) .ok_or_else(|| Error::VarNotFound(s.clone()))?; - Ok(value.clone()) + Ok(Wrap::Value(value.clone())) } Expr::Call(fun, args) => { - let fun = self.eval_inner(fun)?; - let mut fun = fun.as_function()?; - let args = args - .iter() - .map(|arg| self.eval_inner(arg)) - .collect::>()?; - fun.apply(args) + let mut fun = match self.eval_inner(fun)? { + Wrap::Value(value) => value.as_function()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + let mut values = Vec::with_capacity(args.len()); + for arg in args { + let value = match self.eval_inner(arg)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + values.push(value); + } + fun.apply(values) } Expr::Fun(params, body) => { let env = self.clone(); let params = params.clone(); let body = *body.clone(); let fun = Function { env, params, body }; - Ok(Value::Function(fun)) + Ok(Wrap::Value(Value::Function(fun))) } Expr::Let(var, val, body) => { - let val = self.eval_inner(val)?; + let val = match self.eval_inner(val)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; self.eval_pattern(var, val)?; self.eval_inner(body) } Expr::RecordSelect(record, label) => { - let record = self.eval_inner(record)?; + let record = match self.eval_inner(record)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; let record = record.as_record()?; let val = record .get(label) .ok_or_else(|| Error::LabelNotFound(label.clone()))?; - Ok(val.clone()) + Ok(Wrap::Value(val.clone())) } Expr::RecordExtend(labels, record) => { - let record = self.eval_inner(record)?; + let record = match self.eval_inner(record)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; let record = record.as_record()?; let mut record = record.clone(); for (label, expr) in labels { - let value = self.eval_inner(expr)?; + let value = match self.eval_inner(expr)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; record.insert(label.clone(), value); } - Ok(Value::Record(record)) + Ok(Wrap::Value(Value::Record(record))) } Expr::RecordRestrict(record, label) => { - let record = self.eval_inner(record)?; + let record = match self.eval_inner(record)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; let record = record.as_record()?; let mut record = record.clone(); record.remove(label); - Ok(Value::Record(record)) + Ok(Wrap::Value(Value::Record(record))) } - Expr::RecordEmpty => Ok(Value::Record(BTreeMap::new())), + Expr::RecordEmpty => Ok(Wrap::Value(Value::Record(BTreeMap::new()))), Expr::Variant(label, expr) => { - let value = self.eval_inner(expr)?; - Ok(Value::Variant(label.clone(), value.into())) + let value = match self.eval_inner(expr)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + Ok(Wrap::Value(Value::Variant(label.clone(), value.into()))) } Expr::Case(value, cases, def) => { - let value = self.eval_inner(value)?; + let value = match self.eval_inner(value)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; let (label, value) = value.as_variant()?; for (case_label, case_var, case_body) in cases { if label == case_label { @@ -279,16 +335,35 @@ impl Env { Err(Error::NoCase) } Expr::If(if_expr, if_body, elifs, else_body) => { - if self.eval_inner(if_expr)?.as_bool()? { + let b = match self.eval_inner(if_expr)? { + Wrap::Value(value) => value.as_bool()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + if b { return self.eval_inner(if_body); } for (elif_expr, elif_body) in elifs { - if self.eval_inner(elif_expr)?.as_bool()? { + let b = match self.eval_inner(elif_expr)? { + Wrap::Value(value) => value.as_bool()?, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + if b { return self.eval_inner(elif_body); } } self.eval_inner(else_body) } + Expr::Unwrap(expr) => { + let val = match self.eval_inner(expr)? { + Wrap::Value(value) => value, + Wrap::Wrap(name, val) => return Ok(Wrap::Value(Value::Variant(name, val))), + }; + match val { + Value::Variant(name, val) if name == OK_LABEL => Ok(Wrap::Value(*val)), + Value::Variant(name, val) => Ok(Wrap::Wrap(name, val)), + val => Err(Error::UnwrapNotVariant(val)), + } + } } } } diff --git a/src/expr.rs b/src/expr.rs index ec06ba6..7676e95 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -5,6 +5,8 @@ use std::{ use itertools::Itertools; +pub const OK_LABEL: &str = "ok"; + #[derive(Clone, Debug, PartialEq)] pub enum IntBinOp { Plus, @@ -103,6 +105,7 @@ pub enum Expr { Option<(String, Box)>, ), If(Box, Box, Vec<(Expr, Expr)>, Box), + Unwrap(Box), } impl fmt::Display for Expr { @@ -139,6 +142,7 @@ impl fmt::Display for Expr { Expr::Variant(label, value) => write!(f, ":{} {}", label, value), Expr::Case(_, _, _) => todo!(), Expr::If(_, _, _, _) => todo!(), + Expr::Unwrap(expr) => write!(f, "{}?", expr), } } } diff --git a/src/infer.rs b/src/infer.rs index 244add4..4028d01 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -1,10 +1,10 @@ use core::fmt; -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap}; use itertools::Itertools; use crate::{ - expr::{Constraints, Expr, Id, Level, Pattern, Type, TypeVar}, + expr::{Constraints, Expr, Id, Level, Pattern, Type, TypeVar, OK_LABEL}, parser::ForAll, }; @@ -17,11 +17,13 @@ pub enum Error { NotARow(String), RecursiveRowType, UnexpectedNumberOfArguments, - ExpectedFunction, + ExpectedFunction(String), VariableNotFound(String), CannotInjectConstraintsInto(String), RowConstraintFailed(String), RecordPatternNotRecord(String), + UnwrapMissingOk(String), + UnwrapNotVariant(String), } impl fmt::Display for Error { @@ -38,7 +40,7 @@ impl fmt::Display for Error { } Error::RecursiveRowType => write!(f, "recursive row type"), Error::UnexpectedNumberOfArguments => write!(f, "unexpected number of arguments"), - Error::ExpectedFunction => write!(f, "expected a function"), + Error::ExpectedFunction(ty) => write!(f, "expected a function, got: {}", ty), Error::VariableNotFound(var) => write!(f, "variable not found: {}", var), Error::CannotInjectConstraintsInto(ty) => { write!(f, "cannot inject constraints into: {}", ty) @@ -47,6 +49,10 @@ impl fmt::Display for Error { write!(f, "row constraint failed for label: {}", label) } Error::RecordPatternNotRecord(ty) => write!(f, "record pattern not a record: {}", ty), + Error::UnwrapMissingOk(ty) => write!(f, "unwrap missing ok case, type: {}", ty), + Error::UnwrapNotVariant(ty) => { + write!(f, "unwrap on something other than variant: {}", ty) + } } } } @@ -81,6 +87,7 @@ impl BTreeMapExt for BTreeMap { pub struct Env { pub vars: HashMap, pub type_vars: Vec, + wrap: BTreeMap, } impl Env { @@ -549,11 +556,15 @@ impl Env { } TypeVar::Link(ty) => self.match_fun_ty(num_params, ty), TypeVar::UnboundRow(_, _) | TypeVar::Generic | TypeVar::GenericRow(_) => { - Err(Error::ExpectedFunction) + let ty = self.ty_to_string(&ty)?; + Err(Error::ExpectedFunction(ty)) } } } - _ => Err(Error::ExpectedFunction), + _ => { + let ty = self.ty_to_string(&ty)?; + Err(Error::ExpectedFunction(ty)) + } } } } @@ -572,6 +583,7 @@ impl Env { pub fn infer(&mut self, expr: &Expr) -> Result { let ty = self.infer_inner(0, expr)?; + let ty = self.wrapped(ty)?; self.generalize(-1, &ty)?; Ok(ty) } @@ -615,6 +627,31 @@ impl Env { } } + fn wrap_with(&mut self, labels: BTreeMap) -> Result<()> { + for (label, ty) in labels { + match self.wrap.entry(label) { + Entry::Vacant(v) => { + v.insert(ty); + } + Entry::Occupied(o) => { + let old_ty = o.get().clone(); + self.unify(&old_ty, &ty)?; + } + } + } + Ok(()) + } + + fn wrapped(&mut self, ty: Type) -> Result { + let mut labels = std::mem::take(&mut self.wrap); + if labels.is_empty() { + return Ok(ty); + } + labels.insert(OK_LABEL.to_owned(), ty); + let rest = Type::RowEmpty; + Ok(Type::Variant(Type::RowExtend(labels, rest.into()).into())) + } + fn infer_inner(&mut self, level: Level, expr: &Expr) -> Result { match expr { Expr::Bool(_) => Ok(Type::bool()), @@ -641,24 +678,28 @@ impl Env { } Expr::Var(name) => { let ty = self.get_var(name)?.clone(); - self.instantiate(level, ty) + let ty = self.instantiate(level, ty)?; + Ok(ty) } Expr::Fun(params, body) => { let mut param_tys = Vec::with_capacity(params.len()); let old_vars = self.vars.clone(); + let old_wrap = self.wrap.clone(); for param in params { let param_ty = self.infer_pattern(level, param); param_tys.push(param_ty); } let ret_ty = self.infer_inner(level, body)?; self.vars = old_vars; + self.wrap = old_wrap; Ok(Type::Arrow(param_tys, ret_ty.into())) } Expr::Let(pattern, value, body) => { let var_ty = self.infer_inner(level + 1, value)?; self.generalize(level, &var_ty)?; self.assign_pattern(pattern, var_ty)?; - self.infer_inner(level, body) + let ty = self.infer_inner(level, body)?; + Ok(ty) } Expr::Call(f, args) => { let f_ty = self.infer_inner(level, f)?; @@ -746,6 +787,28 @@ impl Env { self.unify(&expr, &Type::Variant(cases_row.into()))?; Ok(ret) } + Expr::Unwrap(expr) => { + let ty = self.infer_inner(level, expr)?; + match &ty { + Type::Variant(rows) => { + let (mut labels, _) = self.match_row_ty(rows)?; + match labels.remove(OK_LABEL) { + None => { + let ty = self.ty_to_string(&ty)?; + Err(Error::UnwrapMissingOk(ty)) + } + Some(ty) => { + self.wrap_with(labels)?; + Ok(ty) + } + } + } + _ => { + let ty = self.ty_to_string(&ty)?; + Err(Error::UnwrapNotVariant(ty)) + } + } + } Expr::If(if_expr, if_body, elifs, else_body) => { let bool = Type::bool(); let ty = self.infer_inner(level, if_expr)?; @@ -1137,7 +1200,7 @@ mod tests { "forall a b => a -> b -> b", ); fail("fun(x) -> x(x)", Error::RecursiveType); - fail("one(id)", Error::ExpectedFunction); + fail("one(id)", Error::ExpectedFunction("int".to_owned())); pass( "fun(f) -> let x = fun(g, y) -> let _ = g(y) in eq(f, g) in x", "forall a b => (a -> b) -> (a -> b, a) -> bool", diff --git a/src/lexer.rs b/src/lexer.rs index 99003a6..edd40eb 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -94,6 +94,8 @@ pub enum Token { GreaterThan, #[token(">=")] GreaterThanOrEqual, + #[token("?")] + QuestionMark, #[regex("[a-zA-Z_][a-zA-Z0-9_]*", ident)] Ident(String), } @@ -144,6 +146,7 @@ impl fmt::Display for Token { Token::LessThanOrEqual => "<=", Token::GreaterThan => ">", Token::GreaterThanOrEqual => ">=", + Token::QuestionMark => "?", }; write!(f, "{}", s) } diff --git a/src/main.rs b/src/main.rs index c1a2c36..9a0c12c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,12 @@ use std::{ use parser::Parser; -mod core; -mod eval; -mod expr; -mod infer; -mod lexer; -mod parser; +pub mod core; +pub mod eval; +pub mod expr; +pub mod infer; +pub mod lexer; +pub mod parser; #[derive(Debug)] enum Error { @@ -158,3 +158,7 @@ fn readme_test() { #[cfg(test)] #[path = "tests/ifs.rs"] mod ifs; + +#[cfg(test)] +#[path = "tests/unwraps.rs"] +mod unwraps; diff --git a/src/parser.rs b/src/parser.rs index 29c174a..c7ffb94 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -408,6 +408,10 @@ impl<'a> Parser<'a> { Ok(Expr::RecordRestrict(lhs.into(), field)) } + fn unwrap_expr(&mut self, lhs: Expr) -> Result { + Ok(Expr::Unwrap(lhs.into())) + } + fn expr_postfix(&mut self, lhs: Expr) -> Result { if self.matches(Token::LParen)? { self.call_expr(lhs) @@ -415,6 +419,8 @@ impl<'a> Parser<'a> { self.record_select_expr(lhs) } else if self.matches(Token::Backslash)? { self.record_restrict_expr(lhs) + } else if self.matches(Token::QuestionMark)? { + self.unwrap_expr(lhs) } else { self.expected( vec![Token::LParen, Token::Dot, Token::Backslash], @@ -504,7 +510,7 @@ impl<'a> Parser<'a> { fn prefix_bp(&self) -> Result { match self.token { - Some(Token::Negate) => Ok(9), + Some(Token::Negate) => Ok(10), None => Err(Error::UnexpectedEof), _ => Err(Error::InvalidPrefix(self.token.clone())), } @@ -513,9 +519,10 @@ impl<'a> Parser<'a> { // TODO: Not sure how to determine precedence fn postfix_bp(&self) -> Option { match self.token { - Some(Token::LParen) => Some(10), - Some(Token::Dot) => Some(12), - Some(Token::Backslash) => Some(11), + Some(Token::LParen) => Some(11), + Some(Token::Dot) => Some(13), + Some(Token::Backslash) => Some(12), + Some(Token::QuestionMark) => Some(9), _ => None, } } diff --git a/src/tests/unwraps.rs b/src/tests/unwraps.rs new file mode 100644 index 0000000..712efc1 --- /dev/null +++ b/src/tests/unwraps.rs @@ -0,0 +1,86 @@ +use std::collections::BTreeMap; + +use crate::{ + eval::{self, Value}, + infer, + parser::Parser, +}; + +struct Env { + infer: infer::Env, + eval: eval::Env, +} + +impl Env { + fn new() -> Self { + let infer = infer::Env::default(); + let eval = eval::Env::default(); + let mut env = Self { infer, eval }; + env.add( + "let safe_div(n, d) = if d == 0 then :div_by_zero {} else :ok (n / d)", + "(int, int) -> [div_by_zero: {}, ok: int]", + ); + env.add( + "let safe_minus(x, y) = if y < 0 then :would_add {} else :ok (x - y)", + "(int, int) -> [ok: int, would_add: {}]", + ); + env + } + + fn add(&mut self, source: &str, source_ty: &str) { + let (forall, ty) = Parser::ty(source_ty).unwrap(); + let expected_ty = self.infer.replace_ty_constants_with_vars(forall, ty); + let expr = Parser::repl(source).unwrap(); + let actual_ty = self.infer.infer(&expr).unwrap(); + let expected_ty = self.infer.ty_to_string(&expected_ty).unwrap(); + let actual_ty = self.infer.ty_to_string(&actual_ty).unwrap(); + assert_eq!(expected_ty, actual_ty); + let _ = self.eval.eval(&expr).unwrap(); + } +} + +#[track_caller] +fn pass(source: &str, source_ty: &str, expected_val: Value) { + let mut env = Env::new(); + let (forall, ty) = Parser::ty(source_ty).unwrap(); + let expected_ty = env.infer.replace_ty_constants_with_vars(forall, ty); + let expr = Parser::expr(source).unwrap(); + let actual_ty = env.infer.infer(&expr).unwrap(); + let expected_ty = env.infer.ty_to_string(&expected_ty).unwrap(); + let actual_ty = env.infer.ty_to_string(&actual_ty).unwrap(); + assert_eq!(expected_ty, actual_ty); + let actual_val = env.eval.eval(&expr).unwrap(); + assert_eq!(expected_val, actual_val); +} + +#[test] +fn unwrap_ok() { + pass( + "safe_div(2, 2)?", + "[div_by_zero: {}, ok: int]", + Value::Int(1), + ); + pass( + "safe_div(2, 0)?", + "[div_by_zero: {}, ok: int]", + Value::Variant( + "div_by_zero".to_owned(), + Value::Record(BTreeMap::new()).into(), + ), + ); + pass( + "let x = safe_div(2, 2)? in x", + "[div_by_zero: {}, ok: int]", + Value::Int(1), + ); + pass( + "let x = safe_div(2, 0)? in x", + "[div_by_zero: {}, ok: int]", + Value::Variant("div_by_zero".to_owned(), Value::record(Vec::new()).into()), + ); + pass( + "let x = safe_div(2, 2)? in let y = safe_minus(x, 1)? in y", + "[div_by_zero: {}, ok: int, would_add: {}]", + Value::Int(0), + ); +}