Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type inference for recursive functions #104

Merged
merged 1 commit into from
Dec 1, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 113 additions & 33 deletions src/typed_ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
open Core.Std
open Sast

exception Cant_infer_type of string

(* environment *)
type symbol_table = {
parent: symbol_table option;
Expand All @@ -15,6 +17,10 @@ type environment = {
types: Sast.tdefault list;
}

let rec get_top_scope scope = match scope.parent with
| None -> scope
| Some(scope) -> get_top_scope scope

let rec has_cycle_rec nodes (self, children) seen =
match List.find !seen ~f:(fun n -> n = self) with
| Some(_) -> true
Expand Down Expand Up @@ -54,10 +60,10 @@ let rec find_field types t access_list =



let unique_add l item =
if List.mem !l item
then l := !l
else l := item :: !l
let replace_add_fun l item =
let (Sast.FunDef(name, tparams, (_, _))) = item in
l := item :: List.filter !l
~f:(fun (Sast.FunDef(n, tps, (_, _))) -> name <> n || tparams <> tps)

let rec find_variable (scope: symbol_table) name =
match List.find scope.variables ~f:(fun (s, _) -> s = name) with
Expand All @@ -68,6 +74,9 @@ let rec find_variable (scope: symbol_table) name =
end
| Some(x) -> x

let find_seen_function functions name tparams =
List.find functions ~f:(fun (n,tps) -> name=n && tps=tparams)

let find_function functions name num_args = match
List.find functions ~f:(function (n, Ast.FunDef(fname,_,_)) -> name = fname && num_args = n)
with
Expand Down Expand Up @@ -117,16 +126,16 @@ let chord_of sexpr =
| _ -> failwith "This expression is not chordable"
end, Ast.Type("chord")

let rec sast_expr env tfuns_ref e =
let sast_expr_env = sast_expr env tfuns_ref in
let rec sast_expr ?(seen_funs = []) ?(force = false) env tfuns_ref e =
let sast_expr_env = sast_expr ~seen_funs:seen_funs env tfuns_ref in
match e with
| Ast.LitBool(x) -> Sast.LitBool(x), Ast.Bool
| Ast.LitInt(x) -> Sast.LitInt(x), Ast.Int
| Ast.LitFloat(x) -> Sast.LitFloat(x), Ast.Float
| Ast.LitStr(x) -> Sast.LitStr(x), Ast.String
| Ast.Binop(lexpr, op, rexpr) ->
let lexprt = sast_expr env tfuns_ref lexpr in
let rexprt = sast_expr env tfuns_ref rexpr in
let lexprt = sast_expr_env lexpr in
let rexprt = sast_expr_env rexpr in
let (_, lt) = lexprt in
let (_, rt) = rexprt in
let opfailwith constraint_str =
Expand Down Expand Up @@ -183,7 +192,7 @@ let rec sast_expr env tfuns_ref e =
else failwith "left side expression of zip must of float or array of float"
end
| Ast.Uniop(op, expr) ->
let exprt = sast_expr env tfuns_ref expr in
let exprt = sast_expr_env expr in
let (_, t) = exprt in
begin match op with
| Ast.Not when t = Ast.Bool -> Sast.Uniop(op, exprt), t
Expand Down Expand Up @@ -212,7 +221,7 @@ let rec sast_expr env tfuns_ref e =
| Ast.FunApply(name, arg_exprs) ->
(* Common code for all function calls *)
(* get typed versions of input expressions *)
let arg_texprs = List.map arg_exprs ~f:(sast_expr env tfuns_ref) in
let arg_texprs = List.map arg_exprs ~f:(sast_expr_env) in
(* get types of the typed arguments *)
let arg_types = List.map arg_texprs ~f:(fun (_, t) -> t) in

Expand All @@ -221,22 +230,58 @@ let rec sast_expr env tfuns_ref e =
then
(* find the function *)
let num_args = List.length arg_exprs in
let (_,Ast.FunDef(_,params,expr)) = find_function env.functions name num_args in
let nh_fun_sig = find_function env.functions name num_args in
let (_,Ast.FunDef(_,params,expr)) = nh_fun_sig in
(* zip params with input types *)
let tparams = match List.zip params arg_types with
| None -> failwith "Internal error: Mismatched lengths of types and arguments while type checking function call"
| Some(x) -> x
in
(* check if types of inputs can be used with this function and UPDATE tfuns_ref *)
let (sexpr, t) = try check_function_type tparams expr tfuns_ref env
with Failure(reason) ->
Log.info "Function template type check failed. Inner exception: %s" reason;
failwith ("Incorrect types passed into function "^name)
in begin
(* UPDATE tfuns_ref and return the sast node *)
ignore(unique_add tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
Sast.FunApply(NhFunction(name), arg_texprs), t
end
let has_seen_fun = (find_seen_function seen_funs name tparams) <> None in
(* check if function type already inferred *)
match List.find !tfuns_ref ~f:(fun (Sast.FunDef(n,tps,_)) -> name=n && tparams=tps) with
(* already know function signature and not forced to re-infer subexpression types
(or is forced but loop encountered, so need to use previous result anyway) *)
| Some(Sast.FunDef(_,_,(_,t))) when (force && has_seen_fun) || not force ->
Sast.FunApply(NhFunction(name), arg_texprs), t
(* don't know function signature, or do know signature but
forced to re-infer subexpression types *)
| _ ->
let seen_funs =
if has_seen_fun
(* if the function is already seen, we are in a loop - can't infer type;
roll back to most recent conditional and see what we can do there
(conditional catches Cant_infer_type exception) *)
then raise
(Cant_infer_type("Can't infer type of recursive call to nh function "^name))
(* function not seen yet; we need to infer subexpression types, so mark this function as seen *)
else (name,tparams):: seen_funs
in
let try_check_function_type force =
(* forcing type inference means that cached results in
tfuns_ref will be ignored unless a loop is encountered *)
try check_function_type tparams expr tfuns_ref env seen_funs force
with Failure(reason) ->
Log.info "Function template type check failed. Inner exception: %s" reason;
failwith ("Incorrect types passed into function "^name)
in
(* check if types of inputs can be used with this function *)
let (sexpr, t) = try_check_function_type force in
begin
(* UPDATE tfuns_ref *)
ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
(* check if it is safe to re-infer all subexpression types (ie tfun_ref is fully updated)
and that we are not already trying to re-infer all subexpression types *)
if List.length seen_funs = 1 && not force
(* go back in to resolve all types;
second pass guarantees no placeholder conditionals are in descendants *)
then let (sexpr, t) = try_check_function_type true in
ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t))));
Sast.FunApply(NhFunction(name), arg_texprs), t
(* pass sast back up the ast so higher expressions can infer type;
could still have placeholder conditionals in descendants *)
else Sast.FunApply(NhFunction(name), arg_texprs), t
end

(* C++ function calls *)
else
Expand Down Expand Up @@ -266,7 +311,7 @@ let rec sast_expr env tfuns_ref e =
extern_functions=env.extern_functions; types=env.types; }
| _ -> env
end
in let texpr = sast_expr env tfuns_ref expr in
in let texpr = sast_expr ~seen_funs:seen_funs env tfuns_ref expr in
(texpr :: texprs, env)
)
in
Expand Down Expand Up @@ -301,12 +346,46 @@ let rec sast_expr env tfuns_ref e =
let (condition, condition_t) = sast_expr_env condition in
if condition_t <> Ast.Bool then
failwith (sprintf "Condition must be a bool expression (%s found)" (Ast.string_of_type condition_t)) else
let (case_true, case_true_t) = sast_expr_env case_true
and (case_false, case_false_t) = sast_expr_env case_false in
if case_true_t <> case_false_t then
failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)"
(Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t)) else
Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t

let try_sast_expr_env expr =
try Some(sast_expr_env expr)
(* Note that Cant_infer_type is raised when an nh function has already
been seen higher up in the ast and it's not in tfuns_ref either *)
(* Also note that this case won't trigger if force was set to true (from FunApply) *)
with Cant_infer_type(_) -> None
in
(* try to infer type of true branch *)
begin match try_sast_expr_env case_true with
(* true branch type inference failed,
ie tfuns_ref is missing an nh function that has been used higher up in the ast *)
| None ->
(* see if other case can infer type *)
begin match try_sast_expr_env case_false with
(* neither branch terminates *)
| None -> failwith "Couldn't infer type of either branch of conditional"
(* only false branch terminates, assume entire conditional is of that type;
return fake sast with correct type so that tfuns_ref can be updated *)
| Some((_,t)) ->
let fake_sexpr = (Sast.LitUnit, t) in
Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), t
end
(* true branch type inference successful, now check false branch *)
| Some((case_true, case_true_t)) ->
(* see if other case can infer type *)
begin match try_sast_expr_env case_false with
(* both branch types have been inferred, check if types are the same *)
| Some((case_false, case_false_t)) -> if case_true_t <> case_false_t
then failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)"
(Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t))
else Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t
(* false branch type inference failed, ie tfuns_ref is missing an nh function that
has been used higher up in the ast;
true branch terminates, assume entire conditional is of that type;
return fake sast with correct type so that tfuns_ref can be updated *)
| None -> let fake_sexpr = (Sast.LitUnit, case_true_t) in
Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), case_true_t
end
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of copy-pasted code here. It's not too bad so I filed tech debt #106 to fix it.


| For(loop_var_name, items, body) ->
ignore (loop_var_name, items, body); failwith "Type checking not implemented for For"
Expand All @@ -315,7 +394,7 @@ let rec sast_expr env tfuns_ref e =
ignore msg; failwith "Type checking not implemented for Throw"

| Assign(names, expr) ->
let (value, tvalue) = sast_expr env tfuns_ref expr in
let (value, tvalue) = sast_expr_env expr in
begin match names with
| [] -> failwith "Internal error: Assign(names, _) had empty string list"
| name :: fields -> try begin
Expand All @@ -341,7 +420,7 @@ let rec sast_expr env tfuns_ref e =
| None -> failwith ("type "^typename^" not found")
in
let fields = List.map defaults ~f:(fun (n,(_,t)) -> (n,t)) in
let sexprs = List.map init_list ~f:(sast_expr env tfuns_ref) in
let sexprs = List.map init_list ~f:(sast_expr_env) in
let varname = function
| Init(name,(_,t)),_
when begin match List.find fields ~f:(fun (n,_) -> n = name) with
Expand Down Expand Up @@ -380,14 +459,15 @@ let rec sast_expr env tfuns_ref e =



and check_function_type tparams expr tfuns_ref env =
and check_function_type tparams expr tfuns_ref env seen_funs force =
let env' = {
scope = { variables = tparams; parent = env.scope.parent };
(* allow global scope *)
scope = { variables = tparams; parent = Some(get_top_scope env.scope) };
functions = env.functions;
extern_functions = env.extern_functions;
types = env.types;
} in
sast_expr env' tfuns_ref expr
sast_expr ~seen_funs:seen_funs ~force:force env' tfuns_ref expr

and typed_typedefs env tfuns_ref typedefs =
(* Assuming the users have defined the types in the correct order *)
Expand Down
11 changes: 11 additions & 0 deletions test/fundef_mutual_rec.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

fun First x = (
Print (StringOfInt x)
if x > 1 then (Print "s"; Second x; "not") else "done"
)

fun Second x = (
Print (First (x-1))
)

Print (First 5)
1 change: 1 addition & 0 deletions test/fundef_mutual_rec.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5s4s3s2s1donenotnotnotnot
5 changes: 5 additions & 0 deletions test/fundef_rec.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Factorial x = if x <= 1 then 1 else x * Factorial (x-1)

x = Factorial 3
Print (StringOfInt x)
1 change: 1 addition & 0 deletions test/fundef_rec.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6
5 changes: 5 additions & 0 deletions test/fundef_rec_notail.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Factorial x = if x > 1 then x * Factorial (x-1) else 1

x = Factorial 3
Print (StringOfInt x)
1 change: 1 addition & 0 deletions test/fundef_rec_notail.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6
5 changes: 5 additions & 0 deletions test/fundef_simple.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

fun Hello x = x

PrintEndline (Hello "what")
PrintEndline (StringOfInt (Hello 1))
2 changes: 2 additions & 0 deletions test/fundef_simple.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
what
1
4 changes: 2 additions & 2 deletions test/simple_assign.nh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

x = 5
x = "x"
y = 10

Print (StringOfInt x)
Print x
Print "\n"
Print (StringOfInt y)
2 changes: 1 addition & 1 deletion test/simple_assign.out
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
5
x
10
3 changes: 3 additions & 0 deletions test/simple_varref.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

x = 5
x
Empty file added test/simple_varref.out
Empty file.
16 changes: 16 additions & 0 deletions test/typedef_chained.nh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

type first = {
f = "hello"
}

type second = {
s = init first
}

type third = {
t = init second
}

x = init third

Print x$t$s$f
1 change: 1 addition & 0 deletions test/typedef_chained.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hello