diff --git a/src/infer.rs b/src/infer.rs index ae5fcc6..90a385a 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -592,11 +592,77 @@ impl Env { pub fn infer(&mut self, expr: ExprAt) -> Result { let expr = self.infer_inner(0, expr)?; - let expr = self.wrapped(expr)?; + let mut expr = self.wrapped(expr)?; self.generalize(-1, expr.ty())?; + self.sanitize_types(&mut expr)?; Ok(expr) } + fn sanitize_types(&self, expr: &mut ExprTypedAt) -> Result<()> { + // TODO: I got lazy here, make this more efficient + expr.context.ty.ty = self.real_ty(expr.context.ty.ty.clone())?; + match expr.expr.as_mut() { + Expr::Bool(_) => Ok(()), + Expr::Int(_) => Ok(()), + Expr::IntBinOp(_, a, b) => { + self.sanitize_types(a)?; + self.sanitize_types(b) + } + Expr::Negate(value) => self.sanitize_types(value), + Expr::EqualEqual(a, b) => { + self.sanitize_types(a)?; + self.sanitize_types(b) + } + Expr::Var(_) => Ok(()), + Expr::Call(fun, args) => { + for arg in args.iter_mut() { + self.sanitize_types(arg)?; + } + self.sanitize_types(fun) + } + Expr::Fun(params, body) => { + for param in params.iter_mut() { + self.sanitize_types(param)?; + } + self.sanitize_types(body) + } + Expr::Let(pattern, value, body) => { + self.sanitize_types(pattern)?; + self.sanitize_types(value)?; + self.sanitize_types(body) + } + Expr::RecordSelect(record, _) => self.sanitize_types(record), + Expr::RecordExtend(labels, rest) => { + for (_, expr) in labels.iter_mut() { + self.sanitize_types(expr)?; + } + self.sanitize_types(rest) + } + Expr::RecordRestrict(record, _) => self.sanitize_types(record), + Expr::RecordEmpty => Ok(()), + Expr::Variant(_, expr) => self.sanitize_types(expr), + Expr::Case(value, cases, def) => { + for (_, _, case) in cases.iter_mut() { + self.sanitize_types(case)?; + } + if let Some((_, expr)) = def { + self.sanitize_types(expr)?; + } + self.sanitize_types(value) + } + Expr::If(if_expr, if_body, elifs, else_body) => { + self.sanitize_types(if_expr)?; + self.sanitize_types(if_body)?; + for (elif_expr, elif_body) in elifs.iter_mut() { + self.sanitize_types(elif_expr)?; + self.sanitize_types(elif_body)?; + } + self.sanitize_types(else_body) + } + Expr::Unwrap(value) => self.sanitize_types(value), + } + } + fn assign_pattern(&mut self, pattern: PatternAt, ty: Type) -> Result { match (*pattern.expr, ty) { (Pattern::Var(name), ty) => { @@ -982,6 +1048,15 @@ impl Env { _ => Ok(ty), } } + Type::Arrow(params, ret) => { + let mut real_params = Vec::with_capacity(params.len()); + for param in params { + let real_param = self.real_ty(param)?; + real_params.push(real_param); + } + let ret = self.real_ty(*ret)?; + Ok(Type::Arrow(real_params, ret.into())) + } _ => Ok(ty), } } diff --git a/src/tests/typed_exprs.rs b/src/tests/typed_exprs.rs index c8e5ded..4e717d8 100644 --- a/src/tests/typed_exprs.rs +++ b/src/tests/typed_exprs.rs @@ -1,6 +1,6 @@ use crate::{ core::make_env, - expr::{Expr, ExprIn, ExprTyped, Type, TypeContext}, + expr::{Expr, ExprIn, ExprTyped, IntBinOp, Type, TypeContext}, parser::Parser, }; @@ -32,10 +32,53 @@ fn let_(pattern: ExprTyped, value: ExprTyped, body: ExprTyped) -> ExprTyped { expr(Expr::Let(pattern, value, body), ty) } +fn fun(params: Vec, body: ExprTyped) -> ExprTyped { + let params_ty = params + .iter() + .map(|param| param.context.ty.clone()) + .collect(); + let ty = Type::Arrow(params_ty, body.context.ty.clone().into()); + expr(Expr::Fun(params, body), ty) +} + +fn add(a: ExprTyped, b: ExprTyped) -> ExprTyped { + let op = IntBinOp::Plus; + let ty = op.output_ty(); + expr(Expr::IntBinOp(op, a, b), ty) +} + +fn call(fun: ExprTyped, args: Vec) -> ExprTyped { + let ty = match fun.context.ty.clone() { + Type::Arrow(_, ret) => *ret, + _ => unreachable!(), + }; + expr(Expr::Call(fun, args), ty) +} + #[test] fn tests() { pass( "let x = 1 in x", let_(var("x", Type::int()), int(1), var("x", Type::int())), ); + pass( + "let add(a, b) = a + b in add(1, 2)", + let_( + var( + "add", + Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()), + ), + fun( + vec![var("a", Type::int()), var("b", Type::int())], + add(var("a", Type::int()), var("b", Type::int())), + ), + call( + var( + "add", + Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()), + ), + vec![int(1), int(2)], + ), + ), + ); }