Skip to content

Commit

Permalink
Add "match ... with ..." syntax (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Jan 26, 2025
1 parent 8d044cc commit b339b1b
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 24 deletions.
6 changes: 5 additions & 1 deletion compiler/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -695,13 +695,15 @@ class AstIfStatement:
self->else_body.free()


# match match_obj:
# match match_obj with func:
# case ...:
# ...
# case ...:
# ...
class AstMatchStatement:
match_obj: AstExpression
func_name: byte[100] # empty if there's no "with foo"
func_signature: Signature # populated in typecheck, zero-initialized before typecheck runs
cases: AstCase*
ncases: int
case_underscore: AstBody* # body of "case _" (always last), NULL if no "case _"
Expand All @@ -719,6 +721,8 @@ class AstMatchStatement:

def free(self) -> None:
self->match_obj.free()
if self->func_signature.name[0] != '\0':
self->func_signature.free()
for i = 0; i < self->ncases; i++:
self->cases[i].free()
free(self->cases)
Expand Down
8 changes: 7 additions & 1 deletion compiler/builders/ast_to_builder.jou
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,13 @@ class AstToBuilder:
otherwise = BuilderBlock{} # will be replaced by loop below
for k = 0; k < match_stmt->cases[i].n_case_objs; k++:
case_obj = self->build_expression(&match_stmt->cases[i].case_objs[k])
cond = self->builder->eq(match_obj, case_obj)
if match_stmt->func_name[0] == '\0':
cond = self->builder->eq(match_obj, case_obj)
else:
args = [match_obj, case_obj]
func_ret = self->builder->call(&match_stmt->func_signature, args, 2)
zero = self->builder->integer(func_ret.type, 0)
cond = self->builder->eq(func_ret, zero)
otherwise = self->builder->add_block()
self->builder->branch(cond, then, otherwise)
self->builder->set_current_block(otherwise)
Expand Down
7 changes: 7 additions & 0 deletions compiler/parser.jou
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,13 @@ class Parser:
self->tokens++

result = AstMatchStatement{match_obj = self->parse_expression()}

if self->tokens->is_keyword("with"):
self->tokens++
if self->tokens->kind != TokenKind.Name:
self->tokens->fail_expected_got("function name")
result.func_name = (self->tokens++)->short_string

self->parse_start_of_body()

while self->tokens->kind != TokenKind.Dedent:
Expand Down
2 changes: 1 addition & 1 deletion compiler/tokenizer.jou
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def is_keyword(word: byte*) -> bool:
"return", "if", "elif", "else", "while", "for", "pass", "break", "continue",
"True", "False", "None", "NULL", "void", "noreturn",
"and", "or", "not", "self", "as", "sizeof", "assert",
"bool", "byte", "short", "int", "long", "float", "double", "match", "case",
"bool", "byte", "short", "int", "long", "float", "double", "match", "with", "case",
]

for i = 0; i < sizeof keywords / sizeof keywords[0]; i++:
Expand Down
69 changes: 48 additions & 21 deletions compiler/typecheck/step3_function_and_method_bodies.jou
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type:
else:
function_or_method = "method"

sigstr = sig->to_string(False, False)
sig_string = sig->to_string(False, False)

nargs = sig->nargs
if self_type != NULL:
Expand All @@ -614,7 +614,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type:
msg, sizeof(msg),
"%s %s takes %d argument%s, but it was called with %d argument%s",
function_or_method,
sigstr,
sig_string,
nargs,
plural_s(nargs),
call->nargs,
Expand All @@ -631,7 +631,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type:
which_arg = nth(++selfless_counter)
snprintf(
msg, sizeof msg,
"%s argument of %s %s should have type <to>, not <from>", which_arg, function_or_method, sigstr
"%s argument of %s %s should have type <to>, not <from>", which_arg, function_or_method, sig_string
)
typecheck_expression_with_implicit_cast(state, &call->args[k++], sig->argtypes[i], msg)

Expand All @@ -653,7 +653,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type:
)
fail(call->args[i].location, msg)

free(sigstr)
free(sig_string)
return sig->returntype


Expand Down Expand Up @@ -1076,26 +1076,52 @@ def typecheck_if_statement(state: State*, ifstmt: AstIfStatement*) -> None:
def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> None:
msg: byte[500]

mtype = typecheck_expression_not_void(state, &match_stmt->match_obj)
if mtype->kind != TypeKind.Enum:
# TODO: extend match statements to other equals-comparable values
snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", mtype->name)
fail(match_stmt->match_obj.location, msg)
if match_stmt->func_name[0] == '\0':
case_type = typecheck_expression_not_void(state, &match_stmt->match_obj)
if case_type->kind != TypeKind.Enum:
# TODO: extend match statements to other equals-comparable values
snprintf(msg, sizeof(msg), "match statements can only be used with enums, not %s", case_type->name)
fail(match_stmt->match_obj.location, msg)

# Ensure user checks all possible enum values
nremaining = case_type->enummembers.count
remaining: byte** = malloc(sizeof(remaining[0]) * nremaining)
assert remaining != NULL
for i = 0; i < nremaining; i++:
remaining[i] = case_type->enummembers.names[i]
sig_string: byte* = NULL
else:
sig = state->file_types->find_function(match_stmt->func_name)
if sig == NULL:
snprintf(msg, sizeof(msg), "function '%s' not found", match_stmt->func_name)
fail(match_stmt->match_obj.location, msg)
match_stmt->func_signature = sig->copy()

# Most of the time, only argument types are relevant, so don't include the return type into sig_string.
sig_string = sig->to_string(False, False)

if sig->nargs != 2 or sig->returntype != intType or sig->takes_varargs: # TODO: could be more general
if sig->returntype != intType:
# show return type in error message
sig_string = sig->to_string(True, False)
snprintf(msg, sizeof(msg), "cannot match with function %s", sig_string)
fail(match_stmt->match_obj.location, msg)

# Ensure user checks all possible enum values
nremaining = mtype->enummembers.count
remaining: byte** = malloc(sizeof(remaining[0]) * nremaining)
assert remaining != NULL
for i = 0; i < nremaining; i++:
remaining[i] = mtype->enummembers.names[i]
snprintf(msg, sizeof(msg), "cannot match <from> with %s", sig_string)
typecheck_expression_with_implicit_cast(state, &match_stmt->match_obj, sig->argtypes[0], msg)

case_type = sig->argtypes[1]
remaining = NULL
nremaining = -1

for i = 0; i < match_stmt->ncases; i++:
for j = 0; j < match_stmt->cases[i].n_case_objs; j++:
case_obj = &match_stmt->cases[i].case_objs[j]
typecheck_expression_with_implicit_cast(
state, case_obj, mtype,
"case value of type <from> cannot be matched against <to>",
)
if sig_string == NULL:
msg = "case value of type <from> cannot be matched against <to>"
else:
snprintf(msg, sizeof(msg), "case value cannot be <from> when matching with %s", sig_string)
typecheck_expression_with_implicit_cast(state, case_obj, case_type, msg)

# The special casing only applies to simple enum member lookups (TheEnum.TheMember syntax)
if case_obj->kind != AstExpressionKind.GetEnumMember:
Expand Down Expand Up @@ -1125,13 +1151,13 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) ->
snprintf(
msg, sizeof(msg),
"enum member %s.%s not handled in match statement",
mtype->name, remaining[0],
case_type->name, remaining[0],
)
else:
snprintf(
msg, sizeof(msg) - 20,
"the following %d members of enum %s are not handled in match statement: ",
nremaining, mtype->name,
nremaining, case_type->name,
)
for i = 0; i < nremaining; i++:
assert sizeof(msg) > 300
Expand All @@ -1146,6 +1172,7 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) ->

fail(match_stmt->match_obj.location, msg)

free(sig_string)
free(remaining)

if match_stmt->case_underscore != NULL:
Expand Down
1 change: 1 addition & 0 deletions doc/syntax-spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Jou has a few different kinds of tokens:
- `float`
- `double`
- `match`
- `with`
- `case`
- **Newline tokens** occur at the end of a line of code.
Lines that only contain spaces and comments do not produce a newline token;
Expand Down
4 changes: 4 additions & 0 deletions tests/404/match_function.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def main() -> int:
match "hello" with lolwatwut: # Error: function 'lolwatwut' not found
case "hi":
pass
30 changes: 30 additions & 0 deletions tests/should_succeed/match.jou
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import "stdlib/io.jou"
import "stdlib/mem.jou"
import "stdlib/str.jou"

enum Foo:
Bar
Expand Down Expand Up @@ -110,4 +112,32 @@ def main() -> int:
case show_evaluation(Foo.Lol, "case 3") | show_evaluation(Foo.Wut, "case 4"):
printf("ye\n")

# Make a string that is surely not == "Hello", to make sure strcmp() is called below
s: byte* = malloc(20)
strcpy(s, "Hello")
match s with strcmp:
case "Hi":
printf("Hiii\n")
case "Hello":
printf("Hello there!\n") # Output: Hello there!
case _:
printf("something else\n")

strcat(s, "lol")
match s with strcmp:
case "Hi":
printf("Hiii\n")
case "Hello":
printf("Hello there!\n")
# no "case _", that's fine

match s with strcmp:
case "Hi":
printf("Hiii\n")
case "Hello":
printf("Hello there!\n")
case _:
printf("something else\n") # Output: something else

free(s)
return 0
2 changes: 2 additions & 0 deletions tests/syntax_error/match_bad_with.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def main() -> int:
match foo with 1: # Error: expected function name, got an integer
8 changes: 8 additions & 0 deletions tests/wrong_type/match_function_num_args.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def foo(a: byte*) -> int:
return 0

def main() -> int:
match "hello" with foo: # Error: cannot match with function foo(a: byte*)
case "hi":
pass
return 0
8 changes: 8 additions & 0 deletions tests/wrong_type/match_function_return_type.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def foo(a: byte*, b: byte*) -> byte*:
return NULL

def main() -> int:
match "hello" with foo: # Error: cannot match with function foo(a: byte*, b: byte*) -> byte*
case "hi":
pass
return 0
8 changes: 8 additions & 0 deletions tests/wrong_type/match_function_varargs.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Otherwise the correct signature, but has "..." (varargs) which is not supported
declare func(a: byte*, b: byte*, ...) -> int

def main() -> int:
match "hello" with func: # Error: cannot match with function func(a: byte*, b: byte*, ...)
case "hi":
pass
return 0
7 changes: 7 additions & 0 deletions tests/wrong_type/match_with_arg1.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import "stdlib/str.jou"

def main() -> int:
match 1 with strcmp: # Error: cannot match int with strcmp(s1: byte*, s2: byte*)
case "hey":
pass
return 0
7 changes: 7 additions & 0 deletions tests/wrong_type/match_with_arg2.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import "stdlib/str.jou"

def main() -> int:
match "hey" with strcmp:
case 1: # Error: case value cannot be int when matching with strcmp(s1: byte*, s2: byte*)
pass
return 0

0 comments on commit b339b1b

Please sign in to comment.