Skip to content

Commit

Permalink
Merge pull request #31 from FrankBro/refactor-context-ty-get
Browse files Browse the repository at this point in the history
refactor
  • Loading branch information
FrankBro authored Jan 10, 2024
2 parents 797327b + ec8e10e commit 41e7b42
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 31 deletions.
4 changes: 4 additions & 0 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ impl ExprIn<PositionTypeContext> {
expr: self.expr.strip_position(),
}
}

pub fn ty(&self) -> &Type {
&self.context.ty.ty
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down
62 changes: 31 additions & 31 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ impl Env {
pub fn infer(&mut self, expr: ExprAt) -> Result<ExprTypedAt> {
let expr = self.infer_inner(0, expr)?;
let expr = self.wrapped(expr)?;
self.generalize(-1, &expr.context.ty.ty)?;
self.generalize(-1, expr.ty())?;
Ok(expr)
}

Expand Down Expand Up @@ -652,7 +652,7 @@ impl Env {
let mut label_tys = BTreeMap::new();
for (label, pat) in labels {
let label_expr = self.infer_pattern(level, pat)?;
label_tys.insert(label.clone(), label_expr.context.ty.ty.clone());
label_tys.insert(label.clone(), label_expr.ty().clone());
label_exprs.insert(label, label_expr);
}
let rest_ty = self.new_unbound_row(level, constraints);
Expand Down Expand Up @@ -684,7 +684,7 @@ impl Env {
if labels.is_empty() {
return Ok(expr);
}
labels.insert(OK_LABEL.to_owned(), expr.context.ty.ty.clone());
labels.insert(OK_LABEL.to_owned(), expr.ty().clone());
let rest = Type::RowEmpty;
let ty = Type::Variant(Type::RowExtend(labels, rest.into()).into());
expr.context.ty.ty = ty;
Expand All @@ -698,22 +698,22 @@ impl Env {
Expr::IntBinOp(op, lhs, rhs) => {
let ty = Type::int();
let lhs = self.infer_inner(level, *lhs)?;
self.unify(&ty, &lhs.context.ty.ty)?;
self.unify(&ty, lhs.ty())?;
let rhs = self.infer_inner(level, *rhs)?;
self.unify(&ty, &rhs.context.ty.ty)?;
self.unify(&ty, rhs.ty())?;
let ty = op.output_ty();
Ok(Expr::IntBinOp(op, lhs.into(), rhs.into()).with(expr.context, ty))
}
Expr::Negate(value) => {
let ty = Type::bool();
let value = self.infer_inner(level, *value)?;
self.unify(&ty, &value.context.ty.ty)?;
self.unify(&ty, value.ty())?;
Ok(Expr::Negate(value.into()).with(expr.context, ty))
}
Expr::EqualEqual(lhs, rhs) => {
let lhs = self.infer_inner(level, *lhs)?;
let rhs = self.infer_inner(level, *rhs)?;
self.unify(&lhs.context.ty.ty, &rhs.context.ty.ty)?;
self.unify(lhs.ty(), rhs.ty())?;
Ok(Expr::EqualEqual(lhs.into(), rhs.into()).with(expr.context, Type::bool()))
}
Expr::Var(name) => {
Expand All @@ -728,31 +728,31 @@ impl Env {
let old_wrap = self.wrap.clone();
for param in params {
let param_expr = self.infer_pattern(level, param)?;
param_tys.push(param_expr.context.ty.ty.clone());
param_tys.push(param_expr.ty().clone());
param_exprs.push(param_expr);
}
let body = self.infer_inner(level, *body)?;
self.vars = old_vars;
self.wrap = old_wrap;
let ty = Type::Arrow(param_tys, body.context.ty.ty.clone().into());
let ty = Type::Arrow(param_tys, body.ty().clone().into());
Ok(Expr::Fun(param_exprs, body.into()).with(expr.context, ty))
}
Expr::Let(pattern, value, body) => {
let value = self.infer_inner(level + 1, *value)?;
self.generalize(level, &value.context.ty.ty)?;
let pattern = self.assign_pattern(*pattern, value.context.ty.ty.clone())?;
self.generalize(level, value.ty())?;
let pattern = self.assign_pattern(*pattern, value.ty().clone())?;
let body = self.infer_inner(level, *body)?;
let ty = body.context.ty.ty.clone();
let ty = body.ty().clone();
Ok(Expr::Let(pattern.into(), value.into(), body.into()).with(expr.context, ty))
}
Expr::Call(fun, args) => {
let fun = self.infer_inner(level, *fun)?;
let (params, ret) = self.match_fun_ty(args.len(), fun.context.ty.ty.clone())?;
let (params, ret) = self.match_fun_ty(args.len(), fun.ty().clone())?;
let mut typed_args = Vec::with_capacity(args.len());
for (i, arg) in args.into_iter().enumerate() {
let arg = self.infer_inner(level, arg)?;
let param = &params[i];
self.unify(&arg.context.ty.ty, param)?;
self.unify(arg.ty(), param)?;
typed_args.push(arg);
}
Ok(Expr::Call(fun.into(), typed_args).with(expr.context, *ret))
Expand All @@ -773,7 +773,7 @@ impl Env {
);
let ret = field;
let record = self.infer_inner(level, *record)?;
self.unify(&param, &record.context.ty.ty)?;
self.unify(&param, record.ty())?;
Ok(Expr::RecordSelect(record.into(), label).with(expr.context, ret))
}
Expr::RecordRestrict(record, label) => {
Expand All @@ -788,7 +788,7 @@ impl Env {
);
let ret = Type::Record(rest.into());
let record = self.infer_inner(level, *record)?;
self.unify(&param, &record.context.ty.ty)?;
self.unify(&param, record.ty())?;
Ok(Expr::RecordRestrict(record.into(), label).with(expr.context, ret))
}
Expr::RecordExtend(labels, record) => {
Expand All @@ -797,12 +797,12 @@ impl Env {
let mut typed_labels = BTreeMap::new();
for (label, expr) in labels {
let expr = self.infer_inner(level, expr)?;
tys.insert(label.clone(), expr.context.ty.ty.clone());
tys.insert(label.clone(), expr.ty().clone());
typed_labels.insert(label, expr);
}
let rest = self.new_unbound_row(level, constraints);
let record = self.infer_inner(level, *record)?;
self.unify(&Type::Record(rest.clone().into()), &record.context.ty.ty)?;
self.unify(&Type::Record(rest.clone().into()), record.ty())?;
let ty = Type::Record(Type::RowExtend(tys, rest.into()).into());
Ok(Expr::RecordExtend(typed_labels, record.into()).with(expr.context, ty))
}
Expand All @@ -815,14 +815,14 @@ impl Env {
.into(),
);
let value = self.infer_inner(level, *value)?;
self.unify(&param, &value.context.ty.ty)?;
self.unify(&param, value.ty())?;
Ok(Expr::Variant(label, value.into()).with(expr.context, ret))
}
Expr::Case(value, cases, None) => {
let ret = self.new_unbound(level);
let value = self.infer_inner(level, *value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, Type::RowEmpty, cases)?;
self.unify(&value.context.ty.ty, &Type::Variant(cases_row.into()))?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
Ok(Expr::Case(value.into(), cases, None).with(expr.context, ret))
}
Expr::Case(value, cases, Some((def_var, def_expr))) => {
Expand All @@ -832,24 +832,24 @@ impl Env {
self.vars
.insert(def_var.clone(), Type::Variant(def_variant.clone().into()));
let def_expr = self.infer_inner(level, *def_expr)?;
let ret = def_expr.context.ty.ty.clone();
let ret = def_expr.ty().clone();
self.vars = old_vars;
let value = self.infer_inner(level, *value)?;
let (cases_row, cases) = self.infer_cases(level, &ret, def_variant, cases)?;
self.unify(&value.context.ty.ty, &Type::Variant(cases_row.into()))?;
self.unify(value.ty(), &Type::Variant(cases_row.into()))?;
Ok(
Expr::Case(value.into(), cases, Some((def_var, def_expr.into())))
.with(expr.context, ret),
)
}
Expr::Unwrap(value) => {
let value = self.infer_inner(level, *value)?;
match &value.context.ty.ty {
match value.ty() {
Type::Variant(rows) => {
let (mut labels, _) = self.match_row_ty(rows)?;
match labels.remove(OK_LABEL) {
None => {
let ty = self.ty_to_string(&value.context.ty.ty)?;
let ty = self.ty_to_string(value.ty())?;
Err(Error::UnwrapMissingOk(ty))
}
Some(ty) => {
Expand All @@ -859,28 +859,28 @@ impl Env {
}
}
_ => {
let ty = self.ty_to_string(&value.context.ty.ty)?;
let ty = self.ty_to_string(value.ty())?;
Err(Error::UnwrapNotVariant(ty))
}
}
}
Expr::If(if_expr, if_body, elifs, else_body) => {
let bool = Type::bool();
let if_expr = self.infer_inner(level, *if_expr)?;
self.unify(&bool, &if_expr.context.ty.ty)?;
self.unify(&bool, if_expr.ty())?;
let if_body = self.infer_inner(level, *if_body)?;
let mut typed_elifs = Vec::with_capacity(elifs.len());
for (elif_expr, elif_body) in elifs {
let elif_expr = self.infer_inner(level, elif_expr)?;
self.unify(&bool, &elif_expr.context.ty.ty)?;
self.unify(&bool, elif_expr.ty())?;
let elif_body = self.infer_inner(level, elif_body)?;
self.unify(&if_body.context.ty.ty, &elif_body.context.ty.ty)?;
self.unify(if_body.ty(), elif_body.ty())?;
typed_elifs.push((elif_expr, elif_body));
}
let else_body = self.infer_inner(level, *else_body)?;
self.unify(&if_body.context.ty.ty, &else_body.context.ty.ty)?;
self.unify(if_body.ty(), else_body.ty())?;
// TODO: if calling a function with an open variant should keep it open
match if_body.context.ty.ty.clone() {
match if_body.ty().clone() {
Type::Variant(row) => {
let (labels, _rest) = self.match_row_ty(&row)?;
let ty =
Expand Down Expand Up @@ -921,7 +921,7 @@ impl Env {
self.vars.insert(var.clone(), variant.clone());
let case = self.infer_inner(level, case)?;
self.vars = old_vars;
self.unify(ret, &case.context.ty.ty)?;
self.unify(ret, case.ty())?;
labels.insert(label.clone(), variant);
typed_cases.push((label, var, case));
}
Expand Down

0 comments on commit 41e7b42

Please sign in to comment.