diff --git a/compiler/builders/ast_to_builder.jou b/compiler/builders/ast_to_builder.jou index 83b0497a..45cc3777 100644 --- a/compiler/builders/ast_to_builder.jou +++ b/compiler/builders/ast_to_builder.jou @@ -508,7 +508,22 @@ class AstToBuilder: if match_stmt->case_underscore != NULL: self->build_body(match_stmt->case_underscore) - self->builder->jump(done) + + if ( + match_stmt->case_underscore == NULL + and match_stmt->match_obj.types.implicit_cast_type->kind == TypeKind.Enum + ): + # The one corner case where match statement invokes UB: + # - User is matching over an enum + # - All enum members are handled (otherwise error in typecheck) + # - The value stored in the enum is not a valid value of the enum + # - There is no "case _" to catch the invalid value + # + # See also: doc/match.md + self->builder->unreachable() + else: + self->builder->jump(done) + self->builder->set_current_block(done) def build_assert(self, assert_location: Location, assertion: AstAssertion*) -> None: diff --git a/compiler/typecheck/step3_function_and_method_bodies.jou b/compiler/typecheck/step3_function_and_method_bodies.jou index 391c5a65..ef756e88 100644 --- a/compiler/typecheck/step3_function_and_method_bodies.jou +++ b/compiler/typecheck/step3_function_and_method_bodies.jou @@ -1146,11 +1146,11 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> snprintf(msg, sizeof(msg), "case value cannot be when matching with %s", sig_string) typecheck_expression_with_implicit_cast(state, case_obj, case_type, msg) - # The special casing only applies to simple enum member lookups (TheEnum.TheMember syntax) - if case_obj->kind != AstExpressionKind.GetEnumMember: - nremaining = -1 - if nremaining != -1: + if case_obj->kind != AstExpressionKind.GetEnumMember: + # Matching an enum but it's too dynamic, not simply TheEnum.Member + snprintf(msg, sizeof(msg), "'case' value must be %s.something when matching a value of enum %s", case_type->name, case_type->name) + fail(case_obj->location, msg) # We are matching against TheEnum.member. Try to find and remove it from remaining members. member = case_obj->enum_member.member_name found = False diff --git a/tests/other_errors/match_enum_too_dynamic.jou b/tests/other_errors/match_enum_too_dynamic.jou new file mode 100644 index 00000000..5cbf0872 --- /dev/null +++ b/tests/other_errors/match_enum_too_dynamic.jou @@ -0,0 +1,11 @@ +enum Thing: + Foo + Bar + Baz + +def do_stuff(t1: Thing, t2: Thing) -> None: + match t1: + # This is forbidden because it would be hard for the compiler to know + # which values have been handled and which haven't + case t2: # Error: 'case' value must be Thing.something when matching a value of enum Thing + pass diff --git a/tests/should_succeed/match.jou b/tests/should_succeed/match.jou index 821ef0c7..dc99c803 100644 --- a/tests/should_succeed/match.jou +++ b/tests/should_succeed/match.jou @@ -9,9 +9,9 @@ enum Foo: Wut -def show_evaluation(foo: Foo, msg: byte*) -> Foo: +def show_evaluation(value: int, msg: byte*) -> int: puts(msg) - return foo + return value def main() -> int: @@ -80,6 +80,8 @@ def main() -> int: printf("Hey! :)\n") # Output: Hey! :) case Foo.Wut: printf("nope\n") + case _: + printf("Other!!!\n") f = 12345 as Foo match f: @@ -91,6 +93,8 @@ def main() -> int: printf("nope\n") case Foo.Wut: printf("nope\n") + case _: + printf("Other!!!\n") # Output: Other!!! # Test evaluation order. # @@ -99,14 +103,14 @@ def main() -> int: # Output: case 2 # Output: case 3 # Output: ye - match show_evaluation(Foo.Lol, "match obj"): - case show_evaluation(Foo.Bar, "case 1"): + match show_evaluation(3, "match obj"): + case show_evaluation(1, "case 1"): printf("nope\n") - case show_evaluation(Foo.Baz, "case 2"): + case show_evaluation(2, "case 2"): printf("nope\n") - case show_evaluation(Foo.Lol, "case 3"): + case show_evaluation(3, "case 3"): printf("ye\n") - case show_evaluation(Foo.Wut, "case 4"): + case show_evaluation(4, "case 4"): printf("nope\n") # Output: match obj @@ -114,10 +118,10 @@ def main() -> int: # Output: case 2 # Output: case 3 # Output: ye - match show_evaluation(Foo.Lol, "match obj"): - case show_evaluation(Foo.Bar, "case 1") | show_evaluation(Foo.Baz, "case 2"): + match show_evaluation(3, "match obj"): + case show_evaluation(1, "case 1") | show_evaluation(2, "case 2"): printf("nope\n") - case show_evaluation(Foo.Lol, "case 3") | show_evaluation(Foo.Wut, "case 4"): + case show_evaluation(3, "case 3") | show_evaluation(4, "case 4"): printf("ye\n") # Make a string that is surely not == "Hello", to make sure strcmp() is called below