From 968f04af5030c5747097367495d9abb8faa9a1f1 Mon Sep 17 00:00:00 2001 From: Akuli Date: Tue, 28 Jan 2025 04:41:17 +0200 Subject: [PATCH] Determine type of strings with bidirectional type inference --- compiler/builders/ast_to_builder.jou | 13 ++- .../step3_function_and_method_bodies.jou | 84 ++++++++++--------- compiler/types_in_ast.jou | 1 - tests/should_succeed/array.jou | 16 +++- tests/should_succeed/sizeof.jou | 42 +--------- 5 files changed, 65 insertions(+), 91 deletions(-) diff --git a/compiler/builders/ast_to_builder.jou b/compiler/builders/ast_to_builder.jou index 58461e41..228cb583 100644 --- a/compiler/builders/ast_to_builder.jou +++ b/compiler/builders/ast_to_builder.jou @@ -252,7 +252,12 @@ class AstToBuilder: def build_expression_without_implicit_cast(self, expr: AstExpression*) -> BuilderValue: match expr->kind: case AstExpressionKind.String: - return self->builder->string(expr->string) + if expr->types.orig_type == byteType->pointer_type(): + return self->builder->string(expr->string) + else: + assert expr->types.orig_type->kind == TypeKind.Array + assert expr->types.orig_type->array.item_type == byteType + return self->builder->string_array(expr->string, expr->types.orig_type->array.len) case AstExpressionKind.Byte: return self->builder->integer(byteType, expr->byte_value) case AstExpressionKind.Short: @@ -359,12 +364,6 @@ class AstToBuilder: if expr->types.implicit_array_to_pointer_cast: return self->builder->cast(self->build_address_of_expression(expr), expr->types.implicit_cast_type) - if expr->types.implicit_string_to_array_cast: - assert expr->types.implicit_cast_type != NULL - assert expr->types.implicit_cast_type->kind == TypeKind.Array - assert expr->kind == AstExpressionKind.String - return self->builder->string_array(expr->string, expr->types.implicit_cast_type->array.len) - raw = self->build_expression_without_implicit_cast(expr) if expr->types.orig_type == NULL and expr->types.implicit_cast_type == NULL: # Function/method call that returns no value diff --git a/compiler/typecheck/step3_function_and_method_bodies.jou b/compiler/typecheck/step3_function_and_method_bodies.jou index 6c0c4e6c..7bf4b171 100644 --- a/compiler/typecheck/step3_function_and_method_bodies.jou +++ b/compiler/typecheck/step3_function_and_method_bodies.jou @@ -282,25 +282,12 @@ def do_implicit_cast( if from == to: return - if ( - expr->kind == AstExpressionKind.String - and from == byteType->pointer_type() - and to->kind == TypeKind.Array - and to->array.item_type == byteType - ): - string_size = strlen(expr->string) + 1 - if to->array.len < string_size: - msg: byte[500] - snprintf(msg, sizeof(msg), "a string of %d bytes (including '\\0') does not fit into %s", string_size, to->name) - fail(location, msg) - expr->types.implicit_string_to_array_cast = True # Passing in NULL for errormsg_template can be used to "force" a cast to happen. - elif errormsg_template != NULL and not can_cast_implicitly(from, to): + if errormsg_template != NULL and not can_cast_implicitly(from, to): fail_with_implicit_cast_error(location, errormsg_template, from, to) expr->types.implicit_cast_type = to expr->types.implicit_array_to_pointer_cast = (from->kind == TypeKind.Array and to->is_pointer_type()) - if expr->types.implicit_array_to_pointer_cast: ensure_can_take_address( fom, @@ -334,8 +321,8 @@ def do_explicit_cast(fom: FunctionOrMethodTypes*, expr: AstExpression*, to: Type cast_array_to_pointer(fom, expr) -def typecheck_expression_not_void(state: State*, expr: AstExpression*) -> Type*: - typecheck_expression(state, expr) +def typecheck_expression_not_void(state: State*, expr: AstExpression*, type_hint: Type*) -> Type*: + typecheck_expression(state, expr, type_hint) if expr->types.orig_type != NULL: # The happy path. Evaluating the expression results in a value. return expr->types.orig_type @@ -354,7 +341,8 @@ def typecheck_expression_with_implicit_cast( casttype: Type*, errormsg_template: byte*, ) -> None: - typecheck_expression_not_void(state, expr) + assert casttype != NULL + typecheck_expression_not_void(state, expr, casttype) do_implicit_cast(state->fom_types, expr, casttype, expr->location, errormsg_template) @@ -488,7 +476,7 @@ def check_increment_or_decrement(state: State*, expr: AstExpression*) -> Type*: case _: assert False - t = typecheck_expression_not_void(state, &expr->operands[0]) + t = typecheck_expression_not_void(state, &expr->operands[0], NULL) if not t->is_number_type() and t->kind != TypeKind.Pointer: msg: byte[500] snprintf(msg, sizeof(msg), bad_type_fmt, t->name) @@ -514,7 +502,7 @@ def typecheck_indexing( ) -> Type*: msg: byte[500] - orig_type = typecheck_expression_not_void(state, ptrexpr) + orig_type = typecheck_expression_not_void(state, ptrexpr, NULL) match orig_type->kind: case TypeKind.Pointer: ptrtype = orig_type @@ -528,7 +516,7 @@ def typecheck_indexing( assert ptrtype != NULL assert ptrtype->kind == TypeKind.Pointer - indextype = typecheck_expression_not_void(state, indexexpr) + indextype = typecheck_expression_not_void(state, indexexpr, NULL) assert indextype != NULL if not indextype->is_integer_type(): snprintf(msg, sizeof(msg), "the index inside [...] must be an integer, not %s", indextype->name) @@ -637,7 +625,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type: for i = k; i < call->nargs; i++: # This code runs for varargs, e.g. the things to format in printf(). - t = typecheck_expression_not_void(state, &call->args[i]) + t = typecheck_expression_not_void(state, &call->args[i], NULL) if t->kind == TypeKind.Array: cast_array_to_pointer(state->fom_types, &call->args[i]) elif (t->is_integer_type() and t->size_in_bits < 32) or t == boolType: @@ -845,7 +833,7 @@ def handle_conflicting_class_field_and_enum_member_syntax(state: State*, expr: A } -def typecheck_expression(state: State*, expr: AstExpression*) -> None: +def typecheck_expression(state: State*, expr: AstExpression*, type_hint: Type*) -> None: msg: byte[500] result: Type* = NULL @@ -869,7 +857,11 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: case AstExpressionKind.Null: result = voidPtrType case AstExpressionKind.String: - result = byteType->pointer_type() + # Use array string if type hint given, e.g. foo: byte[100] = "hello" + if type_hint != NULL and type_hint->kind == TypeKind.Array and type_hint->array.item_type == byteType: + result = type_hint + else: + result = byteType->pointer_type() case AstExpressionKind.GetEnumMember: result = state->file_types->find_type(expr->enum_member.enum_name) @@ -884,22 +876,31 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: fail(expr->location, msg) case AstExpressionKind.SizeOf: - typecheck_expression_not_void(state, &expr->operands[0]) + if expr->operands[0].kind == AstExpressionKind.String: + # sizeof("foo") should be 4 + obj_type_hint = byteType->array_type((strlen(expr->operands[0].string) as int) + 1) + else: + obj_type_hint = NULL + typecheck_expression_not_void(state, &expr->operands[0], obj_type_hint) result = longType case AstExpressionKind.Instantiate: result = typecheck_instantiation(state, &expr->instantiation, expr->location) case AstExpressionKind.Array: + if type_hint != NULL and type_hint->kind == TypeKind.Array: + item_type_hint = type_hint->array.item_type + else: + item_type_hint = NULL n = expr->array.length for i = 0; i < n; i++: - typecheck_expression_not_void(state, &expr->array.items[i]) + typecheck_expression_not_void(state, &expr->array.items[i], item_type_hint) membertype = cast_array_members_to_a_common_type(state->fom_types, expr->location, expr->array) result = membertype->array_type(n) case AstExpressionKind.GetClassField: if expr->class_field.uses_arrow_operator: - temptype = typecheck_expression_not_void(state, expr->class_field.instance) + temptype = typecheck_expression_not_void(state, expr->class_field.instance, NULL) if temptype->kind != TypeKind.Pointer or temptype->value_type->kind != TypeKind.Class: snprintf( msg, sizeof(msg), @@ -910,7 +911,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: fail(expr->location, msg) result = typecheck_class_field(temptype->value_type, expr->class_field.field_name, expr->location)->type else: - temptype = typecheck_expression_not_void(state, expr->class_field.instance) + temptype = typecheck_expression_not_void(state, expr->class_field.instance, NULL) if temptype->kind != TypeKind.Class: snprintf( msg, sizeof(msg), @@ -929,7 +930,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: if expr->call.method_call_self == NULL: result = typecheck_function_or_method_call(state, &expr->call, NULL, expr->location) elif expr->call.uses_arrow_operator: - temptype = typecheck_expression_not_void(state, expr->call.method_call_self) + temptype = typecheck_expression_not_void(state, expr->call.method_call_self, NULL) if temptype->kind != TypeKind.Pointer or temptype->value_type->kind != TypeKind.Class: snprintf(msg, sizeof(msg), "left side of '->' operator must be a pointer to an instance of a class, not %s", @@ -939,7 +940,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: fail(expr->location, msg) result = typecheck_function_or_method_call(state, &expr->call, temptype->value_type, expr->location) else: - temptype = typecheck_expression_not_void(state, expr->call.method_call_self) + temptype = typecheck_expression_not_void(state, expr->call.method_call_self, NULL) if temptype->kind != TypeKind.Class: snprintf(msg, sizeof(msg), "left side of '.' operator must be an instance of a class, not %s", @@ -974,7 +975,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: result = typecheck_indexing(state, &expr->operands[0], &expr->operands[1]) case AstExpressionKind.AddressOf: - result = typecheck_expression_not_void(state, &expr->operands[0])->pointer_type() + result = typecheck_expression_not_void(state, &expr->operands[0], NULL)->pointer_type() ensure_can_take_address(state->fom_types, &expr->operands[0], "the '&' operator cannot be used with %s") case AstExpressionKind.GetVariable: @@ -989,7 +990,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: result = selfvar->type case AstExpressionKind.Dereference: - temptype = typecheck_expression_not_void(state, &expr->operands[0]) + temptype = typecheck_expression_not_void(state, &expr->operands[0], NULL) typecheck_dereferenced_pointer(expr->location, temptype) result = temptype->value_type @@ -1008,7 +1009,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: result = boolType case AstExpressionKind.Negate: - result = typecheck_expression_not_void(state, &expr->operands[0]) + result = typecheck_expression_not_void(state, &expr->operands[0], NULL) if result->kind != TypeKind.SignedInteger and result->kind != TypeKind.FloatingPoint: snprintf(msg, sizeof(msg), "value after '-' must be a float or double or a signed integer, not %s", @@ -1028,8 +1029,8 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: | AstExpressionKind.Lt | AstExpressionKind.Le ): - typecheck_expression_not_void(state, &expr->operands[0]) - typecheck_expression_not_void(state, &expr->operands[1]) + typecheck_expression_not_void(state, &expr->operands[0], NULL) + typecheck_expression_not_void(state, &expr->operands[1], NULL) result = check_binop(state->fom_types, expr->kind, expr->location, &expr->operands[0], &expr->operands[1]) case ( @@ -1041,8 +1042,9 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None: result = check_increment_or_decrement(state, expr) case AstExpressionKind.As: - typecheck_expression_not_void(state, &expr->as_->value) + # TODO: test this: sizeof("foo" as byte[100]) result = type_from_ast(state->file_types, &expr->as_->type) + typecheck_expression_not_void(state, &expr->as_->value, result) do_explicit_cast(state->fom_types, &expr->as_->value, result, expr->location) case _: @@ -1080,7 +1082,7 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) -> nremaining = -1 if match_stmt->func_name[0] == '\0': - case_type = typecheck_expression_not_void(state, &match_stmt->match_obj) + case_type = typecheck_expression_not_void(state, &match_stmt->match_obj, NULL) match case_type->kind: case TypeKind.SignedInteger | TypeKind.UnsignedInteger: @@ -1226,11 +1228,11 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None: and state->find_any_var(targetexpr->varname) == NULL ): # Making a new variable. Use the type of the value being assigned. - type = typecheck_expression_not_void(state, valueexpr) + type = typecheck_expression_not_void(state, valueexpr, NULL) state->fom_types->add_variable(type, targetexpr->varname) else: # Convert value to the type of an existing variable or other assignment target. - targettype = typecheck_expression_not_void(state, targetexpr) + targettype = typecheck_expression_not_void(state, targetexpr, NULL) ensure_can_take_address(state->fom_types, targetexpr, "cannot assign to %s") if targetexpr->kind == AstExpressionKind.Dereference: @@ -1250,8 +1252,8 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None: targetexpr = &stmt->assignment.target valueexpr = &stmt->assignment.value - targettype = typecheck_expression_not_void(state, targetexpr) - value_type = typecheck_expression_not_void(state, valueexpr) + targettype = typecheck_expression_not_void(state, targetexpr, NULL) + value_type = typecheck_expression_not_void(state, valueexpr, NULL) ensure_can_take_address(state->fom_types, targetexpr, "cannot assign to %s") match stmt->kind: @@ -1330,7 +1332,7 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None: "initial value for variable of type cannot be of type ") case AstStatementKind.ExpressionStatement: - typecheck_expression(state, &stmt->expression) + typecheck_expression(state, &stmt->expression, NULL) case AstStatementKind.Assert: typecheck_expression_with_implicit_cast(state, &stmt->expression, boolType, "assertion must be a bool, not ") diff --git a/compiler/types_in_ast.jou b/compiler/types_in_ast.jou index 6e9f7ff7..b033dde3 100644 --- a/compiler/types_in_ast.jou +++ b/compiler/types_in_ast.jou @@ -22,7 +22,6 @@ class ExpressionTypes: # Flags to indicate whether special kinds of implicit casts happened implicit_array_to_pointer_cast: bool # Foo[N] to Foo* - implicit_string_to_array_cast: bool # "..." to byte[N] class LocalVariable: diff --git a/tests/should_succeed/array.jou b/tests/should_succeed/array.jou index edda5bc6..d4d12e58 100644 --- a/tests/should_succeed/array.jou +++ b/tests/should_succeed/array.jou @@ -1,4 +1,4 @@ -declare printf(fmt: byte*, ...) -> int +import "stdlib/io.jou" # c can't do this def make_array() -> int[3]: @@ -51,4 +51,18 @@ def main() -> int: if strings[i] == NULL: printf("strings[%d] is NULL\n", i) + # array of fixed-size strings (#683) + # Output: hello + # Output: world + # Output: test + # Output: foo + # Output: bar + # Output: bazzybaz + strings50: byte[50][3] = ["hello", "world", "test"] + for i = 0; i < 3; i++: + puts(strings50[i]) + strings50 = ["foo", "bar", "bazzybaz"] + for i = 0; i < 3; i++: + puts(strings50[i]) + return 0 diff --git a/tests/should_succeed/sizeof.jou b/tests/should_succeed/sizeof.jou index f5c7d3e7..3af910d5 100644 --- a/tests/should_succeed/sizeof.jou +++ b/tests/should_succeed/sizeof.jou @@ -1,45 +1,5 @@ import "stdlib/io.jou" -import "stdlib/mem.jou" - -def side_effect() -> int: - printf("Side Effect !!!!!\n") - return 123 - -class Foo: - a: int - b: long - c: byte - -# See issue #224. -def ensure_sizeof_isnt_too_small_in_a_weird_corner_case() -> None: - value = Foo{a=1, b=2, c='x'} - # We need the heap allocation, because otherwise the optimizer happens to make things work. - ptr = malloc(50) as Foo* - memcpy(ptr, &value, sizeof value) - # If sizeof is too small, this prints garbage. - printf("%c\n", ptr->c) # Output: x - free(ptr) def main() -> int: - ensure_sizeof_isnt_too_small_in_a_weird_corner_case() - - bo: bool - by: byte - n: int - m: long - - printf("%lld %lld %lld %lld\n", sizeof bo, sizeof by, sizeof n, sizeof m) # Output: 1 1 4 8 - - # test that operator precedence works - printf("%lld\n", sizeof by + sizeof n + sizeof m) # Output: 13 - - arr: long[100] - printf("%lld\n", sizeof arr) # Output: 800 - - # The "array length trick" - printf("%lld\n", sizeof arr / sizeof arr[0]) # Output: 100 - - # Evaluating a sizeof has no side effects. - printf("%lld\n", sizeof side_effect()) # Output: 4 - + printf("%lld\n", sizeof("hello")) return 0