Skip to content

Commit

Permalink
Merge pull request #104 from el2724/ed_recursive_type_inference
Browse files Browse the repository at this point in the history
type inference for recursive functions
  • Loading branch information
edwadli committed Dec 1, 2015
2 parents 3db895b + 11feaab commit 704e505
Show file tree
Hide file tree
Showing 15 changed files with 167 additions and 36 deletions.
146 changes: 113 additions & 33 deletions src/typed_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
open Core.Std
open Sast

exception Cant_infer_type of string

(* environment *)
type symbol_table = {
parent: symbol_table option;
Expand All @@ -15,6 +17,10 @@ type environment = {
types: Sast.tdefault list;
}

let rec get_top_scope scope = match scope.parent with
| None -> scope
| Some(scope) -> get_top_scope scope

let rec has_cycle_rec nodes (self, children) seen =
match List.find !seen ~f:(fun n -> n = self) with
| Some(_) -> true
Expand Down Expand Up @@ -54,10 +60,10 @@ let rec find_field types t access_list =



let unique_add l item =
if List.mem !l item
then l := !l
else l := item :: !l
let replace_add_fun l item =
let (Sast.FunDef(name, tparams, (_, _))) = item in
l := item :: List.filter !l
~f:(fun (Sast.FunDef(n, tps, (_, _))) -> name <> n || tparams <> tps)

let rec find_variable (scope: symbol_table) name =
match List.find scope.variables ~f:(fun (s, _) -> s = name) with
Expand All @@ -68,6 +74,9 @@ let rec find_variable (scope: symbol_table) name =
end
| Some(x) -> x

let find_seen_function functions name tparams =
List.find functions ~f:(fun (n,tps) -> name=n && tps=tparams)

let find_function functions name num_args = match
List.find functions ~f:(function (n, Ast.FunDef(fname,_,_)) -> name = fname && num_args = n)
with
Expand Down Expand Up @@ -117,16 +126,16 @@ let chord_of sexpr =
| _ -> failwith "This expression is not chordable"
end, Ast.Type("chord")

let rec sast_expr env tfuns_ref e =
let sast_expr_env = sast_expr env tfuns_ref in
let rec sast_expr ?(seen_funs = []) ?(force = false) env tfuns_ref e =
let sast_expr_env = sast_expr ~seen_funs:seen_funs env tfuns_ref in
match e with
| Ast.LitBool(x) -> Sast.LitBool(x), Ast.Bool
| Ast.LitInt(x) -> Sast.LitInt(x), Ast.Int
| Ast.LitFloat(x) -> Sast.LitFloat(x), Ast.Float
| Ast.LitStr(x) -> Sast.LitStr(x), Ast.String
| Ast.Binop(lexpr, op, rexpr) ->
let lexprt = sast_expr env tfuns_ref lexpr in
let rexprt = sast_expr env tfuns_ref rexpr in
let lexprt = sast_expr_env lexpr in
let rexprt = sast_expr_env rexpr in
let (_, lt) = lexprt in
let (_, rt) = rexprt in
let opfailwith constraint_str =
Expand Down Expand Up @@ -183,7 +192,7 @@ let rec sast_expr env tfuns_ref e =
else failwith "left side expression of zip must of float or array of float"
end
| Ast.Uniop(op, expr) ->
let exprt = sast_expr env tfuns_ref expr in
let exprt = sast_expr_env expr in
let (_, t) = exprt in
begin match op with
| Ast.Not when t = Ast.Bool -> Sast.Uniop(op, exprt), t
Expand Down Expand Up @@ -212,7 +221,7 @@ let rec sast_expr env tfuns_ref e =
| Ast.FunApply(name, arg_exprs) ->
(* Common code for all function calls *)
(* get typed versions of input expressions *)
let arg_texprs = List.map arg_exprs ~f:(sast_expr env tfuns_ref) in
let arg_texprs = List.map arg_exprs ~f:(sast_expr_env) in
(* get types of the typed arguments *)
let arg_types = List.map arg_texprs ~f:(fun (_, t) -> t) in

Expand All @@ -221,22 +230,58 @@ let rec sast_expr env tfuns_ref e =
then
(* find the function *)
let num_args = List.length arg_exprs in
let (_,Ast.FunDef(_,params,expr)) = find_function env.functions name num_args in
let nh_fun_sig = find_function env.functions name num_args in
let (_,Ast.FunDef(_,params,expr)) = nh_fun_sig in
(* zip params with input types *)
let tparams = match List.zip params arg_types with
| None -> failwith "Internal error: Mismatched lengths of types and arguments while type checking function call"
| Some(x) -> x
in
(* check if types of inputs can be used with this function and UPDATE tfuns_ref *)
let (sexpr, t) = try check_function_type tparams expr tfuns_ref env
with Failure(reason) ->
Log.info "Function template type check failed. Inner exception: %s" reason;
failwith ("Incorrect types passed into function "^name)
in begin
(* UPDATE tfuns_ref and return the sast node *)
ignore(unique_add tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
Sast.FunApply(NhFunction(name), arg_texprs), t
end
let has_seen_fun = (find_seen_function seen_funs name tparams) <> None in
(* check if function type already inferred *)
match List.find !tfuns_ref ~f:(fun (Sast.FunDef(n,tps,_)) -> name=n && tparams=tps) with
(* already know function signature and not forced to re-infer subexpression types
(or is forced but loop encountered, so need to use previous result anyway) *)
| Some(Sast.FunDef(_,_,(_,t))) when (force && has_seen_fun) || not force ->
Sast.FunApply(NhFunction(name), arg_texprs), t
(* don't know function signature, or do know signature but
forced to re-infer subexpression types *)
| _ ->
let seen_funs =
if has_seen_fun
(* if the function is already seen, we are in a loop - can't infer type;
roll back to most recent conditional and see what we can do there
(conditional catches Cant_infer_type exception) *)
then raise
(Cant_infer_type("Can't infer type of recursive call to nh function "^name))
(* function not seen yet; we need to infer subexpression types, so mark this function as seen *)
else (name,tparams):: seen_funs
in
let try_check_function_type force =
(* forcing type inference means that cached results in
tfuns_ref will be ignored unless a loop is encountered *)
try check_function_type tparams expr tfuns_ref env seen_funs force
with Failure(reason) ->
Log.info "Function template type check failed. Inner exception: %s" reason;
failwith ("Incorrect types passed into function "^name)
in
(* check if types of inputs can be used with this function *)
let (sexpr, t) = try_check_function_type force in
begin
(* UPDATE tfuns_ref *)
ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
(* check if it is safe to re-infer all subexpression types (ie tfun_ref is fully updated)
and that we are not already trying to re-infer all subexpression types *)
if List.length seen_funs = 1 && not force
(* go back in to resolve all types;
second pass guarantees no placeholder conditionals are in descendants *)
then let (sexpr, t) = try_check_function_type true in
ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
Sast.FunApply(NhFunction(name), arg_texprs), t
(* pass sast back up the ast so higher expressions can infer type;
could still have placeholder conditionals in descendants *)
else Sast.FunApply(NhFunction(name), arg_texprs), t
end

(* C++ function calls *)
else
Expand Down Expand Up @@ -266,7 +311,7 @@ let rec sast_expr env tfuns_ref e =
extern_functions=env.extern_functions; types=env.types; }
| _ -> env
end
in let texpr = sast_expr env tfuns_ref expr in
in let texpr = sast_expr ~seen_funs:seen_funs env tfuns_ref expr in
(texpr :: texprs, env)
)
in
Expand Down Expand Up @@ -301,12 +346,46 @@ let rec sast_expr env tfuns_ref e =
let (condition, condition_t) = sast_expr_env condition in
if condition_t <> Ast.Bool then
failwith (sprintf "Condition must be a bool expression (%s found)" (Ast.string_of_type condition_t)) else
let (case_true, case_true_t) = sast_expr_env case_true
and (case_false, case_false_t) = sast_expr_env case_false in
if case_true_t <> case_false_t then
failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)"
(Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t)) else
Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t

let try_sast_expr_env expr =
try Some(sast_expr_env expr)
(* Note that Cant_infer_type is raised when an nh function has already
been seen higher up in the ast and it's not in tfuns_ref either *)
(* Also note that this case won't trigger if force was set to true (from FunApply) *)
with Cant_infer_type(_) -> None
in
(* try to infer type of true branch *)
begin match try_sast_expr_env case_true with
(* true branch type inference failed,
ie tfuns_ref is missing an nh function that has been used higher up in the ast *)
| None ->
(* see if other case can infer type *)
begin match try_sast_expr_env case_false with
(* neither branch terminates *)
| None -> failwith "Couldn't infer type of either branch of conditional"
(* only false branch terminates, assume entire conditional is of that type;
return fake sast with correct type so that tfuns_ref can be updated *)
| Some((_,t)) ->
let fake_sexpr = (Sast.LitUnit, t) in
Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), t
end
(* true branch type inference successful, now check false branch *)
| Some((case_true, case_true_t)) ->
(* see if other case can infer type *)
begin match try_sast_expr_env case_false with
(* both branch types have been inferred, check if types are the same *)
| Some((case_false, case_false_t)) -> if case_true_t <> case_false_t
then failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)"
(Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t))
else Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t
(* false branch type inference failed, ie tfuns_ref is missing an nh function that
has been used higher up in the ast;
true branch terminates, assume entire conditional is of that type;
return fake sast with correct type so that tfuns_ref can be updated *)
| None -> let fake_sexpr = (Sast.LitUnit, case_true_t) in
Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), case_true_t
end
end

| For(loop_var_name, items, body) ->
ignore (loop_var_name, items, body); failwith "Type checking not implemented for For"
Expand All @@ -315,7 +394,7 @@ let rec sast_expr env tfuns_ref e =
ignore msg; failwith "Type checking not implemented for Throw"

| Assign(names, expr) ->
let (value, tvalue) = sast_expr env tfuns_ref expr in
let (value, tvalue) = sast_expr_env expr in
begin match names with
| [] -> failwith "Internal error: Assign(names, _) had empty string list"
| name :: fields -> try begin
Expand All @@ -341,7 +420,7 @@ let rec sast_expr env tfuns_ref e =
| None -> failwith ("type "^typename^" not found")
in
let fields = List.map defaults ~f:(fun (n,(_,t)) -> (n,t)) in
let sexprs = List.map init_list ~f:(sast_expr env tfuns_ref) in
let sexprs = List.map init_list ~f:(sast_expr_env) in
let varname = function
| Init(name,(_,t)),_
when begin match List.find fields ~f:(fun (n,_) -> n = name) with
Expand Down Expand Up @@ -380,14 +459,15 @@ let rec sast_expr env tfuns_ref e =



and check_function_type tparams expr tfuns_ref env =
and check_function_type tparams expr tfuns_ref env seen_funs force =
let env' = {
scope = { variables = tparams; parent = env.scope.parent };
(* allow global scope *)
scope = { variables = tparams; parent = Some(get_top_scope env.scope) };
functions = env.functions;
extern_functions = env.extern_functions;
types = env.types;
} in
sast_expr env' tfuns_ref expr
sast_expr ~seen_funs:seen_funs ~force:force env' tfuns_ref expr

and typed_typedefs env tfuns_ref typedefs =
(* Assuming the users have defined the types in the correct order *)
Expand Down
11 changes: 11 additions & 0 deletions test/fundef_mutual_rec.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

fun First x = (
Print (StringOfInt x)
if x > 1 then (Print "s"; Second x; "not") else "done"
)

fun Second x = (
Print (First (x-1))
)

Print (First 5)
1 change: 1 addition & 0 deletions test/fundef_mutual_rec.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5s4s3s2s1donenotnotnotnot
5 changes: 5 additions & 0 deletions test/fundef_rec.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Factorial x = if x <= 1 then 1 else x * Factorial (x-1)

x = Factorial 3
Print (StringOfInt x)
1 change: 1 addition & 0 deletions test/fundef_rec.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6
5 changes: 5 additions & 0 deletions test/fundef_rec_notail.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Factorial x = if x > 1 then x * Factorial (x-1) else 1

x = Factorial 3
Print (StringOfInt x)
1 change: 1 addition & 0 deletions test/fundef_rec_notail.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6
5 changes: 5 additions & 0 deletions test/fundef_simple.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Hello x = x

PrintEndline (Hello "what")
PrintEndline (StringOfInt (Hello 1))
2 changes: 2 additions & 0 deletions test/fundef_simple.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
what
1
4 changes: 2 additions & 2 deletions test/simple_assign.nh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

x = 5
x = "x"
y = 10

Print (StringOfInt x)
Print x
Print "\n"
Print (StringOfInt y)
2 changes: 1 addition & 1 deletion test/simple_assign.out
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
5
x
10
3 changes: 3 additions & 0 deletions test/simple_varref.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

x = 5
x
Empty file added test/simple_varref.out
Empty file.
16 changes: 16 additions & 0 deletions test/typedef_chained.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

type first = {
f = "hello"
}

type second = {
s = init first
}

type third = {
t = init second
}

x = init third

Print x$t$s$f
1 change: 1 addition & 0 deletions test/typedef_chained.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hello

0 comments on commit 704e505

Please sign in to comment.