Skip to content

Commit

Permalink
Add match statement (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Jan 15, 2025
1 parent 51b9799 commit dd88edc
Show file tree
Hide file tree
Showing 25 changed files with 714 additions and 45 deletions.
31 changes: 31 additions & 0 deletions bootstrap_compiler/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,33 @@ static void build_loop(
add_jump(st, NULL, condblock, condblock, doneblock);
}

static void build_match_statament(struct State *st, const AstMatchStatement *match_stmt)
{
const LocalVariable *match_obj_enum = build_expression(st, &match_stmt->match_obj);
LocalVariable *match_obj_int = add_local_var(st, intType);
add_unary_op(st, match_stmt->match_obj.location, CF_ENUM_TO_INT32, match_obj_enum, match_obj_int);

CfBlock *done = add_block(st);
for (int i = 0; i < match_stmt->ncases; i++) {
for (AstExpression *caseobj = match_stmt->cases[i].case_objs; caseobj < &match_stmt->cases[i].case_objs[match_stmt->cases[i].n_case_objs]; caseobj++) {
const LocalVariable *case_obj_enum = build_expression(st, caseobj);
LocalVariable *case_obj_int = add_local_var(st, intType);
add_unary_op(st, caseobj->location, CF_ENUM_TO_INT32, case_obj_enum, case_obj_int);

const LocalVariable *cond = build_binop(st, AST_EXPR_EQ, caseobj->location, match_obj_int, case_obj_int, boolType);
CfBlock *then = add_block(st);
CfBlock *otherwise = add_block(st);

add_jump(st, cond, then, otherwise, then);
build_body(st, &match_stmt->cases[i].body);
add_jump(st, NULL, done, done, otherwise);
}
}

build_body(st, &match_stmt->case_underscore);
add_jump(st, NULL, done, done, done);
}

static void build_statement(struct State *st, const AstStatement *stmt)
{
switch(stmt->kind) {
Expand All @@ -891,6 +918,10 @@ static void build_statement(struct State *st, const AstStatement *stmt)
&stmt->data.forloop.body);
break;

case AST_STMT_MATCH:
build_match_statament(st, &stmt->data.match);
break;

case AST_STMT_BREAK:
if (!st->breakstack.len)
fail(stmt->location, "'break' can only be used inside a loop");
Expand Down
12 changes: 12 additions & 0 deletions bootstrap_compiler/free.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ void free_ast_statement(const AstStatement *stmt)
free(stmt->data.forloop.incr);
free_ast_body(&stmt->data.forloop.body);
break;
case AST_STMT_MATCH:
free_expression(&stmt->data.match.match_obj);
for (int i = 0; i < stmt->data.match.ncases; i++) {
for (AstExpression *caseobj = stmt->data.match.cases[i].case_objs; caseobj < &stmt->data.match.cases[i].case_objs[stmt->data.match.cases[i].n_case_objs]; caseobj++) {
free_expression(caseobj);
}
free(stmt->data.match.cases[i].case_objs);
free_ast_body(&stmt->data.match.cases[i].body);
}
free(stmt->data.match.cases);
free_ast_body(&stmt->data.match.case_underscore);
break;
case AST_STMT_ASSERT:
free_expression(&stmt->data.assertion.condition);
free(stmt->data.assertion.condition_str);
Expand Down
15 changes: 15 additions & 0 deletions bootstrap_compiler/jou_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ typedef struct AstConditionAndBody AstConditionAndBody;
typedef struct AstExpression AstExpression;
typedef struct AstAssignment AstAssignment;
typedef struct AstForLoop AstForLoop;
typedef struct AstCase AstCase;
typedef struct AstMatchStatement AstMatchStatement;
typedef struct AstNameTypeValue AstNameTypeValue;
typedef struct AstIfStatement AstIfStatement;
typedef struct AstStatement AstStatement;
Expand Down Expand Up @@ -264,6 +266,17 @@ struct AstForLoop {
AstStatement *incr;
AstBody body;
};
struct AstCase {
AstExpression *case_objs;
int n_case_objs;
AstBody body;
};
struct AstMatchStatement {
AstExpression match_obj;
AstCase *cases;
int ncases;
AstBody case_underscore;
};
struct AstIfStatement {
AstConditionAndBody *if_and_elifs;
int n_if_and_elifs; // Always >= 1 for the initial "if"
Expand Down Expand Up @@ -327,6 +340,7 @@ struct AstStatement {
AST_STMT_IF,
AST_STMT_WHILE,
AST_STMT_FOR,
AST_STMT_MATCH,
AST_STMT_BREAK,
AST_STMT_CONTINUE,
AST_STMT_DECLARE_LOCAL_VAR,
Expand All @@ -350,6 +364,7 @@ struct AstStatement {
AstConditionAndBody whileloop;
AstIfStatement ifstatement;
AstForLoop forloop;
AstMatchStatement match;
AstNameTypeValue vardecl;
AstAssignment assignment; // also used for inplace operations
AstFunction function;
Expand Down
47 changes: 47 additions & 0 deletions bootstrap_compiler/parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ static void validate_expression_statement(const AstExpression *expr)
}
}

static void parse_start_of_body(ParserState *ps);
static AstBody parse_body(ParserState *ps);

static AstIfStatement parse_if_statement(ParserState *ps)
Expand Down Expand Up @@ -701,6 +702,49 @@ static AstIfStatement parse_if_statement(ParserState *ps)
};
}

static AstMatchStatement parse_match_statement(ParserState *ps)
{
assert(is_keyword(ps->tokens, "match"));
ps->tokens++;

AstMatchStatement result = {.match_obj = parse_expression(ps)};
parse_start_of_body(ps);

while (ps->tokens->type != TOKEN_DEDENT) {
assert(is_keyword(ps->tokens, "case"));
ps->tokens++;

if (ps->tokens->type == TOKEN_NAME
&& strcmp(ps->tokens->data.name, "_") == 0
&& is_operator(&ps->tokens[1], ":"))
{
// case _:
ps->tokens++;
result.case_underscore = parse_body(ps);
} else {
List(AstExpression) case_objs = {0};
while(1){
Append(&case_objs, parse_expression(ps));
if (is_operator(ps->tokens, "|"))
ps->tokens++;
else if (is_operator(ps->tokens, ":"))
break;
else
fail_with_parse_error(ps->tokens, "'|' or ':'");
}
result.cases = realloc(result.cases, sizeof result.cases[0] * (result.ncases + 1));
result.cases[result.ncases++] = (AstCase){
.case_objs = case_objs.ptr,
.n_case_objs = case_objs.len,
.body = parse_body(ps),
};
}
}
ps->tokens++;
return result;
}


// reverse code golfing: https://xkcd.com/1960/
static enum AstStatementKind determine_the_kind_of_a_statement_that_starts_with_an_expression(
const Token *this_token_is_after_that_initial_expression)
Expand Down Expand Up @@ -1041,6 +1085,9 @@ static AstStatement parse_statement(ParserState *ps)
} else if (is_keyword(ps->tokens, "if")) {
result.kind = AST_STMT_IF;
result.data.ifstatement = parse_if_statement(ps);
} else if (is_keyword(ps->tokens, "match")) {
result.kind = AST_STMT_MATCH;
result.data.match = parse_match_statement(ps);
} else if (is_keyword(ps->tokens, "while")) {
ps->tokens++;
result.kind = AST_STMT_WHILE;
Expand Down
4 changes: 4 additions & 0 deletions bootstrap_compiler/print.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp)
printf("body:\n");
print_ast_body(&stmt->data.forloop.body, sub);
break;
case AST_STMT_MATCH:
printf("match (printing not implemented)\n");
// TODO: implement printing match statement, if needed for debugging
break;
case AST_STMT_BREAK:
printf("break\n");
break;
Expand Down
4 changes: 2 additions & 2 deletions bootstrap_compiler/tokenize.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ static bool is_keyword(const char *s)
"return", "if", "elif", "else", "while", "for", "pass", "break", "continue",
"True", "False", "None", "NULL", "void", "noreturn",
"and", "or", "not", "self", "as", "sizeof", "assert",
"bool", "byte", "short", "int", "long", "float", "double",
"bool", "byte", "short", "int", "long", "float", "double", "match", "case",
};
for (const char **kw = &keywords[0]; kw < &keywords[sizeof(keywords)/sizeof(keywords[0])]; kw++)
if (!strcmp(*kw, s))
Expand Down Expand Up @@ -351,7 +351,7 @@ static const char *read_operator(struct State *st)
// Longer operators are first, so that '==' does not tokenize as '=' '='
"...", "===", "!==",
"==", "!=", "->", "<=", ">=", "++", "--", "+=", "-=", "*=", "/=", "%=", "&&", "||",
".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!",
".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", "|",
NULL,
};

Expand Down
22 changes: 21 additions & 1 deletion bootstrap_compiler/typecheck.c
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,6 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi

static void typecheck_dereferenced_pointer(Location location, const Type *t)
{
// TODO: improved error message for dereferencing void*
if (t->kind != TYPE_POINTER)
fail(location, "the dereference operator '*' is only for pointers, not for %s", t->name);
}
Expand Down Expand Up @@ -1287,6 +1286,23 @@ static void typecheck_if_statement(FileTypes *ft, const AstIfStatement *ifstmt)
typecheck_body(ft, &ifstmt->elsebody);
}

static void typecheck_match_statement(FileTypes *ft, AstMatchStatement *match_stmt)
{
const Type *mtype = typecheck_expression_not_void(ft, &match_stmt->match_obj)->type;
assert(mtype->kind == TYPE_ENUM);

for (int i = 0; i < match_stmt->ncases; i++) {
for (int k = 0; k < match_stmt->cases[i].n_case_objs; k++) {
typecheck_expression_with_implicit_cast(
ft, &match_stmt->cases[i].case_objs[k], mtype,
"case value of type FROM cannot be matched against TO"
);
}
typecheck_body(ft, &match_stmt->cases[i].body);
}
typecheck_body(ft, &match_stmt->case_underscore);
}

static void typecheck_statement(FileTypes *ft, AstStatement *stmt)
{
switch(stmt->kind) {
Expand All @@ -1310,6 +1326,10 @@ static void typecheck_statement(FileTypes *ft, AstStatement *stmt)
typecheck_statement(ft, stmt->data.forloop.incr);
break;

case AST_STMT_MATCH:
typecheck_match_statement(ft, &stmt->data.match);
break;

case AST_STMT_BREAK:
break;

Expand Down
62 changes: 62 additions & 0 deletions compiler/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ enum AstStatementKind:
Pass
Return
If
Match
WhileLoop
ForLoop
Break
Expand Down Expand Up @@ -515,6 +516,7 @@ class AstStatement:
classdef: AstClassDef
enumdef: AstEnumDef
assertion: AstAssertion
match_statement: AstMatchStatement

def print(self) -> None:
self->print_with_tree_printer(TreePrinter{})
Expand All @@ -537,6 +539,8 @@ class AstStatement:
self->if_statement.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.ForLoop:
self->for_loop.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.Match:
self->match_statement.print_with_tree_printer(tp)
elif self->kind == AstStatementKind.WhileLoop:
printf("while loop\n")
self->while_loop.print_with_tree_printer(tp, True)
Expand Down Expand Up @@ -601,6 +605,8 @@ class AstStatement:
self->while_loop.free()
if self->kind == AstStatementKind.ForLoop:
self->for_loop.free()
if self->kind == AstStatementKind.Match:
self->match_statement.free()
if (
self->kind == AstStatementKind.DeclareLocalVar
or self->kind == AstStatementKind.GlobalVariableDeclaration
Expand Down Expand Up @@ -696,6 +702,62 @@ class AstIfStatement:
self->else_body.free()


# match match_obj:
# case ...:
# ...
# case ...:
# ...
class AstMatchStatement:
match_obj: AstExpression
cases: AstCase*
ncases: int
case_underscore: AstBody* # body of "case _" (always last), NULL if no "case _"
case_underscore_location: Location # not meaningful if case_underscore == NULL

def print_with_tree_printer(self, tp: TreePrinter) -> None:
printf("match\n")
for i = 0; i < self->ncases; i++:
self->cases[i].print_with_tree_printer(tp, i == self->ncases - 1 and self->case_underscore == NULL)

if self->case_underscore != NULL:
sub = tp.print_prefix(True)
printf("[line %d] body of case _:\n", self->case_underscore_location.lineno)
self->case_underscore->print_with_tree_printer(sub)

def free(self) -> None:
self->match_obj.free()
for i = 0; i < self->ncases; i++:
self->cases[i].free()
free(self->cases)
if self->case_underscore != NULL:
self->case_underscore->free()
free(self->case_underscore)


# case case_obj1 | case_obj2 | case_obj3:
# body
class AstCase:
case_objs: AstExpression*
n_case_objs: int
body: AstBody

def print_with_tree_printer(self, tp: TreePrinter, is_last_case: bool) -> None:
for i = 0; i < self->n_case_objs; i++:
sub = tp.print_prefix(False)
printf("case_obj: ")
self->case_objs[i].print_with_tree_printer(sub)

sub = tp.print_prefix(is_last_case)
printf("body:\n")
self->body.print_with_tree_printer(sub)

def free(self) -> None:
for i = 0; i < self->n_case_objs; i++:
self->case_objs[i].free()
free(self->case_objs)
self->body.free()


# for init; cond; incr:
# ...body...
class AstForLoop:
Expand Down
Loading

0 comments on commit dd88edc

Please sign in to comment.