diff --git a/compiler/ast.jou b/compiler/ast.jou index b198f2cd..8020b899 100644 --- a/compiler/ast.jou +++ b/compiler/ast.jou @@ -695,13 +695,15 @@ class AstIfStatement: self->else_body.free() -# match match_obj: +# match match_obj with func: # case ...: # ... # case ...: # ... class AstMatchStatement: match_obj: AstExpression + func_name: byte[100] # empty if there's no "with foo" + func_signature: Signature # populated in typecheck, zero-initialized before typecheck runs cases: AstCase* ncases: int case_underscore: AstBody* # body of "case _" (always last), NULL if no "case _" @@ -719,6 +721,8 @@ class AstMatchStatement: def free(self) -> None: self->match_obj.free() + if self->func_signature.name[0] != '\0': + self->func_signature.free() for i = 0; i < self->ncases; i++: self->cases[i].free() free(self->cases) diff --git a/compiler/builders/ast_to_builder.jou b/compiler/builders/ast_to_builder.jou index 91ff18d4..58461e41 100644 --- a/compiler/builders/ast_to_builder.jou +++ b/compiler/builders/ast_to_builder.jou @@ -468,7 +468,13 @@ class AstToBuilder: otherwise = BuilderBlock{} # will be replaced by loop below for k = 0; k < match_stmt->cases[i].n_case_objs; k++: case_obj = self->build_expression(&match_stmt->cases[i].case_objs[k]) - cond = self->builder->eq(match_obj, case_obj) + if match_stmt->func_name[0] == '\0': + cond = self->builder->eq(match_obj, case_obj) + else: + args = [match_obj, case_obj] + func_ret = self->builder->call(&match_stmt->func_signature, args, 2) + zero = self->builder->integer(func_ret.type, 0) + cond = self->builder->eq(func_ret, zero) otherwise = self->builder->add_block() self->builder->branch(cond, then, otherwise) self->builder->set_current_block(otherwise) diff --git a/compiler/parser.jou b/compiler/parser.jou index c8150f23..66e35300 100644 --- a/compiler/parser.jou +++ b/compiler/parser.jou @@ -938,6 +938,13 @@ class Parser: self->tokens++ result = AstMatchStatement{match_obj = self->parse_expression()} + + if self->tokens->is_keyword("with"): + self->tokens++ + if self->tokens->kind != TokenKind.Name: + self->tokens->fail_expected_got("function name") + result.func_name = (self->tokens++)->short_string + self->parse_start_of_body() while self->tokens->kind != TokenKind.Dedent: diff --git a/compiler/tokenizer.jou b/compiler/tokenizer.jou index 7de8643e..2cc31f2b 100644 --- a/compiler/tokenizer.jou +++ b/compiler/tokenizer.jou @@ -26,7 +26,7 @@ def is_keyword(word: byte*) -> bool: "return", "if", "elif", "else", "while", "for", "pass", "break", "continue", "True", "False", "None", "NULL", "void", "noreturn", "and", "or", "not", "self", "as", "sizeof", "assert", - "bool", "byte", "short", "int", "long", "float", "double", "match", "case", + "bool", "byte", "short", "int", "long", "float", "double", "match", "with", "case", ] for i = 0; i < sizeof keywords / sizeof keywords[0]; i++: diff --git a/compiler/typecheck/step3_function_and_method_bodies.jou b/compiler/typecheck/step3_function_and_method_bodies.jou index 21d91ea6..5407b5d7 100644 --- a/compiler/typecheck/step3_function_and_method_bodies.jou +++ b/compiler/typecheck/step3_function_and_method_bodies.jou @@ -603,7 +603,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type: else: function_or_method = "method" - sigstr = sig->to_string(False, False) + sig_string = sig->to_string(False, False) nargs = sig->nargs if self_type != NULL: @@ -614,7 +614,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type: msg, sizeof(msg), "%s %s takes %d argument%s, but it was called with %d argument%s", function_or_method, - sigstr, + sig_string, nargs, plural_s(nargs), call->nargs, @@ -631,7 +631,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type: which_arg = nth(++selfless_counter) snprintf( msg, sizeof msg, - "%s argument of %s %s should have type , not ", which_arg, function_or_method, sigstr + "%s argument of %s %s should have type , not ", which_arg, function_or_method, sig_string ) typecheck_expression_with_implicit_cast(state, &call->args[k++], sig->argtypes[i], msg) @@ -653,7 +653,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type: ) fail(call->args[i].location, msg) - free(sigstr) + free(sig_string) return sig->returntype @@ -1076,26 +1076,52 @@ def typecheck_if_statement(state: State*, ifstmt: AstIfStatement*) -> None: def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> None: msg: byte[500] - mtype = typecheck_expression_not_void(state, &match_stmt->match_obj) - if mtype->kind != TypeKind.Enum: - # TODO: extend match statements to other equals-comparable values - snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", mtype->name) - fail(match_stmt->match_obj.location, msg) + if match_stmt->func_name[0] == '\0': + case_type = typecheck_expression_not_void(state, &match_stmt->match_obj) + if case_type->kind != TypeKind.Enum: + # TODO: extend match statements to other equals-comparable values + snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", case_type->name) + fail(match_stmt->match_obj.location, msg) + + # Ensure user checks all possible enum values + nremaining = case_type->enummembers.count + remaining: byte** = malloc(sizeof(remaining[0]) * nremaining) + assert remaining != NULL + for i = 0; i < nremaining; i++: + remaining[i] = case_type->enummembers.names[i] + sig_string: byte* = NULL + else: + sig = state->file_types->find_function(match_stmt->func_name) + if sig == NULL: + snprintf(msg, sizeof(msg), "function '%s' not found", match_stmt->func_name) + fail(match_stmt->match_obj.location, msg) + match_stmt->func_signature = sig->copy() + + # Most of the time, only argument types are relevant, so don't include the return type into sig_string. + sig_string = sig->to_string(False, False) + + if sig->nargs != 2 or sig->returntype != intType or sig->takes_varargs: # TODO: could be more general + if sig->returntype != intType: + # show return type in error message + sig_string = sig->to_string(True, False) + snprintf(msg, sizeof(msg), "cannot match with function %s", sig_string) + fail(match_stmt->match_obj.location, msg) - # Ensure user checks all possible enum values - nremaining = mtype->enummembers.count - remaining: byte** = malloc(sizeof(remaining[0]) * nremaining) - assert remaining != NULL - for i = 0; i < nremaining; i++: - remaining[i] = mtype->enummembers.names[i] + snprintf(msg, sizeof(msg), "cannot match with %s", sig_string) + typecheck_expression_with_implicit_cast(state, &match_stmt->match_obj, sig->argtypes[0], msg) + + case_type = sig->argtypes[1] + remaining = NULL + nremaining = -1 for i = 0; i < match_stmt->ncases; i++: for j = 0; j < match_stmt->cases[i].n_case_objs; j++: case_obj = &match_stmt->cases[i].case_objs[j] - typecheck_expression_with_implicit_cast( - state, case_obj, mtype, - "case value of type cannot be matched against ", - ) + if sig_string == NULL: + msg = "case value of type cannot be matched against " + else: + snprintf(msg, sizeof(msg), "case value cannot be when matching with %s", sig_string) + typecheck_expression_with_implicit_cast(state, case_obj, case_type, msg) # The special casing only applies to simple enum member lookups (TheEnum.TheMember syntax) if case_obj->kind != AstExpressionKind.GetEnumMember: @@ -1125,13 +1151,13 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> snprintf( msg, sizeof(msg), "enum member %s.%s not handled in match statement", - mtype->name, remaining[0], + case_type->name, remaining[0], ) else: snprintf( msg, sizeof(msg) - 20, "the following %d members of enum %s are not handled in match statement: ", - nremaining, mtype->name, + nremaining, case_type->name, ) for i = 0; i < nremaining; i++: assert sizeof(msg) > 300 @@ -1146,6 +1172,7 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> fail(match_stmt->match_obj.location, msg) + free(sig_string) free(remaining) if match_stmt->case_underscore != NULL: diff --git a/doc/syntax-spec.md b/doc/syntax-spec.md index 46236c62..29827864 100644 --- a/doc/syntax-spec.md +++ b/doc/syntax-spec.md @@ -107,6 +107,7 @@ Jou has a few different kinds of tokens: - `float` - `double` - `match` + - `with` - `case` - **Newline tokens** occur at the end of a line of code. Lines that only contain spaces and comments do not produce a newline token; diff --git a/tests/404/match_function.jou b/tests/404/match_function.jou new file mode 100644 index 00000000..83b0daf1 --- /dev/null +++ b/tests/404/match_function.jou @@ -0,0 +1,4 @@ +def main() -> int: + match "hello" with lolwatwut: # Error: function 'lolwatwut' not found + case "hi": + pass diff --git a/tests/should_succeed/match.jou b/tests/should_succeed/match.jou index 29c4ee34..e3fe1800 100644 --- a/tests/should_succeed/match.jou +++ b/tests/should_succeed/match.jou @@ -1,4 +1,6 @@ import "stdlib/io.jou" +import "stdlib/mem.jou" +import "stdlib/str.jou" enum Foo: Bar @@ -110,4 +112,32 @@ def main() -> int: case show_evaluation(Foo.Lol, "case 3") | show_evaluation(Foo.Wut, "case 4"): printf("ye\n") + # Make a string that is surely not == "Hello", to make sure strcmp() is called below + s: byte* = malloc(20) + strcpy(s, "Hello") + match s with strcmp: + case "Hi": + printf("Hiii\n") + case "Hello": + printf("Hello there!\n") # Output: Hello there! + case _: + printf("something else\n") + + strcat(s, "lol") + match s with strcmp: + case "Hi": + printf("Hiii\n") + case "Hello": + printf("Hello there!\n") + # no "case _", that's fine + + match s with strcmp: + case "Hi": + printf("Hiii\n") + case "Hello": + printf("Hello there!\n") + case _: + printf("something else\n") # Output: something else + + free(s) return 0 diff --git a/tests/syntax_error/match_bad_with.jou b/tests/syntax_error/match_bad_with.jou new file mode 100644 index 00000000..a23b6ff1 --- /dev/null +++ b/tests/syntax_error/match_bad_with.jou @@ -0,0 +1,2 @@ +def main() -> int: + match foo with 1: # Error: expected function name, got an integer diff --git a/tests/wrong_type/match_function_num_args.jou b/tests/wrong_type/match_function_num_args.jou new file mode 100644 index 00000000..4c9bbc3d --- /dev/null +++ b/tests/wrong_type/match_function_num_args.jou @@ -0,0 +1,8 @@ +def foo(a: byte*) -> int: + return 0 + +def main() -> int: + match "hello" with foo: # Error: cannot match with function foo(a: byte*) + case "hi": + pass + return 0 diff --git a/tests/wrong_type/match_function_return_type.jou b/tests/wrong_type/match_function_return_type.jou new file mode 100644 index 00000000..dae79034 --- /dev/null +++ b/tests/wrong_type/match_function_return_type.jou @@ -0,0 +1,8 @@ +def foo(a: byte*, b: byte*) -> byte*: + return NULL + +def main() -> int: + match "hello" with foo: # Error: cannot match with function foo(a: byte*, b: byte*) -> byte* + case "hi": + pass + return 0 diff --git a/tests/wrong_type/match_function_varargs.jou b/tests/wrong_type/match_function_varargs.jou new file mode 100644 index 00000000..707be376 --- /dev/null +++ b/tests/wrong_type/match_function_varargs.jou @@ -0,0 +1,8 @@ +# Otherwise the correct signature, but has "..." (varargs) which is not supported +declare func(a: byte*, b: byte*, ...) -> int + +def main() -> int: + match "hello" with func: # Error: cannot match with function func(a: byte*, b: byte*, ...) + case "hi": + pass + return 0 diff --git a/tests/wrong_type/match_with_arg1.jou b/tests/wrong_type/match_with_arg1.jou new file mode 100644 index 00000000..c79ed2af --- /dev/null +++ b/tests/wrong_type/match_with_arg1.jou @@ -0,0 +1,7 @@ +import "stdlib/str.jou" + +def main() -> int: + match 1 with strcmp: # Error: cannot match int with strcmp(s1: byte*, s2: byte*) + case "hey": + pass + return 0 diff --git a/tests/wrong_type/match_with_arg2.jou b/tests/wrong_type/match_with_arg2.jou new file mode 100644 index 00000000..3e41b684 --- /dev/null +++ b/tests/wrong_type/match_with_arg2.jou @@ -0,0 +1,7 @@ +import "stdlib/str.jou" + +def main() -> int: + match "hey" with strcmp: + case 1: # Error: case value cannot be int when matching with strcmp(s1: byte*, s2: byte*) + pass + return 0