Skip to content

Commit

Permalink
Fix prim defs to use an internal GADT to avoid warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
polytypic committed Feb 22, 2020
1 parent 89edfee commit 1d77593
Showing 1 changed file with 84 additions and 82 deletions.
166 changes: 84 additions & 82 deletions prim.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,90 +54,92 @@ let is_poly {typ = ts1, ts2} = List.mem VarT ts1 || List.mem VarT ts2

let typs = [BoolT; IntT; CharT; TextT]

type 'a def =
| VoidD: unit def
| BoolD: bool def
| IntD: int def
| CharD: char def
| TextD: string def
| VarD: const def
| ProdD: 'a def * 'b def -> ('a * 'b) def

let (&) l r = ProdD (l, r)

let rec typs_of: type a. a def -> typ list = function
| VoidD -> []
| BoolD -> [BoolT]
| IntD -> [IntT]
| CharD -> [CharT]
| TextD -> [TextT]
| VarD -> [VarT]
| ProdD (l, r) -> typs_of l @ typs_of r

let rec inj: type a. a def -> a -> const list -> const list = function
| VoidD -> fun () vs -> vs
| BoolD -> fun v vs -> BoolV v :: vs
| IntD -> fun v vs -> IntV v :: vs
| CharD -> fun v vs -> CharV v :: vs
| TextD -> fun v vs -> TextV v :: vs
| VarD -> fun v vs -> v :: vs
| ProdD (lD, rD) ->
let injL = inj lD and injR = inj rD in fun (l, r) vs -> injL l (injR r vs)

let rec prj: type a. a def -> const list -> a * const list = function
| VoidD -> fun vs -> ((), vs)
| BoolD -> (function (BoolV v :: vs) -> (v, vs) | _ -> failwith "bool")
| IntD -> (function (IntV v :: vs) -> (v, vs) | _ -> failwith "int")
| CharD -> (function (CharV v :: vs) -> (v, vs) | _ -> failwith "char")
| TextD -> (function (TextV v :: vs) -> (v, vs) | _ -> failwith "text")
| VarD -> (function (v :: vs) -> (v, vs) | _ -> failwith "var")
| ProdD (lD, rD) ->
let prjL = prj lD and prjR = prj rD in
fun vs -> let (l, vs) = prjL vs in let (r, vs) = prjR vs in ((l, r), vs)

let def name inD outD fn = {
name = name;
typ = typs_of inD, typs_of outD;
fn = let inj = inj outD and prj = prj inD in
fun vs -> let (v, vs) = prj vs in assert (vs = []); inj (fn v) []
}

let funs =
[
{name = "==";
typ = [VarT; VarT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 = x2)]};
{name = "<>";
typ = [VarT; VarT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 <> x2)]};

{name = "true";
typ = [], [BoolT];
fn = fun [] -> [BoolV(true)]};
{name = "false";
typ = [], [BoolT];
fn = fun [] -> [BoolV(false)]};

{name = "Int.+";
typ = [IntT; IntT], [IntT];
fn = fun [IntV i1; IntV i2] -> [IntV(i1 + i2)]};
{name = "Int.-";
typ = [IntT; IntT], [IntT];
fn = fun [IntV i1; IntV i2] -> [IntV(i1 - i2)]};
{name = "Int.*";
typ = [IntT; IntT], [IntT];
fn = fun [IntV i1; IntV i2] -> [IntV(i1 * i2)]};
{name = "Int./";
typ = [IntT; IntT], [IntT];
fn = fun [IntV i1; IntV i2] -> [IntV(i1 / i2)]};
{name = "Int.%";
typ = [IntT; IntT], [IntT];
fn = fun [IntV i1; IntV i2] -> [IntV(i1 mod i2)]};
{name = "Int.<";
typ = [IntT; IntT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 < x2)]};
{name = "Int.>";
typ = [IntT; IntT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 > x2)]};
{name = "Int.<=";
typ = [IntT; IntT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 <= x2)]};
{name = "Int.>=";
typ = [IntT; IntT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 >= x2)]};
{name = "Int.print";
typ = [IntT], [];
fn = fun [IntV i] -> print_int i; flush_all (); []};

{name = "Char.toInt";
typ = [CharT], [IntT];
fn = fun [CharV c] -> [IntV(Char.code c)]};
{name = "Char.fromInt";
typ = [IntT], [CharT];
fn = fun [IntV i] -> [CharV(Char.chr i)]};
{name = "Char.print";
typ = [CharT], [];
fn = fun [CharV c] -> print_char c; flush_all (); []};

{name = "Text.++";
typ = [TextT; TextT], [TextT];
fn = fun [TextV t1; TextV t2] -> [TextV(t1 ^ t2)]};
{name = "Text.<";
typ = [TextT; TextT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 < x2)]};
{name = "Text.>";
typ = [TextT; TextT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 > x2)]};
{name = "Text.<=";
typ = [TextT; TextT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 <= x2)]};
{name = "Text.>=";
typ = [TextT; TextT], [BoolT];
fn = fun [x1; x2] -> [BoolV(x1 >= x2)]};
{name = "Text.length";
typ = [TextT], [IntT];
fn = fun [TextV t] -> [IntV(String.length t)]};
{name = "Text.sub";
typ = [TextT; IntT], [CharT];
fn = fun [TextV t; IntV i] -> [CharV(t.[i])]};
{name = "Text.fromChar";
typ = [CharT], [TextT];
fn = fun [CharV c] -> [TextV(String.make 1 c)]};
{name = "Text.print";
typ = [TextT], [];
fn = fun [TextV t] -> print_string t; flush_all (); []};
def "==" (VarD & VarD) BoolD (fun (x1, x2) -> x1 = x2);
def "<>" (VarD & VarD) BoolD (fun (x1, x2) -> x1 <> x2);

def "true" VoidD BoolD (fun () -> true);
def "false" VoidD BoolD (fun () -> false);

def "Int.+" (IntD & IntD) IntD (fun (i1, i2) -> i1 + i2);
def "Int.-" (IntD & IntD) IntD (fun (i1, i2) -> i1 - i2);
def "Int.*" (IntD & IntD) IntD (fun (i1, i2) -> i1 * i2);
def "Int./" (IntD & IntD) IntD (fun (i1, i2) -> i1 / i2);
def "Int.%" (IntD & IntD) IntD (fun (i1, i2) -> i1 mod i2);

def "Int.<" (IntD & IntD) BoolD (fun (i1, i2) -> i1 < i2);
def "Int.>" (IntD & IntD) BoolD (fun (i1, i2) -> i1 > i2);
def "Int.<=" (IntD & IntD) BoolD (fun (i1, i2) -> i1 <= i2);
def "Int.>=" (IntD & IntD) BoolD (fun (i1, i2) -> i1 >= i2);

def "Int.print" IntD VoidD (fun i -> print_int i; flush_all ());

def "Char.toInt" CharD IntD Char.code;
def "Char.fromInt" IntD CharD Char.chr;

def "Char.print" CharD VoidD (fun c -> print_char c; flush_all ());

def "Text.++" (TextD & TextD) TextD (fun (t1, t2) -> t1 ^ t2);

def "Text.<" (TextD & TextD) BoolD (fun (i1, i2) -> i1 < i2);
def "Text.>" (TextD & TextD) BoolD (fun (i1, i2) -> i1 > i2);
def "Text.<=" (TextD & TextD) BoolD (fun (i1, i2) -> i1 <= i2);
def "Text.>=" (TextD & TextD) BoolD (fun (i1, i2) -> i1 >= i2);

def "Text.length" TextD IntD String.length;
def "Text.sub" (TextD & IntD) CharD (fun (t, i) -> t.[i]);
def "Text.fromChar" CharD TextD (String.make 1);

def "Text.print" TextD VoidD (fun t -> print_string t; flush_all ());
]

let fun_of_string name =
Expand Down

0 comments on commit 1d77593

Please sign in to comment.