diff --git a/src/typed_ast.ml b/src/typed_ast.ml index ca8bbd7..43e9520 100644 --- a/src/typed_ast.ml +++ b/src/typed_ast.ml @@ -2,6 +2,8 @@ open Core.Std open Sast +exception Cant_infer_type of string + (* environment *) type symbol_table = { parent: symbol_table option; @@ -15,6 +17,10 @@ type environment = { types: Sast.tdefault list; } +let rec get_top_scope scope = match scope.parent with + | None -> scope + | Some(scope) -> get_top_scope scope + let rec has_cycle_rec nodes (self, children) seen = match List.find !seen ~f:(fun n -> n = self) with | Some(_) -> true @@ -54,10 +60,10 @@ let rec find_field types t access_list = -let unique_add l item = - if List.mem !l item - then l := !l - else l := item :: !l +let replace_add_fun l item = + let (Sast.FunDef(name, tparams, (_, _))) = item in + l := item :: List.filter !l + ~f:(fun (Sast.FunDef(n, tps, (_, _))) -> name <> n || tparams <> tps) let rec find_variable (scope: symbol_table) name = match List.find scope.variables ~f:(fun (s, _) -> s = name) with @@ -68,6 +74,9 @@ let rec find_variable (scope: symbol_table) name = end | Some(x) -> x +let find_seen_function functions name tparams = + List.find functions ~f:(fun (n,tps) -> name=n && tps=tparams) + let find_function functions name num_args = match List.find functions ~f:(function (n, Ast.FunDef(fname,_,_)) -> name = fname && num_args = n) with @@ -117,16 +126,16 @@ let chord_of sexpr = | _ -> failwith "This expression is not chordable" end, Ast.Type("chord") -let rec sast_expr env tfuns_ref e = - let sast_expr_env = sast_expr env tfuns_ref in +let rec sast_expr ?(seen_funs = []) ?(force = false) env tfuns_ref e = + let sast_expr_env = sast_expr ~seen_funs:seen_funs env tfuns_ref in match e with | Ast.LitBool(x) -> Sast.LitBool(x), Ast.Bool | Ast.LitInt(x) -> Sast.LitInt(x), Ast.Int | Ast.LitFloat(x) -> Sast.LitFloat(x), Ast.Float | Ast.LitStr(x) -> Sast.LitStr(x), Ast.String | Ast.Binop(lexpr, op, rexpr) -> - let lexprt = sast_expr env tfuns_ref lexpr in - let rexprt = sast_expr env tfuns_ref rexpr in + let lexprt = sast_expr_env lexpr in + let rexprt = sast_expr_env rexpr in let (_, lt) = lexprt in let (_, rt) = rexprt in let opfailwith constraint_str = @@ -183,7 +192,7 @@ let rec sast_expr env tfuns_ref e = else failwith "left side expression of zip must of float or array of float" end | Ast.Uniop(op, expr) -> - let exprt = sast_expr env tfuns_ref expr in + let exprt = sast_expr_env expr in let (_, t) = exprt in begin match op with | Ast.Not when t = Ast.Bool -> Sast.Uniop(op, exprt), t @@ -212,7 +221,7 @@ let rec sast_expr env tfuns_ref e = | Ast.FunApply(name, arg_exprs) -> (* Common code for all function calls *) (* get typed versions of input expressions *) - let arg_texprs = List.map arg_exprs ~f:(sast_expr env tfuns_ref) in + let arg_texprs = List.map arg_exprs ~f:(sast_expr_env) in (* get types of the typed arguments *) let arg_types = List.map arg_texprs ~f:(fun (_, t) -> t) in @@ -221,22 +230,58 @@ let rec sast_expr env tfuns_ref e = then (* find the function *) let num_args = List.length arg_exprs in - let (_,Ast.FunDef(_,params,expr)) = find_function env.functions name num_args in + let nh_fun_sig = find_function env.functions name num_args in + let (_,Ast.FunDef(_,params,expr)) = nh_fun_sig in (* zip params with input types *) let tparams = match List.zip params arg_types with | None -> failwith "Internal error: Mismatched lengths of types and arguments while type checking function call" | Some(x) -> x in - (* check if types of inputs can be used with this function and UPDATE tfuns_ref *) - let (sexpr, t) = try check_function_type tparams expr tfuns_ref env - with Failure(reason) -> - Log.info "Function template type check failed. Inner exception: %s" reason; - failwith ("Incorrect types passed into function "^name) - in begin - (* UPDATE tfuns_ref and return the sast node *) - ignore(unique_add tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t)))); - Sast.FunApply(NhFunction(name), arg_texprs), t - end + let has_seen_fun = (find_seen_function seen_funs name tparams) <> None in + (* check if function type already inferred *) + match List.find !tfuns_ref ~f:(fun (Sast.FunDef(n,tps,_)) -> name=n && tparams=tps) with + (* already know function signature and not forced to re-infer subexpression types + (or is forced but loop encountered, so need to use previous result anyway) *) + | Some(Sast.FunDef(_,_,(_,t))) when (force && has_seen_fun) || not force -> + Sast.FunApply(NhFunction(name), arg_texprs), t + (* don't know function signature, or do know signature but + forced to re-infer subexpression types *) + | _ -> + let seen_funs = + if has_seen_fun + (* if the function is already seen, we are in a loop - can't infer type; + roll back to most recent conditional and see what we can do there + (conditional catches Cant_infer_type exception) *) + then raise + (Cant_infer_type("Can't infer type of recursive call to nh function "^name)) + (* function not seen yet; we need to infer subexpression types, so mark this function as seen *) + else (name,tparams):: seen_funs + in + let try_check_function_type force = + (* forcing type inference means that cached results in + tfuns_ref will be ignored unless a loop is encountered *) + try check_function_type tparams expr tfuns_ref env seen_funs force + with Failure(reason) -> + Log.info "Function template type check failed. Inner exception: %s" reason; + failwith ("Incorrect types passed into function "^name) + in + (* check if types of inputs can be used with this function *) + let (sexpr, t) = try_check_function_type force in + begin + (* UPDATE tfuns_ref *) + ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t)))); + (* check if it is safe to re-infer all subexpression types (ie tfun_ref is fully updated) + and that we are not already trying to re-infer all subexpression types *) + if List.length seen_funs = 1 && not force + (* go back in to resolve all types; + second pass guarantees no placeholder conditionals are in descendants *) + then let (sexpr, t) = try_check_function_type true in + ignore(replace_add_fun tfuns_ref (Sast.FunDef(name, tparams, (sexpr, t)))); + Sast.FunApply(NhFunction(name), arg_texprs), t + (* pass sast back up the ast so higher expressions can infer type; + could still have placeholder conditionals in descendants *) + else Sast.FunApply(NhFunction(name), arg_texprs), t + end (* C++ function calls *) else @@ -266,7 +311,7 @@ let rec sast_expr env tfuns_ref e = extern_functions=env.extern_functions; types=env.types; } | _ -> env end - in let texpr = sast_expr env tfuns_ref expr in + in let texpr = sast_expr ~seen_funs:seen_funs env tfuns_ref expr in (texpr :: texprs, env) ) in @@ -301,12 +346,46 @@ let rec sast_expr env tfuns_ref e = let (condition, condition_t) = sast_expr_env condition in if condition_t <> Ast.Bool then failwith (sprintf "Condition must be a bool expression (%s found)" (Ast.string_of_type condition_t)) else - let (case_true, case_true_t) = sast_expr_env case_true - and (case_false, case_false_t) = sast_expr_env case_false in - if case_true_t <> case_false_t then - failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)" - (Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t)) else - Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t + + let try_sast_expr_env expr = + try Some(sast_expr_env expr) + (* Note that Cant_infer_type is raised when an nh function has already + been seen higher up in the ast and it's not in tfuns_ref either *) + (* Also note that this case won't trigger if force was set to true (from FunApply) *) + with Cant_infer_type(_) -> None + in + (* try to infer type of true branch *) + begin match try_sast_expr_env case_true with + (* true branch type inference failed, + ie tfuns_ref is missing an nh function that has been used higher up in the ast *) + | None -> + (* see if other case can infer type *) + begin match try_sast_expr_env case_false with + (* neither branch terminates *) + | None -> failwith "Couldn't infer type of either branch of conditional" + (* only false branch terminates, assume entire conditional is of that type; + return fake sast with correct type so that tfuns_ref can be updated *) + | Some((_,t)) -> + let fake_sexpr = (Sast.LitUnit, t) in + Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), t + end + (* true branch type inference successful, now check false branch *) + | Some((case_true, case_true_t)) -> + (* see if other case can infer type *) + begin match try_sast_expr_env case_false with + (* both branch types have been inferred, check if types are the same *) + | Some((case_false, case_false_t)) -> if case_true_t <> case_false_t + then failwith (sprintf "Both expressions in a conditional must have the same type (%s and %s found)" + (Ast.string_of_type case_true_t) (Ast.string_of_type case_false_t)) + else Sast.Conditional( (condition, condition_t), (case_true, case_true_t), (case_false, case_false_t) ), case_true_t + (* false branch type inference failed, ie tfuns_ref is missing an nh function that + has been used higher up in the ast; + true branch terminates, assume entire conditional is of that type; + return fake sast with correct type so that tfuns_ref can be updated *) + | None -> let fake_sexpr = (Sast.LitUnit, case_true_t) in + Sast.Conditional((condition, condition_t), fake_sexpr, fake_sexpr), case_true_t + end + end | For(loop_var_name, items, body) -> ignore (loop_var_name, items, body); failwith "Type checking not implemented for For" @@ -315,7 +394,7 @@ let rec sast_expr env tfuns_ref e = ignore msg; failwith "Type checking not implemented for Throw" | Assign(names, expr) -> - let (value, tvalue) = sast_expr env tfuns_ref expr in + let (value, tvalue) = sast_expr_env expr in begin match names with | [] -> failwith "Internal error: Assign(names, _) had empty string list" | name :: fields -> try begin @@ -341,7 +420,7 @@ let rec sast_expr env tfuns_ref e = | None -> failwith ("type "^typename^" not found") in let fields = List.map defaults ~f:(fun (n,(_,t)) -> (n,t)) in - let sexprs = List.map init_list ~f:(sast_expr env tfuns_ref) in + let sexprs = List.map init_list ~f:(sast_expr_env) in let varname = function | Init(name,(_,t)),_ when begin match List.find fields ~f:(fun (n,_) -> n = name) with @@ -380,14 +459,15 @@ let rec sast_expr env tfuns_ref e = -and check_function_type tparams expr tfuns_ref env = +and check_function_type tparams expr tfuns_ref env seen_funs force = let env' = { - scope = { variables = tparams; parent = env.scope.parent }; + (* allow global scope *) + scope = { variables = tparams; parent = Some(get_top_scope env.scope) }; functions = env.functions; extern_functions = env.extern_functions; types = env.types; } in - sast_expr env' tfuns_ref expr + sast_expr ~seen_funs:seen_funs ~force:force env' tfuns_ref expr and typed_typedefs env tfuns_ref typedefs = (* Assuming the users have defined the types in the correct order *) diff --git a/test/fundef_mutual_rec.nh b/test/fundef_mutual_rec.nh new file mode 100644 index 0000000..fe50a76 --- /dev/null +++ b/test/fundef_mutual_rec.nh @@ -0,0 +1,11 @@ + +fun First x = ( + Print (StringOfInt x) + if x > 1 then (Print "s"; Second x; "not") else "done" +) + +fun Second x = ( + Print (First (x-1)) +) + +Print (First 5) diff --git a/test/fundef_mutual_rec.out b/test/fundef_mutual_rec.out new file mode 100644 index 0000000..ff441e6 --- /dev/null +++ b/test/fundef_mutual_rec.out @@ -0,0 +1 @@ +5s4s3s2s1donenotnotnotnot \ No newline at end of file diff --git a/test/fundef_rec.nh b/test/fundef_rec.nh new file mode 100644 index 0000000..b103617 --- /dev/null +++ b/test/fundef_rec.nh @@ -0,0 +1,5 @@ + +fun Factorial x = if x <= 1 then 1 else x * Factorial (x-1) + +x = Factorial 3 +Print (StringOfInt x) diff --git a/test/fundef_rec.out b/test/fundef_rec.out new file mode 100644 index 0000000..62f9457 --- /dev/null +++ b/test/fundef_rec.out @@ -0,0 +1 @@ +6 \ No newline at end of file diff --git a/test/fundef_rec_notail.nh b/test/fundef_rec_notail.nh new file mode 100644 index 0000000..f072df8 --- /dev/null +++ b/test/fundef_rec_notail.nh @@ -0,0 +1,5 @@ + +fun Factorial x = if x > 1 then x * Factorial (x-1) else 1 + +x = Factorial 3 +Print (StringOfInt x) diff --git a/test/fundef_rec_notail.out b/test/fundef_rec_notail.out new file mode 100644 index 0000000..62f9457 --- /dev/null +++ b/test/fundef_rec_notail.out @@ -0,0 +1 @@ +6 \ No newline at end of file diff --git a/test/fundef_simple.nh b/test/fundef_simple.nh new file mode 100644 index 0000000..a2c397e --- /dev/null +++ b/test/fundef_simple.nh @@ -0,0 +1,5 @@ + +fun Hello x = x + +PrintEndline (Hello "what") +PrintEndline (StringOfInt (Hello 1)) diff --git a/test/fundef_simple.out b/test/fundef_simple.out new file mode 100644 index 0000000..acb2571 --- /dev/null +++ b/test/fundef_simple.out @@ -0,0 +1,2 @@ +what +1 diff --git a/test/simple_assign.nh b/test/simple_assign.nh index 94faedd..47d8638 100644 --- a/test/simple_assign.nh +++ b/test/simple_assign.nh @@ -1,7 +1,7 @@ -x = 5 +x = "x" y = 10 -Print (StringOfInt x) +Print x Print "\n" Print (StringOfInt y) diff --git a/test/simple_assign.out b/test/simple_assign.out index 81882f0..63cf67f 100644 --- a/test/simple_assign.out +++ b/test/simple_assign.out @@ -1,2 +1,2 @@ -5 +x 10 \ No newline at end of file diff --git a/test/simple_varref.nh b/test/simple_varref.nh new file mode 100644 index 0000000..461a23f --- /dev/null +++ b/test/simple_varref.nh @@ -0,0 +1,3 @@ + +x = 5 +x diff --git a/test/simple_varref.out b/test/simple_varref.out new file mode 100644 index 0000000..e69de29 diff --git a/test/typedef_chained.nh b/test/typedef_chained.nh new file mode 100644 index 0000000..8c8eac8 --- /dev/null +++ b/test/typedef_chained.nh @@ -0,0 +1,16 @@ + +type first = { + f = "hello" +} + +type second = { + s = init first +} + +type third = { + t = init second +} + +x = init third + +Print x$t$s$f diff --git a/test/typedef_chained.out b/test/typedef_chained.out new file mode 100644 index 0000000..b6fc4c6 --- /dev/null +++ b/test/typedef_chained.out @@ -0,0 +1 @@ +hello \ No newline at end of file