Skip to content

Commit

Permalink
initial trait work
Browse files Browse the repository at this point in the history
  • Loading branch information
Wren H committed Sep 9, 2024
1 parent f21c4a1 commit 8278d54
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 26 deletions.
44 changes: 37 additions & 7 deletions lib/frontend/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ open Share.Maybe
(* ideally these would be newtypes, but ocaml doesn't have those *)
type resolved = R of int [@@deriving show { with_path = false }]

let fresh_resolved =
let i = ref 0 in
fun () ->
decr i;
!i

type unresolved = U of string
[@@deriving show { with_path = false }]

Expand All @@ -28,7 +34,7 @@ and 'a typ =
| TyPoly of 'a
| TyCustom of 'a * 'a typ list
| TyRef of 'a typ (* mutability shock horror *)
| TyMeta of 'a meta
| TyMeta of 'a meta ref
[@@deriving show { with_path = false }]

and 'a field = Field of 'a * 'a typ
Expand All @@ -41,8 +47,23 @@ let rec force (t : 'a typ) : 'a typ =
| TyCustom (a, b) -> TyCustom (a, List.map force b)
| TyRef a -> TyRef (force a)
| TyMeta m -> begin
match m with Unresolved -> t | Resolved t -> force t
match !m with Unresolved -> t | Resolved t -> force t
end
| _ -> t

let rec to_metas (polys : 'a list) (t : 'a typ) : 'a typ =
let f = to_metas polys in
match force (t : 'a typ) with
| TyTuple t -> TyTuple (List.map f t)
| TyArrow (a, b) -> TyArrow (f a, f b)
| TyPoly x -> begin
if not @@ List.mem x polys then
TyMeta (ref Unresolved)
else
t
end
| TyCustom (x, t) -> TyCustom (x, List.map f t)
| TyRef t -> TyRef (f t)
| _ -> t

let rec instantiate (map : ('a * 'a typ) list) (t : 'a typ) : 'a typ =
Expand All @@ -57,8 +78,6 @@ let rec instantiate (map : ('a * 'a typ) list) (t : 'a typ) : 'a typ =
| TyRef t -> TyRef (f t)
| _ -> t

type binop = Binop of string [@@deriving show { with_path = false }]

type 'a case =
| CaseVar of 'a
| CaseTuple of 'a case list
Expand All @@ -82,7 +101,7 @@ type 'a expr =
| LetIn of data * 'a case * 'a typ option * 'a expr * 'a expr
| Seq of data * 'a expr * 'a expr
| Funccall of data * 'a expr * 'a expr
| Binop of data * binop
| Binop of data * 'a
| Lambda of data * 'a * 'a typ option * 'a expr
| Tuple of data * 'a expr list
| Annot of data * 'a expr * 'a typ
Expand Down Expand Up @@ -133,8 +152,19 @@ type 'a typdef = {
}
[@@deriving show { with_path = false }]

let typdef_and_ctor_to_typ (t : 'a typdef) (i : 'a) : 'a typ =
match t.content with
| Record _ -> failwith "shouldn't be record"
| Sum s ->
let ctor = List.assoc i s in
let custom =
TyCustom (t.name, List.map (fun x -> TyPoly x) t.args)
in
typ_list_to_typ (ctor @ [ custom ])

type 'a trait_bound =
| Bound of 'a * 'a typ list (* trait name, "args" *)
| Bound of
'a * 'a typ list * 'a typ list (* trait name, args, assocs *)
[@@deriving show { with_path = false }]

type ('a, 'p) definition = {
Expand Down Expand Up @@ -164,7 +194,7 @@ type 'a trait = {

type 'a impl = {
data : data;
parent : uuid option;
parent : 'a trait;
args : ('a * 'a typ) list;
assocs : ('a * 'a typ) list;
impls : ('a, yes) definition list;
Expand Down
236 changes: 217 additions & 19 deletions lib/frontend/typecheck.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ open Share.Uuid
open Ast
open Share.Result
open Share.Maybe
open Unify


type ctx = {
(* name, parent *)
Expand Down Expand Up @@ -40,7 +42,7 @@ let search (ctx : ctx) (id : resolved) : (resolved typ, string) result
case' (List.assoc_opt id ctx.ctors) (fun t ->
match t.content with
| Record _ -> failwith "shouldn't be variable-ing a record"
| Sum s -> typ_list_to_typ @@ List.assoc id s);
| Sum s -> typdef_and_ctor_to_typ t id);

case' (List.assoc_opt id ctx.funs) (fun d -> definition_type d);

Expand Down Expand Up @@ -133,29 +135,210 @@ let rec infer (ctx : ctx) (e : resolved expr) :
infer ctx b
| Funccall (_, a, b) ->
let* a'ty = infer ctx a in
(* *)
(* unify the calling type and the argument type to make sure
they're actually compatible
*)
begin
match a'ty with
| TyArrow (q, w) -> failwith "tmp"
| TyArrow (q, w) ->
let* b'ty = infer ctx b in
let* _ = unify ctx.local_polys b'ty q in
ok w
| _ -> err "must function call on function type"
end
| Binop (_, _) -> failwith "tmp"
| Lambda (_, _, _, _) -> failwith "tmp"
| Tuple (_, _) -> failwith "tmp"
| Annot (_, _, _) -> failwith "tmp"
| Match (_, _, _) -> failwith "tmp"
| Project (_, _, _) -> failwith "tmp"
| Ref (_, _) -> failwith "tmp"
| Modify (_, _, _) -> failwith "tmp"
| Record (_, _, _) -> failwith "tmp"
| Binop (_, v) -> search ctx v
| Lambda (_, v, typ, body) ->
(* if we don't have a static type we can
make a meta in order to try and infer the body
gotta remember to "close up" the meta once we're
done though, nonlocal inference is sucky
*)
let typ = begin match typ with
| Some ty -> ty
| None -> TyMeta (ref Unresolved)
end in
let ctx = add_local ctx v typ in
let* body'ty = infer ctx body in
let* _ = match typ with
| TyMeta m ->
(* ocaml ref patterns when *)
begin match !m with
| Unresolved -> err "meta remained unsolved"
| _ -> ok ()
end
| _ -> ok ()
in
ok @@ TyArrow(typ, body'ty)
| Tuple (_, ts) ->
(* i love fp *)
List.map (infer ctx) ts
|> collect
|> Result.map_error (String.concat " ")
|> Result.map (fun x -> TyTuple x)
| Annot (_, x, t) ->
(* switch directions *)
check ctx x t
| Match (_, scrut, cases) ->
(* match is uniquely annoying because of that big old
fold down at the bottom. this makes error propogation
a pain, because you want to keep all of those possible
erroring unifies, but you also don't want the whole thing
to be a mess
*)
let* scrut_typ = infer ctx scrut in
let handle_case (case: 'a case * 'a expr): ('a typ, string) result =
let case, expr = case in
let* vars = break_down_case_pattern ctx case scrut_typ in
let ctx = add_locals ctx vars in
infer ctx expr
in
let* typs = List.map handle_case cases
|> collect
|> Result.map_error (String.concat " ")
in
begin match typs with
(* an empty match can return anything, because it can never be
matched on
*)
| [] -> ok @@ TyMeta (ref Unresolved)
| x :: xs ->
(* TODO: don't ignore errors here *)
List.fold_left (fun a b -> unify' ctx.local_polys a b; a) x xs
|> ok
end
| Project (_, x, i) ->
let* x'ty = infer ctx x in
begin match x'ty with
| TyCustom (nm, args) ->
let typ = List.find (fun (x: 'a typdef) -> x.name = nm) ctx.types in
begin match typ.content with
| Record fields ->
(* we have to consider the case in which the record is
parameterized therefore, while we know the field we
are working with, we need to up type arguments and
perform an instantiation
*)
let map = List.combine typ.args args in
let Field(_, typ) = List.nth fields i in
ok @@ instantiate map typ
| Sum _ -> err "should be record not sum"
end
| _ -> err "can't be record and not record"
end
| Ref (_, f) ->
let* t = infer ctx f in
ok @@ TyRef t
| Modify (_, old, new') ->
let* t = search ctx old in
let* _ = check ctx new' t in
ok @@ TyTuple []
| Record (_, nm, fields) ->
(* this case is mildly annoying, because we have to deal
with instantiation more explicitly
*)
failwith "records"
in
let uuid = get_uuid e in
add_type uuid ty;
ok (force ty)

and check (ctx : ctx) (e : resolved expr) (t : resolved typ) :
(resolved typ, string) result =
failwith "hi"
(* here, we only consider the cases where checking something
would actually benefit typechecking as a whole - therefore,
there's a whole bunch of stuff that we just defer straight back
to infer
i think there's technically some trickery that can be done
with regards to Funccall, but it's late and i can't think of it
right now
TODO: look at that later
*)
let* ty = match e with
| LetIn (_, case, ty, head, body) ->
let* head'ty = match ty with
| Some t -> check ctx head t
| None -> infer ctx head
in
let* vars = break_down_case_pattern ctx case head'ty in
let ctx = add_locals ctx vars in
check ctx body t
| Seq (_, a, b) ->
let* _ = check ctx a (TyTuple []) in
check ctx b t
| Lambda (_, v, ty, body) ->
begin match t with
| TyArrow(q, w) ->
let* ty = begin match ty with
(* logically if we're checking
(fun (x: t) -> ...)
against
q -> w
then t == q
*)
| Some t -> unify ctx.local_polys q t
| None -> ok q
end
in
let ctx = add_local ctx v ty in
check ctx body w
| _ -> err "lambda cannot be non-function type"
end
| Tuple (_, ts) ->
begin match t with
| TyTuple s ->
if List.length ts <> List.length s then
err "uneq tuple lengths"
else
(* i love fp round 2 *)
List.map2 (check ctx) ts s
|> collect
|> Result.map_error (String.concat " ")
|> Result.map (fun x -> TyTuple x)
| _ -> err "must be tuple type"
end
| Match (_, scrut, cases) ->
(* see comments in infer
should probably factor all this out tbh
*)
let* scrut_typ = infer ctx scrut in
let handle_case (case: 'a case * 'a expr): ('a typ, string) result =
let case, expr = case in
let* vars = break_down_case_pattern ctx case scrut_typ in
let ctx = add_locals ctx vars in
(* only difference is that we can check this time *)
check ctx expr t
in
let* typs = List.map handle_case cases
|> collect
|> Result.map_error (String.concat " ")
in
begin match typs with
(* an empty match can return anything, because it can never be
matched on
*)
| [] -> ok @@ TyMeta (ref Unresolved)
| x :: xs ->
(* TODO: don't ignore errors here *)
List.fold_left (fun a b -> unify' ctx.local_polys a b; a) x xs
|> ok
end
| Ref (_, r) ->
begin match t with
| TyRef r't ->
check ctx r r't
| _ -> err "cannot check ref against not ref"
end
| _ ->
(* in the general case, we defer to infer (hehe) and then
come back here and "check our work" with unify
*)
let* ty = infer ctx e in
let* _ = unify ctx.local_polys ty t in
ok ty
in
let uuid = get_uuid e in
add_type uuid ty;
ok (force ty)

let typecheck_definition (ctx : ctx) (d : (resolved, yes) definition)
: (unit, string) result =
Expand All @@ -165,8 +348,8 @@ let typecheck_definition (ctx : ctx) (d : (resolved, yes) definition)
let ctx =
{
ctx with
locals = args;
local_polys = polys;
locals = ctx.locals @ args;
local_polys = ctx.local_polys @ polys;
(* yay recursion *)
funs = self :: ctx.funs;
}
Expand All @@ -175,6 +358,23 @@ let typecheck_definition (ctx : ctx) (d : (resolved, yes) definition)
let* _ = check ctx body d.return in
ok ()

let typecheck_impl (ctx: ctx) (i: resolved impl): (unit, string) result =
(* TODO: check that args match the trait *)
List.map (typecheck_definition ctx) i.impls
|> collect
|> Result.map_error (String.concat " ")
|> Result.map (fun _ -> ())

let typecheck_toplevel (ctx: ctx) (t: resolved toplevel):
(unit, string) result =
match t with
| Typdef _ -> ok ()
| Trait _ -> ok ()
| Impl i ->
typecheck_impl ctx i
| Definition d ->
typecheck_definition ctx d

let gather (t : resolved toplevel list) : ctx =
let ctx = empty_ctx () in
List.fold_left
Expand All @@ -201,14 +401,12 @@ let gather (t : resolved toplevel list) : ctx =
{ acc with traitfuns = (a.name, t) :: acc.traitfuns })
ctx t.functions
| Impl _ ->
(* we don't do anything here yet
TODO: typecheck impl'd functions
*)
(* we don't do anything here *)
ctx
| Definition d ->
{ ctx with funs = (d.name, forget_body d) :: ctx.funs })
ctx t

let typecheck_toplevel (t : resolved toplevel list) : unit =
let ctx = gather t in
failwith "temp"

Loading

0 comments on commit 8278d54

Please sign in to comment.