Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankBro committed Jan 14, 2024
1 parent 9ce351f commit d8bda35
Showing 1 changed file with 32 additions and 57 deletions.
89 changes: 32 additions & 57 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use itertools::Itertools;

use crate::{
expr::{
Constraints, Expr, ExprAt, ExprIn, ExprTyped, ExprTypedAt, Id, Level, Pattern, PatternAt,
PatternTypedAt, PositionContext, PositionTypeContext, Type, TypeVar, OK_LABEL,
Constraints, Expr, ExprAt, ExprIn, ExprTypedAt, Id, Level, Pattern, PatternAt,
PatternTypedAt, Type, TypeVar, OK_LABEL,
},
parser::ForAll,
};
Expand Down Expand Up @@ -663,21 +663,11 @@ impl Env {
}
}

fn expr_with_context(
&mut self,
expr: Expr<PositionTypeContext>,
context: PositionContext,
ty: Type,
) -> Result<ExprTypedAt> {
let ty = self.real_ty(ty)?;
Ok(expr.with(context, ty))
}

fn assign_pattern(&mut self, pattern: PatternAt, ty: Type) -> Result<PatternTypedAt> {
match (*pattern.expr, ty) {
(Pattern::Var(name), ty) => {
self.insert_var(name.clone(), ty.clone());
self.expr_with_context(Pattern::Var(name), pattern.context, ty)
Ok(Pattern::Var(name).with(pattern.context, ty))
}
(Pattern::RecordExtend(labels, rest), Type::Record(row)) => {
match *rest.expr {
Expand All @@ -696,12 +686,9 @@ impl Env {
}
}
}
let rest = self.expr_with_context(Pattern::RecordEmpty, rest.context, rest_ty)?;
self.expr_with_context(
Pattern::RecordExtend(label_patterns, rest),
pattern.context,
Type::Record(row),
)
let rest = Pattern::RecordEmpty.with(rest.context, rest_ty);
Ok(Pattern::RecordExtend(label_patterns, rest)
.with(pattern.context, Type::Record(row)))
}
(Pattern::RecordExtend(_, _), ty) => {
let ty = self.ty_to_string(&ty)?;
Expand All @@ -719,7 +706,7 @@ impl Env {
Pattern::Var(name) => {
let ty = self.new_unbound(level);
self.insert_var(name.clone(), ty.clone());
self.expr_with_context(Pattern::Var(name), pattern.context, ty)
Ok(Pattern::Var(name).with(pattern.context, ty))
}
Pattern::RecordExtend(labels, rest) => {
match *rest.expr {
Expand All @@ -736,12 +723,8 @@ impl Env {
}
let rest_ty = self.new_unbound_row(level, constraints);
let ty = Type::Record(Type::RowExtend(label_tys, rest_ty.clone().into()).into());
let rest = self.expr_with_context(Pattern::RecordEmpty, rest.context, rest_ty)?;
self.expr_with_context(
Pattern::RecordExtend(label_exprs, rest),
pattern.context,
ty,
)
let rest = Pattern::RecordEmpty.with(rest.context, rest_ty);
Ok(Pattern::RecordExtend(label_exprs, rest).with(pattern.context, ty))
}
_ => Err(Error::InvalidPattern(pattern.clone())),
}
Expand Down Expand Up @@ -776,33 +759,33 @@ impl Env {

fn infer_inner(&mut self, level: Level, expr: ExprAt) -> Result<ExprTypedAt> {
match *expr.expr {
Expr::Bool(b) => self.expr_with_context(Expr::Bool(b), expr.context, Type::bool()),
Expr::Int(i) => self.expr_with_context(Expr::Int(i), expr.context, Type::int()),
Expr::Bool(b) => Ok(Expr::Bool(b).with(expr.context, Type::bool())),
Expr::Int(i) => Ok(Expr::Int(i).with(expr.context, Type::int())),
Expr::IntBinOp(op, lhs, rhs) => {
let ty = Type::int();
let lhs = self.infer_inner(level, lhs)?;
self.unify(&ty, lhs.ty())?;
let rhs = self.infer_inner(level, rhs)?;
self.unify(&ty, rhs.ty())?;
let ty = op.output_ty();
self.expr_with_context(Expr::IntBinOp(op, lhs, rhs), expr.context, ty)
Ok(Expr::IntBinOp(op, lhs, rhs).with(expr.context, ty))
}
Expr::Negate(value) => {
let ty = Type::bool();
let value = self.infer_inner(level, value)?;
self.unify(&ty, value.ty())?;
self.expr_with_context(Expr::Negate(value), expr.context, ty)
Ok(Expr::Negate(value).with(expr.context, ty))
}
Expr::EqualEqual(lhs, rhs) => {
let lhs = self.infer_inner(level, lhs)?;
let rhs = self.infer_inner(level, rhs)?;
self.unify(lhs.ty(), rhs.ty())?;
self.expr_with_context(Expr::EqualEqual(lhs, rhs), expr.context, Type::bool())
Ok(Expr::EqualEqual(lhs, rhs).with(expr.context, Type::bool()))
}
Expr::Var(name) => {
let ty = self.get_var(&name)?.clone();
let ty = self.instantiate(level, ty)?;
self.expr_with_context(Expr::Var(name), expr.context, ty)
Ok(Expr::Var(name).with(expr.context, ty))
}
Expr::Fun(params, body) => {
let mut param_exprs = Vec::with_capacity(params.len());
Expand All @@ -818,15 +801,15 @@ impl Env {
self.vars = old_vars;
self.wrap = old_wrap;
let ty = Type::Arrow(param_tys, body.ty().clone().into());
self.expr_with_context(Expr::Fun(param_exprs, body), expr.context, ty)
Ok(Expr::Fun(param_exprs, body).with(expr.context, ty))
}
Expr::Let(pattern, value, body) => {
let value = self.infer_inner(level + 1, value)?;
self.generalize(level, value.ty())?;
let pattern = self.assign_pattern(pattern, value.ty().clone())?;
let body = self.infer_inner(level, body)?;
let ty = body.ty().clone();
self.expr_with_context(Expr::Let(pattern, value, body), expr.context, ty)
Ok(Expr::Let(pattern, value, body).with(expr.context, ty))
}
Expr::Call(fun, args) => {
let fun = self.infer_inner(level, fun)?;
Expand All @@ -838,11 +821,11 @@ impl Env {
self.unify(arg.ty(), param)?;
typed_args.push(arg);
}
self.expr_with_context(Expr::Call(fun, typed_args), expr.context, *ret)
Ok(Expr::Call(fun, typed_args).with(expr.context, *ret))
}
Expr::RecordEmpty => {
let ty = Type::Record(Type::RowEmpty.into());
self.expr_with_context(Expr::RecordEmpty, expr.context, ty)
Ok(Expr::RecordEmpty.with(expr.context, ty))
}
Expr::RecordSelect(record, label) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -857,7 +840,7 @@ impl Env {
let ret = field;
let record = self.infer_inner(level, record)?;
self.unify(&param, record.ty())?;
self.expr_with_context(Expr::RecordSelect(record, label), expr.context, ret)
Ok(Expr::RecordSelect(record, label).with(expr.context, ret))
}
Expr::RecordRestrict(record, label) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -872,7 +855,7 @@ impl Env {
let ret = Type::Record(rest.into());
let record = self.infer_inner(level, record)?;
self.unify(&param, record.ty())?;
self.expr_with_context(Expr::RecordRestrict(record, label), expr.context, ret)
Ok(Expr::RecordRestrict(record, label).with(expr.context, ret))
}
Expr::RecordExtend(labels, record) => {
let mut tys = BTreeMap::new();
Expand All @@ -887,7 +870,7 @@ impl Env {
let record = self.infer_inner(level, record)?;
self.unify(&Type::Record(rest.clone().into()), record.ty())?;
let ty = Type::Record(Type::RowExtend(tys, rest.into()).into());
self.expr_with_context(Expr::RecordExtend(typed_labels, record), expr.context, ty)
Ok(Expr::RecordExtend(typed_labels, record).with(expr.context, ty))
}
Expr::Variant(label, value) => {
let rest = self.new_unbound_row(level, Constraints::singleton(label.clone()));
Expand All @@ -899,14 +882,14 @@ impl Env {
);
let value = self.infer_inner(level, value)?;
self.unify(&param, value.ty())?;
self.expr_with_context(Expr::Variant(label, value), expr.context, ret)
Ok(Expr::Variant(label, value).with(expr.context, ret))
}
Expr::Case(value, cases, None) => {
let ret = self.new_unbound(level);
let value = self.infer_inner(level, value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, Type::RowEmpty, cases)?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
self.expr_with_context(Expr::Case(value, cases, None), expr.context, ret)
Ok(Expr::Case(value, cases, None).with(expr.context, ret))
}
Expr::Case(value, cases, Some((def_var, def_expr))) => {
let constraints = cases.iter().map(|(label, _, _)| label).cloned().collect();
Expand All @@ -920,11 +903,7 @@ impl Env {
let value = self.infer_inner(level, value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, def_variant, cases)?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
self.expr_with_context(
Expr::Case(value, cases, Some((def_var, def_expr))),
expr.context,
ret,
)
Ok(Expr::Case(value, cases, Some((def_var, def_expr))).with(expr.context, ret))
}
Expr::Unwrap(value) => {
let value = self.infer_inner(level, value)?;
Expand All @@ -938,7 +917,7 @@ impl Env {
}
Some(ty) => {
self.wrap_with(labels)?;
self.expr_with_context(Expr::Unwrap(value), expr.context, ty)
Ok(Expr::Unwrap(value).with(expr.context, ty))
}
}
}
Expand Down Expand Up @@ -969,17 +948,13 @@ impl Env {
let (labels, _rest) = self.match_row_ty(&row)?;
let ty =
Type::Variant(Type::RowExtend(labels, Type::RowEmpty.into()).into());
self.expr_with_context(
Expr::If(if_expr, if_body, typed_elifs, else_body),
expr.context,
ty,
)
Ok(Expr::If(if_expr, if_body, typed_elifs, else_body)
.with(expr.context, ty))
}
ty => {
Ok(Expr::If(if_expr, if_body, typed_elifs, else_body)
.with(expr.context, ty))
}
ty => self.expr_with_context(
Expr::If(if_expr, if_body, typed_elifs, else_body),
expr.context,
ty,
),
}
}
}
Expand Down

0 comments on commit d8bda35

Please sign in to comment.