Skip to content

Commit

Permalink
Merge pull request #33 from FrankBro/sanitize-infer-type-output
Browse files Browse the repository at this point in the history
sanitize infer type output
  • Loading branch information
FrankBro authored Jan 14, 2024
2 parents 0e2e3aa + 561daaa commit 1f210c1
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 2 deletions.
77 changes: 76 additions & 1 deletion src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,77 @@ impl Env {

pub fn infer(&mut self, expr: ExprAt) -> Result<ExprTypedAt> {
let expr = self.infer_inner(0, expr)?;
let expr = self.wrapped(expr)?;
let mut expr = self.wrapped(expr)?;
self.generalize(-1, expr.ty())?;
self.sanitize_types(&mut expr)?;
Ok(expr)
}

fn sanitize_types(&self, expr: &mut ExprTypedAt) -> Result<()> {
// TODO: I got lazy here, make this more efficient
expr.context.ty.ty = self.real_ty(expr.context.ty.ty.clone())?;
match expr.expr.as_mut() {
Expr::Bool(_) => Ok(()),
Expr::Int(_) => Ok(()),
Expr::IntBinOp(_, a, b) => {
self.sanitize_types(a)?;
self.sanitize_types(b)
}
Expr::Negate(value) => self.sanitize_types(value),
Expr::EqualEqual(a, b) => {
self.sanitize_types(a)?;
self.sanitize_types(b)
}
Expr::Var(_) => Ok(()),
Expr::Call(fun, args) => {
for arg in args.iter_mut() {
self.sanitize_types(arg)?;
}
self.sanitize_types(fun)
}
Expr::Fun(params, body) => {
for param in params.iter_mut() {
self.sanitize_types(param)?;
}
self.sanitize_types(body)
}
Expr::Let(pattern, value, body) => {
self.sanitize_types(pattern)?;
self.sanitize_types(value)?;
self.sanitize_types(body)
}
Expr::RecordSelect(record, _) => self.sanitize_types(record),
Expr::RecordExtend(labels, rest) => {
for (_, expr) in labels.iter_mut() {
self.sanitize_types(expr)?;
}
self.sanitize_types(rest)
}
Expr::RecordRestrict(record, _) => self.sanitize_types(record),
Expr::RecordEmpty => Ok(()),
Expr::Variant(_, expr) => self.sanitize_types(expr),
Expr::Case(value, cases, def) => {
for (_, _, case) in cases.iter_mut() {
self.sanitize_types(case)?;
}
if let Some((_, expr)) = def {
self.sanitize_types(expr)?;
}
self.sanitize_types(value)
}
Expr::If(if_expr, if_body, elifs, else_body) => {
self.sanitize_types(if_expr)?;
self.sanitize_types(if_body)?;
for (elif_expr, elif_body) in elifs.iter_mut() {
self.sanitize_types(elif_expr)?;
self.sanitize_types(elif_body)?;
}
self.sanitize_types(else_body)
}
Expr::Unwrap(value) => self.sanitize_types(value),
}
}

fn assign_pattern(&mut self, pattern: PatternAt, ty: Type) -> Result<PatternTypedAt> {
match (*pattern.expr, ty) {
(Pattern::Var(name), ty) => {
Expand Down Expand Up @@ -982,6 +1048,15 @@ impl Env {
_ => Ok(ty),
}
}
Type::Arrow(params, ret) => {
let mut real_params = Vec::with_capacity(params.len());
for param in params {
let real_param = self.real_ty(param)?;
real_params.push(real_param);
}
let ret = self.real_ty(*ret)?;
Ok(Type::Arrow(real_params, ret.into()))
}
_ => Ok(ty),
}
}
Expand Down
45 changes: 44 additions & 1 deletion src/tests/typed_exprs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
core::make_env,
expr::{Expr, ExprIn, ExprTyped, Type, TypeContext},
expr::{Expr, ExprIn, ExprTyped, IntBinOp, Type, TypeContext},
parser::Parser,
};

Expand Down Expand Up @@ -32,10 +32,53 @@ fn let_(pattern: ExprTyped, value: ExprTyped, body: ExprTyped) -> ExprTyped {
expr(Expr::Let(pattern, value, body), ty)
}

fn fun(params: Vec<ExprTyped>, body: ExprTyped) -> ExprTyped {
let params_ty = params
.iter()
.map(|param| param.context.ty.clone())
.collect();
let ty = Type::Arrow(params_ty, body.context.ty.clone().into());
expr(Expr::Fun(params, body), ty)
}

fn add(a: ExprTyped, b: ExprTyped) -> ExprTyped {
let op = IntBinOp::Plus;
let ty = op.output_ty();
expr(Expr::IntBinOp(op, a, b), ty)
}

fn call(fun: ExprTyped, args: Vec<ExprTyped>) -> ExprTyped {
let ty = match fun.context.ty.clone() {
Type::Arrow(_, ret) => *ret,
_ => unreachable!(),
};
expr(Expr::Call(fun, args), ty)
}

#[test]
fn tests() {
pass(
"let x = 1 in x",
let_(var("x", Type::int()), int(1), var("x", Type::int())),
);
pass(
"let add(a, b) = a + b in add(1, 2)",
let_(
var(
"add",
Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()),
),
fun(
vec![var("a", Type::int()), var("b", Type::int())],
add(var("a", Type::int()), var("b", Type::int())),
),
call(
var(
"add",
Type::Arrow(vec![Type::int(), Type::int()], Type::int().into()),
),
vec![int(1), int(2)],
),
),
);
}

0 comments on commit 1f210c1

Please sign in to comment.