diff --git a/bootstrap_compiler/build_cfg.c b/bootstrap_compiler/build_cfg.c index 1d801a43..d1d182bd 100644 --- a/bootstrap_compiler/build_cfg.c +++ b/bootstrap_compiler/build_cfg.c @@ -865,6 +865,33 @@ static void build_loop( add_jump(st, NULL, condblock, condblock, doneblock); } +static void build_match_statament(struct State *st, const AstMatchStatement *match_stmt) +{ + const LocalVariable *match_obj_enum = build_expression(st, &match_stmt->match_obj); + LocalVariable *match_obj_int = add_local_var(st, intType); + add_unary_op(st, match_stmt->match_obj.location, CF_ENUM_TO_INT32, match_obj_enum, match_obj_int); + + CfBlock *done = add_block(st); + for (int i = 0; i < match_stmt->ncases; i++) { + for (AstExpression *caseobj = match_stmt->cases[i].case_objs; caseobj < &match_stmt->cases[i].case_objs[match_stmt->cases[i].n_case_objs]; caseobj++) { + const LocalVariable *case_obj_enum = build_expression(st, caseobj); + LocalVariable *case_obj_int = add_local_var(st, intType); + add_unary_op(st, caseobj->location, CF_ENUM_TO_INT32, case_obj_enum, case_obj_int); + + const LocalVariable *cond = build_binop(st, AST_EXPR_EQ, caseobj->location, match_obj_int, case_obj_int, boolType); + CfBlock *then = add_block(st); + CfBlock *otherwise = add_block(st); + + add_jump(st, cond, then, otherwise, then); + build_body(st, &match_stmt->cases[i].body); + add_jump(st, NULL, done, done, otherwise); + } + } + + build_body(st, &match_stmt->case_underscore); + add_jump(st, NULL, done, done, done); +} + static void build_statement(struct State *st, const AstStatement *stmt) { switch(stmt->kind) { @@ -891,6 +918,10 @@ static void build_statement(struct State *st, const AstStatement *stmt) &stmt->data.forloop.body); break; + case AST_STMT_MATCH: + build_match_statament(st, &stmt->data.match); + break; + case AST_STMT_BREAK: if (!st->breakstack.len) fail(stmt->location, "'break' can only be used inside a loop"); diff --git a/bootstrap_compiler/free.c b/bootstrap_compiler/free.c index e7839ef2..a88c4786 100644 --- a/bootstrap_compiler/free.c +++ b/bootstrap_compiler/free.c @@ -155,6 +155,18 @@ void free_ast_statement(const AstStatement *stmt) free(stmt->data.forloop.incr); free_ast_body(&stmt->data.forloop.body); break; + case AST_STMT_MATCH: + free_expression(&stmt->data.match.match_obj); + for (int i = 0; i < stmt->data.match.ncases; i++) { + for (AstExpression *caseobj = stmt->data.match.cases[i].case_objs; caseobj < &stmt->data.match.cases[i].case_objs[stmt->data.match.cases[i].n_case_objs]; caseobj++) { + free_expression(caseobj); + } + free(stmt->data.match.cases[i].case_objs); + free_ast_body(&stmt->data.match.cases[i].body); + } + free(stmt->data.match.cases); + free_ast_body(&stmt->data.match.case_underscore); + break; case AST_STMT_ASSERT: free_expression(&stmt->data.assertion.condition); free(stmt->data.assertion.condition_str); diff --git a/bootstrap_compiler/jou_compiler.h b/bootstrap_compiler/jou_compiler.h index 47f2f0c5..d6a1b552 100644 --- a/bootstrap_compiler/jou_compiler.h +++ b/bootstrap_compiler/jou_compiler.h @@ -23,6 +23,8 @@ typedef struct AstConditionAndBody AstConditionAndBody; typedef struct AstExpression AstExpression; typedef struct AstAssignment AstAssignment; typedef struct AstForLoop AstForLoop; +typedef struct AstCase AstCase; +typedef struct AstMatchStatement AstMatchStatement; typedef struct AstNameTypeValue AstNameTypeValue; typedef struct AstIfStatement AstIfStatement; typedef struct AstStatement AstStatement; @@ -264,6 +266,17 @@ struct AstForLoop { AstStatement *incr; AstBody body; }; +struct AstCase { + AstExpression *case_objs; + int n_case_objs; + AstBody body; +}; +struct AstMatchStatement { + AstExpression match_obj; + AstCase *cases; + int ncases; + AstBody case_underscore; +}; struct AstIfStatement { AstConditionAndBody *if_and_elifs; int n_if_and_elifs; // Always >= 1 for the initial "if" @@ -327,6 +340,7 @@ struct AstStatement { AST_STMT_IF, AST_STMT_WHILE, AST_STMT_FOR, + AST_STMT_MATCH, AST_STMT_BREAK, AST_STMT_CONTINUE, AST_STMT_DECLARE_LOCAL_VAR, @@ -350,6 +364,7 @@ struct AstStatement { AstConditionAndBody whileloop; AstIfStatement ifstatement; AstForLoop forloop; + AstMatchStatement match; AstNameTypeValue vardecl; AstAssignment assignment; // also used for inplace operations AstFunction function; diff --git a/bootstrap_compiler/parse.c b/bootstrap_compiler/parse.c index 1c555af3..8cb8f794 100644 --- a/bootstrap_compiler/parse.c +++ b/bootstrap_compiler/parse.c @@ -674,6 +674,7 @@ static void validate_expression_statement(const AstExpression *expr) } } +static void parse_start_of_body(ParserState *ps); static AstBody parse_body(ParserState *ps); static AstIfStatement parse_if_statement(ParserState *ps) @@ -701,6 +702,49 @@ static AstIfStatement parse_if_statement(ParserState *ps) }; } +static AstMatchStatement parse_match_statement(ParserState *ps) +{ + assert(is_keyword(ps->tokens, "match")); + ps->tokens++; + + AstMatchStatement result = {.match_obj = parse_expression(ps)}; + parse_start_of_body(ps); + + while (ps->tokens->type != TOKEN_DEDENT) { + assert(is_keyword(ps->tokens, "case")); + ps->tokens++; + + if (ps->tokens->type == TOKEN_NAME + && strcmp(ps->tokens->data.name, "_") == 0 + && is_operator(&ps->tokens[1], ":")) + { + // case _: + ps->tokens++; + result.case_underscore = parse_body(ps); + } else { + List(AstExpression) case_objs = {0}; + while(1){ + Append(&case_objs, parse_expression(ps)); + if (is_operator(ps->tokens, "|")) + ps->tokens++; + else if (is_operator(ps->tokens, ":")) + break; + else + fail_with_parse_error(ps->tokens, "'|' or ':'"); + } + result.cases = realloc(result.cases, sizeof result.cases[0] * (result.ncases + 1)); + result.cases[result.ncases++] = (AstCase){ + .case_objs = case_objs.ptr, + .n_case_objs = case_objs.len, + .body = parse_body(ps), + }; + } + } + ps->tokens++; + return result; +} + + // reverse code golfing: https://xkcd.com/1960/ static enum AstStatementKind determine_the_kind_of_a_statement_that_starts_with_an_expression( const Token *this_token_is_after_that_initial_expression) @@ -1041,6 +1085,9 @@ static AstStatement parse_statement(ParserState *ps) } else if (is_keyword(ps->tokens, "if")) { result.kind = AST_STMT_IF; result.data.ifstatement = parse_if_statement(ps); + } else if (is_keyword(ps->tokens, "match")) { + result.kind = AST_STMT_MATCH; + result.data.match = parse_match_statement(ps); } else if (is_keyword(ps->tokens, "while")) { ps->tokens++; result.kind = AST_STMT_WHILE; diff --git a/bootstrap_compiler/print.c b/bootstrap_compiler/print.c index a20a062c..72892c39 100644 --- a/bootstrap_compiler/print.c +++ b/bootstrap_compiler/print.c @@ -364,6 +364,10 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp) printf("body:\n"); print_ast_body(&stmt->data.forloop.body, sub); break; + case AST_STMT_MATCH: + printf("match (printing not implemented)\n"); + // TODO: implement printing match statement, if needed for debugging + break; case AST_STMT_BREAK: printf("break\n"); break; diff --git a/bootstrap_compiler/tokenize.c b/bootstrap_compiler/tokenize.c index 33a236c2..a0303177 100644 --- a/bootstrap_compiler/tokenize.c +++ b/bootstrap_compiler/tokenize.c @@ -222,7 +222,7 @@ static bool is_keyword(const char *s) "return", "if", "elif", "else", "while", "for", "pass", "break", "continue", "True", "False", "None", "NULL", "void", "noreturn", "and", "or", "not", "self", "as", "sizeof", "assert", - "bool", "byte", "short", "int", "long", "float", "double", + "bool", "byte", "short", "int", "long", "float", "double", "match", "case", }; for (const char **kw = &keywords[0]; kw < &keywords[sizeof(keywords)/sizeof(keywords[0])]; kw++) if (!strcmp(*kw, s)) @@ -351,7 +351,7 @@ static const char *read_operator(struct State *st) // Longer operators are first, so that '==' does not tokenize as '=' '=' "...", "===", "!==", "==", "!=", "->", "<=", ">=", "++", "--", "+=", "-=", "*=", "/=", "%=", "&&", "||", - ".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", + ".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", "|", NULL, }; diff --git a/bootstrap_compiler/typecheck.c b/bootstrap_compiler/typecheck.c index ce60b385..224572ca 100644 --- a/bootstrap_compiler/typecheck.c +++ b/bootstrap_compiler/typecheck.c @@ -794,7 +794,6 @@ static const Type *check_increment_or_decrement(FileTypes *ft, const AstExpressi static void typecheck_dereferenced_pointer(Location location, const Type *t) { - // TODO: improved error message for dereferencing void* if (t->kind != TYPE_POINTER) fail(location, "the dereference operator '*' is only for pointers, not for %s", t->name); } @@ -1287,6 +1286,23 @@ static void typecheck_if_statement(FileTypes *ft, const AstIfStatement *ifstmt) typecheck_body(ft, &ifstmt->elsebody); } +static void typecheck_match_statement(FileTypes *ft, AstMatchStatement *match_stmt) +{ + const Type *mtype = typecheck_expression_not_void(ft, &match_stmt->match_obj)->type; + assert(mtype->kind == TYPE_ENUM); + + for (int i = 0; i < match_stmt->ncases; i++) { + for (int k = 0; k < match_stmt->cases[i].n_case_objs; k++) { + typecheck_expression_with_implicit_cast( + ft, &match_stmt->cases[i].case_objs[k], mtype, + "case value of type FROM cannot be matched against TO" + ); + } + typecheck_body(ft, &match_stmt->cases[i].body); + } + typecheck_body(ft, &match_stmt->case_underscore); +} + static void typecheck_statement(FileTypes *ft, AstStatement *stmt) { switch(stmt->kind) { @@ -1310,6 +1326,10 @@ static void typecheck_statement(FileTypes *ft, AstStatement *stmt) typecheck_statement(ft, stmt->data.forloop.incr); break; + case AST_STMT_MATCH: + typecheck_match_statement(ft, &stmt->data.match); + break; + case AST_STMT_BREAK: break; diff --git a/compiler/ast.jou b/compiler/ast.jou index bc7fd191..09606f69 100644 --- a/compiler/ast.jou +++ b/compiler/ast.jou @@ -482,6 +482,7 @@ enum AstStatementKind: Pass Return If + Match WhileLoop ForLoop Break @@ -515,6 +516,7 @@ class AstStatement: classdef: AstClassDef enumdef: AstEnumDef assertion: AstAssertion + match_statement: AstMatchStatement def print(self) -> None: self->print_with_tree_printer(TreePrinter{}) @@ -537,6 +539,8 @@ class AstStatement: self->if_statement.print_with_tree_printer(tp) elif self->kind == AstStatementKind.ForLoop: self->for_loop.print_with_tree_printer(tp) + elif self->kind == AstStatementKind.Match: + self->match_statement.print_with_tree_printer(tp) elif self->kind == AstStatementKind.WhileLoop: printf("while loop\n") self->while_loop.print_with_tree_printer(tp, True) @@ -601,6 +605,8 @@ class AstStatement: self->while_loop.free() if self->kind == AstStatementKind.ForLoop: self->for_loop.free() + if self->kind == AstStatementKind.Match: + self->match_statement.free() if ( self->kind == AstStatementKind.DeclareLocalVar or self->kind == AstStatementKind.GlobalVariableDeclaration @@ -696,6 +702,62 @@ class AstIfStatement: self->else_body.free() +# match match_obj: +# case ...: +# ... +# case ...: +# ... +class AstMatchStatement: + match_obj: AstExpression + cases: AstCase* + ncases: int + case_underscore: AstBody* # body of "case _" (always last), NULL if no "case _" + case_underscore_location: Location # not meaningful if case_underscore == NULL + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("match\n") + for i = 0; i < self->ncases; i++: + self->cases[i].print_with_tree_printer(tp, i == self->ncases - 1 and self->case_underscore == NULL) + + if self->case_underscore != NULL: + sub = tp.print_prefix(True) + printf("[line %d] body of case _:\n", self->case_underscore_location.lineno) + self->case_underscore->print_with_tree_printer(sub) + + def free(self) -> None: + self->match_obj.free() + for i = 0; i < self->ncases; i++: + self->cases[i].free() + free(self->cases) + if self->case_underscore != NULL: + self->case_underscore->free() + free(self->case_underscore) + + +# case case_obj1 | case_obj2 | case_obj3: +# body +class AstCase: + case_objs: AstExpression* + n_case_objs: int + body: AstBody + + def print_with_tree_printer(self, tp: TreePrinter, is_last_case: bool) -> None: + for i = 0; i < self->n_case_objs; i++: + sub = tp.print_prefix(False) + printf("case_obj: ") + self->case_objs[i].print_with_tree_printer(sub) + + sub = tp.print_prefix(is_last_case) + printf("body:\n") + self->body.print_with_tree_printer(sub) + + def free(self) -> None: + for i = 0; i < self->n_case_objs; i++: + self->case_objs[i].free() + free(self->case_objs) + self->body.free() + + # for init; cond; incr: # ...body... class AstForLoop: diff --git a/compiler/build_cf_graph.jou b/compiler/build_cf_graph.jou index 550ce372..721835ff 100644 --- a/compiler/build_cf_graph.jou +++ b/compiler/build_cf_graph.jou @@ -340,9 +340,9 @@ class CfBuilder: if not t->is_integer_type() and not t->is_pointer_type(): msg: byte[500] if diff == 1: - snprintf(msg, sizeof(msg), "cannot increment a value of type %s", t->name) + snprintf(msg, sizeof(msg), "cannot increment a value of type %s", t->name) else: - snprintf(msg, sizeof(msg), "cannot decrement a value of type %s", t->name) + snprintf(msg, sizeof(msg), "cannot decrement a value of type %s", t->name) fail(location, msg) old_value = self->add_var(t) @@ -814,7 +814,7 @@ class CfBuilder: done = self->add_block() for i = 0; i < ifstmt->n_if_and_elifs; i++: - cond: LocalVariable* = self->build_expression(&ifstmt->if_and_elifs[i].condition) + cond = self->build_expression(&ifstmt->if_and_elifs[i].condition) then = self->add_block() otherwise = self->add_block() @@ -825,6 +825,35 @@ class CfBuilder: self->build_body(&ifstmt->else_body) self->jump(NULL, done, done, done) + def build_match_statement(self, match_stmt: AstMatchStatement*) -> None: + match_obj_enum = self->build_expression(&match_stmt->match_obj) + match_obj_int = self->add_var(intType) + self->unary_op(match_stmt->match_obj.location, CfInstructionKind.EnumToInt32, match_obj_enum, match_obj_int) + + done = self->add_block() + for i = 0; i < match_stmt->ncases; i++: + then = self->add_block() + otherwise: CfBlock* = NULL + for k = 0; k < match_stmt->cases[i].n_case_objs; k++: + case_obj_ast = &match_stmt->cases[i].case_objs[k] + case_obj_enum = self->build_expression(case_obj_ast) + case_obj_int = self->add_var(intType) + self->unary_op(case_obj_ast->location, CfInstructionKind.EnumToInt32, case_obj_enum, case_obj_int) + cond = self->build_binop(AstExpressionKind.Eq, case_obj_ast->location, match_obj_int, case_obj_int, boolType) + + otherwise = self->add_block() + self->jump(cond, then, otherwise, otherwise) + + assert otherwise != NULL + + self->current_block = then + self->build_body(&match_stmt->cases[i].body) + self->jump(NULL, done, done, otherwise) + + if match_stmt->case_underscore != NULL: + self->build_body(match_stmt->case_underscore) + self->jump(NULL, done, done, done) + def build_assert(self, assert_location: Location, assertion: AstAssertion*) -> None: condvar = self->build_expression(&assertion->condition) @@ -934,6 +963,8 @@ class CfBuilder: self->build_loop( stmt->for_loop.init, &stmt->for_loop.cond, stmt->for_loop.incr, &stmt->for_loop.body) + elif stmt->kind == AstStatementKind.Match: + self->build_match_statement(&stmt->match_statement) elif stmt->kind == AstStatementKind.Break: if self->nloops == 0: fail(stmt->location, "'break' can only be used inside a loop") diff --git a/compiler/parser.jou b/compiler/parser.jou index c0765dba..67ada548 100644 --- a/compiler/parser.jou +++ b/compiler/parser.jou @@ -856,6 +856,55 @@ class Parser: body = self->parse_body(), } + def parse_match_statement(self) -> AstMatchStatement: + assert self->tokens->is_keyword("match") + self->tokens++ + + result = AstMatchStatement{match_obj = self->parse_expression()} + self->parse_start_of_body() + + while self->tokens->kind != TokenKind.Dedent: + if not self->tokens->is_keyword("case"): + self->tokens->fail_expected_got("the 'case' keyword") + if result.case_underscore != NULL: + fail( + self->tokens->location, + "this case will never run, because 'case _:' above matches anything", + ) + self->tokens++ + + if ( + self->tokens->kind == TokenKind.Name + and strcmp(self->tokens->short_string, "_") == 0 + and self->tokens[1].is_operator(":") + ): + # case _: + result.case_underscore_location = (self->tokens++)->location + result.case_underscore = malloc(sizeof(*result.case_underscore)) + assert result.case_underscore != NULL + *result.case_underscore = self->parse_body() + else: + case_objs: AstExpression* = NULL + n_case_objs = 0 + while True: + case_objs = realloc(case_objs, sizeof(case_objs[0]) * (n_case_objs + 1)) + case_objs[n_case_objs++] = self->parse_expression() + if self->tokens->is_operator("|"): + self->tokens++ + elif self->tokens->is_operator(":"): + break + else: + self->tokens->fail_expected_got("'|' or ':'") + result.cases = realloc(result.cases, sizeof result.cases[0] * (result.ncases + 1)) + result.cases[result.ncases++] = AstCase{ + case_objs = case_objs, + n_case_objs = n_case_objs, + body = self->parse_body(), + } + self->tokens++ + + return result + # Parses the "x: int" part of "x, y, z: int", leaving "y, z: int" to be parsed later. def parse_first_of_multiple_local_var_declares(self) -> AstNameTypeValue: assert self->tokens->kind == TokenKind.Name @@ -976,6 +1025,13 @@ class Parser: while_loop = self->parse_while_loop(), } + if self->tokens->is_keyword("match"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind.Match, + match_statement = self->parse_match_statement(), + } + if ( self->tokens[0].kind == TokenKind.Name and self->tokens[1].is_operator(",") diff --git a/compiler/tokenizer.jou b/compiler/tokenizer.jou index 6f103bd3..c8686094 100644 --- a/compiler/tokenizer.jou +++ b/compiler/tokenizer.jou @@ -26,7 +26,7 @@ def is_keyword(word: byte*) -> bool: "return", "if", "elif", "else", "while", "for", "pass", "break", "continue", "True", "False", "None", "NULL", "void", "noreturn", "and", "or", "not", "self", "as", "sizeof", "assert", - "bool", "byte", "short", "int", "long", "float", "double", + "bool", "byte", "short", "int", "long", "float", "double", "match", "case", ] for i = 0; i < sizeof keywords / sizeof keywords[0]; i++: @@ -382,7 +382,7 @@ class Tokenizer: # Longer operators are first, so that '==' does not tokenize as '=' '=' "...", "===", "!==", "==", "!=", "->", "<=", ">=", "++", "--", "+=", "-=", "*=", "/=", "%=", "&&", "||", - ".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", + ".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", "|", ] operator: byte[100] diff --git a/compiler/typecheck/step3_function_and_method_bodies.jou b/compiler/typecheck/step3_function_and_method_bodies.jou index 84b42c8a..1c9ef654 100644 --- a/compiler/typecheck/step3_function_and_method_bodies.jou +++ b/compiler/typecheck/step3_function_and_method_bodies.jou @@ -1033,6 +1033,85 @@ def typecheck_if_statement(ft: FileTypes*, ifstmt: AstIfStatement*) -> None: typecheck_body(ft, &ifstmt->else_body) +def typecheck_match_statement(ft: FileTypes*, match_stmt: AstMatchStatement*) -> None: + msg: byte[500] + + mtype = typecheck_expression_not_void(ft, &match_stmt->match_obj) + if mtype->kind != TypeKind.Enum: + # TODO: extend match statements to other equals-comparable values + snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", mtype->name) + fail(match_stmt->match_obj.location, msg) + + # Ensure user checks all possible enum values + nremaining = mtype->enummembers.count + remaining: byte** = malloc(sizeof(remaining[0]) * nremaining) + assert remaining != NULL + for i = 0; i < nremaining; i++: + remaining[i] = mtype->enummembers.names[i] + + for i = 0; i < match_stmt->ncases; i++: + for j = 0; j < match_stmt->cases[i].n_case_objs; j++: + case_obj = &match_stmt->cases[i].case_objs[j] + typecheck_expression_with_implicit_cast( + ft, case_obj, mtype, + "case value of type cannot be matched against ", + ) + + # The special casing only applies to simple enum member lookups (TheEnum.TheMember syntax) + if case_obj->kind != AstExpressionKind.GetEnumMember: + nremaining = -1 + + if nremaining != -1: + # We are matching against TheEnum.member. Try to find and remove it from remaining members. + member = case_obj->enum_member.member_name + found = False + for k = 0; k < nremaining; k++: + if strcmp(remaining[k], member) == 0: + memmove(&remaining[k], &remaining[k+1], sizeof(remaining[0]) * (--nremaining - k)) + found = True + break + + if not found: + snprintf(msg, sizeof(msg), "enum member %s is handled twice", member) + fail(case_obj->location, msg) + + typecheck_body(ft, &match_stmt->cases[i].body) + + # Do not complain if there is a seemingly unnecessary 'case _', because it + # may be ran by casting an integer to enum with 'as'. However, we can + # complain if handling for enum members is missing. + if nremaining > 0 and match_stmt->case_underscore == NULL: + if nremaining == 1: + snprintf( + msg, sizeof(msg), + "enum member %s.%s not handled in match statement", + mtype->name, remaining[0], + ) + else: + snprintf( + msg, sizeof(msg) - 20, + "the following %d members of enum %s are not handled in match statement: ", + nremaining, mtype->name, + ) + for i = 0; i < nremaining; i++: + assert sizeof(msg) > 300 + if strlen(msg) + strlen(remaining[i]) < 200: + strcat(msg, remaining[i]) + strcat(msg, ", ") + else: + strcat(msg, "...") + break + if ends_with(msg, ", "): + msg[strlen(msg) - 2] = '\0' + + fail(match_stmt->match_obj.location, msg) + + free(remaining) + + if match_stmt->case_underscore != NULL: + typecheck_body(ft, match_stmt->case_underscore) + + def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: msg: byte[500] @@ -1053,6 +1132,9 @@ def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: typecheck_body(ft, &stmt->for_loop.body) typecheck_statement(ft, stmt->for_loop.incr) + elif stmt->kind == AstStatementKind.Match: + typecheck_match_statement(ft, &stmt->match_statement) + elif ( stmt->kind == AstStatementKind.Break or stmt->kind == AstStatementKind.Continue diff --git a/doc/enums.md b/doc/enums.md index 3c60bc23..dfb3b3e7 100644 --- a/doc/enums.md +++ b/doc/enums.md @@ -5,22 +5,29 @@ TL;DR: ```python import "stdlib/io.jou" -enum Foo: +enum Thingy: + Foo Bar Baz def main() -> int: - thing = Foo.Bar + thing = Thingy.Bar - if thing == Foo.Bar: - printf("It's bar\n") # Output: It's bar - elif thing == Foo.Baz: - printf("It's baz\n") - else: - assert False # never happens + match thing: + case Thingy.Foo: + printf("It's foo\n") + case Thingy.Bar | Thingy.Baz: + printf("It's bar or baz\n") # Output: It's bar or baz + + match thing: + case Thingy.Foo: + printf("It's foo\n") + case _: + printf("It's not foo\n") # Output: It's not foo - printf("%d\n", Foo.Bar) # Output: 0 - printf("%d\n", Foo.Baz) # Output: 1 + printf("%d\n", Thingy.Foo as int) # Output: 0 + printf("%d\n", Thingy.Bar as int) # Output: 1 + printf("%d\n", Thingy.Baz as int) # Output: 2 return 0 ``` @@ -77,18 +84,17 @@ enum Operation: FloorDivide # 7 / 3 produces 2 def calculate(a: double, b: double, op: Operation) -> double: - if op == Operation.Add: - return a + b - if op == Operation.Subtract: - return a - b - if op == Operation.Multiply: - return a * b - if op == Operation.Divide: - return a / b - if op == Operation.FloorDivide: - return floor(a / b) - else: - assert False # not possible + match op: + case Operation.Add: + return a + b + case Operation.Subtract: + return a - b + case Operation.Multiply: + return a * b + case Operation.Divide: + return a / b + case Operation.FloorDivide: + return floor(a / b) def main() -> int: printf("%f\n", calculate(7, 3, Operation.Divide)) # Output: 2.333333 @@ -96,8 +102,66 @@ def main() -> int: return 0 ``` -Here `enum Operation` defines a new enum, which has 5 possible values. -You can then use `if` statements to check which value an instance of `Operation` is. +Here `enum Operation` defines a new enum, which has 5 members. +You will get a compiler error if you don't handle all 5 members in the `match` statement. +This is useful especially when adding a new feature to a large existing program. +Instead of a `match` statement, you can also use `if` and `elif` with enums, +but then the compiler won't complain if you don't handle all enum members. + +Sometimes you don't want to list all enum members in a `match` statement. +You can use `case _:` to catch all remaining enum members: + +```python +import "stdlib/io.jou" + +enum Operation: + Add + Subtract + Multiply + Divide + FloorDivide + +def main() -> int: + op = Operation.Divide + + match op: + case Operation.Add: + printf("It's adding\n") + case Operation.Subtract: + printf("It's subtracting\n") + case _: + printf("Not adding or subtracting\n") # Output: Not adding or subtracting + + return 0 +``` + +You can also combine multiple cases with `|`. +Read the `|` operator as "or" when it's used in match statements. + +```python +import "stdlib/io.jou" + +enum Operation: + Add + Subtract + Multiply + Divide + FloorDivide + +def main() -> int: + op = Operation.FloorDivide + + match op: + case Operation.Divide | Operation.FloorDivide: + printf("It's dividing\n") # Output: It's dividing + case _: + pass + + return 0 +``` + +Here `case _: pass` is needed to ignore the enum members that were not mentioned, +because without it, you will get a compiler error saying that you didn't handle all possible values. ## Integer conversions @@ -138,7 +202,8 @@ def main() -> int: ``` You can also convert integers to enums, -but note that the result might not correspond with any member of the enum: +but note that the result might not correspond with any member of the enum. +In a `match` statements, only `case _:` matches these values. ```python import "stdlib/io.jou" @@ -149,17 +214,18 @@ enum Operation: Multiply def main() -> int: - wat = 7 as Operation - - if wat == Operation.Add: - printf("Add\n") - elif wat == Operation.Subtract: - printf("Subtract\n") - elif wat == Operation.Multiply: - printf("Multiply\n") - else: - # Output: something else 7 - printf("something else %d\n", wat as int) + wat = 42 as Operation + + match wat: + case Operation.Add: + printf("Add\n") + case Operation.Subtract: + printf("Subtract\n") + case Operation.Multiply: + printf("Multiply\n") + case _: + # Output: something else 42 + printf("something else %d\n", wat as int) return 0 ``` diff --git a/doc/syntax-spec.md b/doc/syntax-spec.md index 1c1e6829..46236c62 100644 --- a/doc/syntax-spec.md +++ b/doc/syntax-spec.md @@ -106,6 +106,8 @@ Jou has a few different kinds of tokens: - `long` - `float` - `double` + - `match` + - `case` - **Newline tokens** occur at the end of a line of code. Lines that only contain spaces and comments do not produce a newline token; this ensures that blank lines are ignored as they should be. @@ -114,7 +116,7 @@ Jou has a few different kinds of tokens: Indent tokens always occur just after newline tokens. It is an error if the code is indented with tabs or with some indentation size other than 4 spaces. - **Dedent tokens** are added whenever the amount of indentation decreases by 4 spaces. -- **Operator tokens** are any of the following: `... == != -> <= >= ++ -- += -= *= /= %= . , : ; = ( ) { } [ ] & % * / + - < >` +- **Operator tokens** are any of the following: `... == != -> <= >= ++ -- += -= *= /= %= . , : ; = ( ) { } [ ] & % * / + - < > |` Note that `a = = b` and `a == b` do different things: `a = = b` tokenizes as 4 tokens (and the parser errors when it sees the tokens) while `a == b` tokenizes as 3 tokens. diff --git a/tests/other_errors/match_missing_1.jou b/tests/other_errors/match_missing_1.jou new file mode 100644 index 00000000..b154f5c5 --- /dev/null +++ b/tests/other_errors/match_missing_1.jou @@ -0,0 +1,11 @@ +enum Foo: + One + Two + Three + +def blah() -> None: + match Foo.One: # Error: enum member Foo.Three not handled in match statement + case Foo.One: + pass + case Foo.Two: + pass diff --git a/tests/other_errors/match_missing_1_or.jou b/tests/other_errors/match_missing_1_or.jou new file mode 100644 index 00000000..942d5a7d --- /dev/null +++ b/tests/other_errors/match_missing_1_or.jou @@ -0,0 +1,12 @@ +enum Foo: + One + Two + Three + Four + +def blah() -> None: + match Foo.One: # Error: enum member Foo.Four not handled in match statement + case Foo.One | Foo.Two: + pass + case Foo.Three: + pass diff --git a/tests/other_errors/match_missing_2.jou b/tests/other_errors/match_missing_2.jou new file mode 100644 index 00000000..311a8b4b --- /dev/null +++ b/tests/other_errors/match_missing_2.jou @@ -0,0 +1,12 @@ +enum Foo: + One + Two + Three + Four + +def blah() -> None: + match Foo.One: # Error: the following 2 members of enum Foo are not handled in match statement: Three, Four + case Foo.One: + pass + case Foo.Two: + pass diff --git a/tests/other_errors/match_missing_28.jou b/tests/other_errors/match_missing_28.jou new file mode 100644 index 00000000..443e6816 --- /dev/null +++ b/tests/other_errors/match_missing_28.jou @@ -0,0 +1,39 @@ +enum First30: + # This is AI generated code :) + One + Two + Three + Four + Five + Six + Seven + Eight + Nine + Ten + Eleven + Twelve + Thirteen + Fourteen + Fifteen + Sixteen + Seventeen + Eighteen + Nineteen + Twenty + TwentyOne + TwentyTwo + TwentyThree + TwentyFour + TwentyFive + TwentySix + TwentySeven + TwentyEight + TwentyNine + Thirty + +def blah() -> None: + match First30.One: # Error: the following 28 members of enum First30 are not handled in match statement: Three, Four, Five, Six, Seven, Eight, Nine, Ten, Eleven, Twelve, Thirteen, Fourteen, Fifteen, Sixteen, Seventeen, Eighteen, ... + case First30.One: + pass + case First30.Two: + pass diff --git a/tests/other_errors/match_twice.jou b/tests/other_errors/match_twice.jou new file mode 100644 index 00000000..e5301a2c --- /dev/null +++ b/tests/other_errors/match_twice.jou @@ -0,0 +1,15 @@ +enum Foo: + One + Two + Three + +def blah() -> None: + match Foo.One: + case Foo.One: + pass + case Foo.Two: + pass + case Foo.Three: + pass + case Foo.Two: # Error: enum member Two is handled twice + pass diff --git a/tests/should_succeed/match.jou b/tests/should_succeed/match.jou new file mode 100644 index 00000000..29309412 --- /dev/null +++ b/tests/should_succeed/match.jou @@ -0,0 +1,103 @@ +import "stdlib/io.jou" + +enum Foo: + Bar + Baz + Lol + Wut + + +def show_evaluation(foo: Foo, msg: byte*) -> Foo: + puts(msg) + return foo + + +def main() -> int: + f = Foo.Bar + match f: + case Foo.Bar: + printf("Bar\n") # Output: Bar + case Foo.Baz: + printf("Baz\n") + case _: + printf("Other\n") + + match Foo.Bar: + case Foo.Bar | Foo.Baz: + printf("yay\n") # Output: yay + case _: + printf("nope\n") + + match Foo.Baz: + case Foo.Bar | Foo.Baz: + printf("yay\n") # Output: yay + case _: + printf("nope\n") + + f = 69 as Foo + match f: + case Foo.Bar: + printf("nope\n") + case Foo.Baz: + printf("nope\n") + case Foo.Lol: + printf("nope\n") + case Foo.Wut: + printf("nope\n") + # There's no compiler warning for this "case _", because as you can + # see it's not as unnecessary as you think :) + case _: + printf("Other!!!! %d\n", f as int) # Output: Other!!!! 69 + + # "case _" is not needed, will do nothing if enum is outside allowed range + f = Foo.Lol + match f: + case Foo.Bar: + printf("nope\n") + case Foo.Baz: + printf("nope\n") + case Foo.Lol: + printf("Hey! :)\n") # Output: Hey! :) + case Foo.Wut: + printf("nope\n") + + f = 12345 as Foo + match f: + case Foo.Bar: + printf("nope\n") + case Foo.Baz: + printf("nope\n") + case Foo.Lol: + printf("nope\n") + case Foo.Wut: + printf("nope\n") + + # Test evaluation order. + # + # Output: match obj + # Output: case 1 + # Output: case 2 + # Output: case 3 + # Output: ye + match show_evaluation(Foo.Lol, "match obj"): + case show_evaluation(Foo.Bar, "case 1"): + printf("nope\n") + case show_evaluation(Foo.Baz, "case 2"): + printf("nope\n") + case show_evaluation(Foo.Lol, "case 3"): + printf("ye\n") + case show_evaluation(Foo.Wut, "case 4"): + printf("nope\n") + + # Output: match obj + # Output: case 1 + # Output: case 2 + # Output: case 3 + # Output: ye + match show_evaluation(Foo.Lol, "match obj"): + case show_evaluation(Foo.Bar, "case 1") | show_evaluation(Foo.Baz, "case 2"): + printf("nope\n") + case show_evaluation(Foo.Lol, "case 3") | show_evaluation(Foo.Wut, "case 4"): + printf("ye\n") + + return 0 diff --git a/tests/syntax_error/match_bad_case.jou b/tests/syntax_error/match_bad_case.jou new file mode 100644 index 00000000..75cb7ec1 --- /dev/null +++ b/tests/syntax_error/match_bad_case.jou @@ -0,0 +1,9 @@ +enum Foo: + One + Two + Three + +def blah(f: Foo) -> None: + match f: + case Foo.One Foo.Two: # Error: expected '|' or ':', got a variable name 'Foo' + printf("hi\n") diff --git a/tests/syntax_error/match_no_case_keyword.jou b/tests/syntax_error/match_no_case_keyword.jou new file mode 100644 index 00000000..a9ace72c --- /dev/null +++ b/tests/syntax_error/match_no_case_keyword.jou @@ -0,0 +1,9 @@ +enum Foo: + One + Two + Three + +def blah(f: Foo) -> None: + match f: + Foo.One: # Error: expected the 'case' keyword, got a variable name 'Foo' + printf("hi\n") diff --git a/tests/syntax_error/match_underscore_not_last.jou b/tests/syntax_error/match_underscore_not_last.jou new file mode 100644 index 00000000..334aadee --- /dev/null +++ b/tests/syntax_error/match_underscore_not_last.jou @@ -0,0 +1,13 @@ +enum Foo: + One + Two + Three + +def blah() -> None: + match Foo.One: + case Foo.One: + pass + case _: + pass + case Foo.Two: # Error: this case will never run, because 'case _:' above matches anything + pass diff --git a/tests/wrong_type/match_not_enum.jou b/tests/wrong_type/match_not_enum.jou new file mode 100644 index 00000000..7607fcee --- /dev/null +++ b/tests/wrong_type/match_not_enum.jou @@ -0,0 +1,5 @@ +def main() -> int: + match 123: # Error: match statements can only be used with enums, not int + case 123: + pass + return 0 diff --git a/tests/wrong_type/match_wrong_enum.jou b/tests/wrong_type/match_wrong_enum.jou new file mode 100644 index 00000000..9b9d6781 --- /dev/null +++ b/tests/wrong_type/match_wrong_enum.jou @@ -0,0 +1,13 @@ +enum Foo: + One + Two + +enum Bar: + Three + Four + Five + +def bruh() -> None: + match Foo.One: + case Bar.Three: # Error: case value of type Bar cannot be matched against Foo + printf("hi\n")