Skip to content

Commit

Permalink
sanitize infer type output
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankBro committed Jan 14, 2024
1 parent 0e2e3aa commit 9ce351f
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 34 deletions.
166 changes: 133 additions & 33 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use itertools::Itertools;

use crate::{
expr::{
Constraints, Expr, ExprAt, ExprIn, ExprTypedAt, Id, Level, Pattern, PatternAt,
PatternTypedAt, Type, TypeVar, OK_LABEL,
Constraints, Expr, ExprAt, ExprIn, ExprTyped, ExprTypedAt, Id, Level, Pattern, PatternAt,
PatternTypedAt, PositionContext, PositionTypeContext, Type, TypeVar, OK_LABEL,
},
parser::ForAll,
};
Expand Down Expand Up @@ -592,16 +592,92 @@ impl Env {

pub fn infer(&mut self, expr: ExprAt) -> Result<ExprTypedAt> {
let expr = self.infer_inner(0, expr)?;
let expr = self.wrapped(expr)?;
let mut expr = self.wrapped(expr)?;
self.generalize(-1, expr.ty())?;
self.sanitize_types(&mut expr)?;
Ok(expr)
}

// TODO: I got lazy here, make this more efficient
fn sanitize_types(&self, expr: &mut ExprTypedAt) -> Result<()> {
expr.context.ty.ty = self.real_ty(expr.context.ty.ty.clone())?;
match expr.expr.as_mut() {
Expr::Bool(_) => Ok(()),
Expr::Int(_) => Ok(()),
Expr::IntBinOp(_, a, b) => {
self.sanitize_types(a)?;
self.sanitize_types(b)
}
Expr::Negate(value) => self.sanitize_types(value),
Expr::EqualEqual(a, b) => {
self.sanitize_types(a)?;
self.sanitize_types(b)
}
Expr::Var(_) => Ok(()),
Expr::Call(fun, args) => {
for arg in args.iter_mut() {
self.sanitize_types(arg)?;
}
self.sanitize_types(fun)
}
Expr::Fun(params, body) => {
for param in params.iter_mut() {
self.sanitize_types(param)?;
}
self.sanitize_types(body)
}
Expr::Let(pattern, value, body) => {
self.sanitize_types(pattern)?;
self.sanitize_types(value)?;
self.sanitize_types(body)
}
Expr::RecordSelect(record, _) => self.sanitize_types(record),
Expr::RecordExtend(labels, rest) => {
for (_, expr) in labels.iter_mut() {
self.sanitize_types(expr)?;
}
self.sanitize_types(rest)
}
Expr::RecordRestrict(record, _) => self.sanitize_types(record),
Expr::RecordEmpty => Ok(()),
Expr::Variant(_, expr) => self.sanitize_types(expr),
Expr::Case(value, cases, def) => {
for (_, _, case) in cases.iter_mut() {
self.sanitize_types(case)?;
}
if let Some((_, expr)) = def {
self.sanitize_types(expr)?;
}
self.sanitize_types(value)
}
Expr::If(if_expr, if_body, elifs, else_body) => {
self.sanitize_types(if_expr)?;
self.sanitize_types(if_body)?;
for (elif_expr, elif_body) in elifs.iter_mut() {
self.sanitize_types(elif_expr)?;
self.sanitize_types(elif_body)?;
}
self.sanitize_types(else_body)
}
Expr::Unwrap(value) => self.sanitize_types(value),
}
}

fn expr_with_context(
&mut self,
expr: Expr<PositionTypeContext>,
context: PositionContext,
ty: Type,
) -> Result<ExprTypedAt> {
let ty = self.real_ty(ty)?;
Ok(expr.with(context, ty))
}

fn assign_pattern(&mut self, pattern: PatternAt, ty: Type) -> Result<PatternTypedAt> {
match (*pattern.expr, ty) {
(Pattern::Var(name), ty) => {
self.insert_var(name.clone(), ty.clone());
Ok(Pattern::Var(name).with(pattern.context, ty))
self.expr_with_context(Pattern::Var(name), pattern.context, ty)
}
(Pattern::RecordExtend(labels, rest), Type::Record(row)) => {
match *rest.expr {
Expand All @@ -620,9 +696,12 @@ impl Env {
}
}
}
let rest = Pattern::RecordEmpty.with(rest.context, rest_ty);
Ok(Pattern::RecordExtend(label_patterns, rest)
.with(pattern.context, Type::Record(row)))
let rest = self.expr_with_context(Pattern::RecordEmpty, rest.context, rest_ty)?;
self.expr_with_context(
Pattern::RecordExtend(label_patterns, rest),
pattern.context,
Type::Record(row),
)
}
(Pattern::RecordExtend(_, _), ty) => {
let ty = self.ty_to_string(&ty)?;
Expand All @@ -640,7 +719,7 @@ impl Env {
Pattern::Var(name) => {
let ty = self.new_unbound(level);
self.insert_var(name.clone(), ty.clone());
Ok(Pattern::Var(name).with(pattern.context, ty))
self.expr_with_context(Pattern::Var(name), pattern.context, ty)
}
Pattern::RecordExtend(labels, rest) => {
match *rest.expr {
Expand All @@ -657,8 +736,12 @@ impl Env {
}
let rest_ty = self.new_unbound_row(level, constraints);
let ty = Type::Record(Type::RowExtend(label_tys, rest_ty.clone().into()).into());
let rest = Pattern::RecordEmpty.with(rest.context, rest_ty);
Ok(Pattern::RecordExtend(label_exprs, rest).with(pattern.context, ty))
let rest = self.expr_with_context(Pattern::RecordEmpty, rest.context, rest_ty)?;
self.expr_with_context(
Pattern::RecordExtend(label_exprs, rest),
pattern.context,
ty,
)
}
_ => Err(Error::InvalidPattern(pattern.clone())),
}
Expand Down Expand Up @@ -693,33 +776,33 @@ impl Env {

fn infer_inner(&mut self, level: Level, expr: ExprAt) -> Result<ExprTypedAt> {
match *expr.expr {
Expr::Bool(b) => Ok(Expr::Bool(b).with(expr.context, Type::bool())),
Expr::Int(i) => Ok(Expr::Int(i).with(expr.context, Type::int())),
Expr::Bool(b) => self.expr_with_context(Expr::Bool(b), expr.context, Type::bool()),
Expr::Int(i) => self.expr_with_context(Expr::Int(i), expr.context, Type::int()),
Expr::IntBinOp(op, lhs, rhs) => {
let ty = Type::int();
let lhs = self.infer_inner(level, lhs)?;
self.unify(&ty, lhs.ty())?;
let rhs = self.infer_inner(level, rhs)?;
self.unify(&ty, rhs.ty())?;
let ty = op.output_ty();
Ok(Expr::IntBinOp(op, lhs, rhs).with(expr.context, ty))
self.expr_with_context(Expr::IntBinOp(op, lhs, rhs), expr.context, ty)
}
Expr::Negate(value) => {
let ty = Type::bool();
let value = self.infer_inner(level, value)?;
self.unify(&ty, value.ty())?;
Ok(Expr::Negate(value).with(expr.context, ty))
self.expr_with_context(Expr::Negate(value), expr.context, ty)
}
Expr::EqualEqual(lhs, rhs) => {
let lhs = self.infer_inner(level, lhs)?;
let rhs = self.infer_inner(level, rhs)?;
self.unify(lhs.ty(), rhs.ty())?;
Ok(Expr::EqualEqual(lhs, rhs).with(expr.context, Type::bool()))
self.expr_with_context(Expr::EqualEqual(lhs, rhs), expr.context, Type::bool())
}
Expr::Var(name) => {
let ty = self.get_var(&name)?.clone();
let ty = self.instantiate(level, ty)?;
Ok(Expr::Var(name).with(expr.context, ty))
self.expr_with_context(Expr::Var(name), expr.context, ty)
}
Expr::Fun(params, body) => {
let mut param_exprs = Vec::with_capacity(params.len());
Expand All @@ -735,15 +818,15 @@ impl Env {
self.vars = old_vars;
self.wrap = old_wrap;
let ty = Type::Arrow(param_tys, body.ty().clone().into());
Ok(Expr::Fun(param_exprs, body).with(expr.context, ty))
self.expr_with_context(Expr::Fun(param_exprs, body), expr.context, ty)
}
Expr::Let(pattern, value, body) => {
let value = self.infer_inner(level + 1, value)?;
self.generalize(level, value.ty())?;
let pattern = self.assign_pattern(pattern, value.ty().clone())?;
let body = self.infer_inner(level, body)?;
let ty = body.ty().clone();
Ok(Expr::Let(pattern, value, body).with(expr.context, ty))
self.expr_with_context(Expr::Let(pattern, value, body), expr.context, ty)
}
Expr::Call(fun, args) => {
let fun = self.infer_inner(level, fun)?;
Expand All @@ -755,11 +838,11 @@ impl Env {
self.unify(arg.ty(), param)?;
typed_args.push(arg);
}
Ok(Expr::Call(fun, typed_args).with(expr.context, *ret))
self.expr_with_context(Expr::Call(fun, typed_args), expr.context, *ret)
}
Expr::RecordEmpty => {
let ty = Type::Record(Type::RowEmpty.into());
Ok(Expr::RecordEmpty.with(expr.context, ty))
self.expr_with_context(Expr::RecordEmpty, expr.context, ty)
}
Expr::RecordSelect(record, label) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -774,7 +857,7 @@ impl Env {
let ret = field;
let record = self.infer_inner(level, record)?;
self.unify(&param, record.ty())?;
Ok(Expr::RecordSelect(record, label).with(expr.context, ret))
self.expr_with_context(Expr::RecordSelect(record, label), expr.context, ret)
}
Expr::RecordRestrict(record, label) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -789,7 +872,7 @@ impl Env {
let ret = Type::Record(rest.into());
let record = self.infer_inner(level, record)?;
self.unify(&param, record.ty())?;
Ok(Expr::RecordRestrict(record, label).with(expr.context, ret))
self.expr_with_context(Expr::RecordRestrict(record, label), expr.context, ret)
}
Expr::RecordExtend(labels, record) => {
let mut tys = BTreeMap::new();
Expand All @@ -804,7 +887,7 @@ impl Env {
let record = self.infer_inner(level, record)?;
self.unify(&Type::Record(rest.clone().into()), record.ty())?;
let ty = Type::Record(Type::RowExtend(tys, rest.into()).into());
Ok(Expr::RecordExtend(typed_labels, record).with(expr.context, ty))
self.expr_with_context(Expr::RecordExtend(typed_labels, record), expr.context, ty)
}
Expr::Variant(label, value) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -816,14 +899,14 @@ impl Env {
);
let value = self.infer_inner(level, value)?;
self.unify(&param, value.ty())?;
Ok(Expr::Variant(label, value).with(expr.context, ret))
self.expr_with_context(Expr::Variant(label, value), expr.context, ret)
}
Expr::Case(value, cases, None) => {
let ret = self.new_unbound(level);
let value = self.infer_inner(level, value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, Type::RowEmpty, cases)?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
Ok(Expr::Case(value, cases, None).with(expr.context, ret))
self.expr_with_context(Expr::Case(value, cases, None), expr.context, ret)
}
Expr::Case(value, cases, Some((def_var, def_expr))) => {
let constraints = cases.iter().map(|(label, _, _)| label).cloned().collect();
Expand All @@ -837,7 +920,11 @@ impl Env {
let value = self.infer_inner(level, value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, def_variant, cases)?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
Ok(Expr::Case(value, cases, Some((def_var, def_expr))).with(expr.context, ret))
self.expr_with_context(
Expr::Case(value, cases, Some((def_var, def_expr))),
expr.context,
ret,
)
}
Expr::Unwrap(value) => {
let value = self.infer_inner(level, value)?;
Expand All @@ -851,7 +938,7 @@ impl Env {
}
Some(ty) => {
self.wrap_with(labels)?;
Ok(Expr::Unwrap(value).with(expr.context, ty))
self.expr_with_context(Expr::Unwrap(value), expr.context, ty)
}
}
}
Expand Down Expand Up @@ -882,13 +969,17 @@ impl Env {
let (labels, _rest) = self.match_row_ty(&row)?;
let ty =
Type::Variant(Type::RowExtend(labels, Type::RowEmpty.into()).into());
Ok(Expr::If(if_expr, if_body, typed_elifs, else_body)
.with(expr.context, ty))
}
ty => {
Ok(Expr::If(if_expr, if_body, typed_elifs, else_body)
.with(expr.context, ty))
self.expr_with_context(
Expr::If(if_expr, if_body, typed_elifs, else_body),
expr.context,
ty,
)
}
ty => self.expr_with_context(
Expr::If(if_expr, if_body, typed_elifs, else_body),
expr.context,
ty,
),
}
}
}
Expand Down Expand Up @@ -982,6 +1073,15 @@ impl Env {
_ => Ok(ty),
}
}
Type::Arrow(params, ret) => {
let mut real_params = Vec::with_capacity(params.len());
for param in params {
let real_param = self.real_ty(param)?;
real_params.push(real_param);
}
let ret = self.real_ty(*ret)?;
Ok(Type::Arrow(real_params, ret.into()))
}
_ => Ok(ty),
}
}
Expand Down
45 changes: 44 additions & 1 deletion src/tests/typed_exprs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
core::make_env,
expr::{Expr, ExprIn, ExprTyped, Type, TypeContext},
expr::{Expr, ExprIn, ExprTyped, IntBinOp, Type, TypeContext},
parser::Parser,
};

Expand Down Expand Up @@ -32,10 +32,53 @@ fn let_(pattern: ExprTyped, value: ExprTyped, body: ExprTyped) -> ExprTyped {
expr(Expr::Let(pattern, value, body), ty)
}

fn fun(params: Vec<ExprTyped>, body: ExprTyped) -> ExprTyped {
let params_ty = params
.iter()
.map(|param| param.context.ty.clone())
.collect();
let ty = Type::Arrow(params_ty, body.context.ty.clone().into());
expr(Expr::Fun(params, body), ty)
}

fn add(a: ExprTyped, b: ExprTyped) -> ExprTyped {
let op = IntBinOp::Plus;
let ty = op.output_ty();
expr(Expr::IntBinOp(op, a, b), ty)
}

fn call(fun: ExprTyped, args: Vec<ExprTyped>) -> ExprTyped {
let ty = match fun.context.ty.clone() {
Type::Arrow(_, ret) => *ret,
_ => unreachable!(),
};
expr(Expr::Call(fun, args), ty)
}

#[test]
fn tests() {
pass(
"let x = 1 in x",
let_(var("x", Type::int()), int(1), var("x", Type::int())),
);
pass(
"let add(a, b) = a + b in add(1, 2)",
let_(
var(
"add",
Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()),
),
fun(
vec![var("a", Type::int()), var("b", Type::int())],
add(var("a", Type::int()), var("b", Type::int())),
),
call(
var(
"add",
Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()),
),
vec![int(1), int(2)],
),
),
);
}

0 comments on commit 9ce351f

Please sign in to comment.