Skip to content

Commit

Permalink
Determine type of strings with bidirectional type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Jan 28, 2025
1 parent e014aae commit 968f04a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 91 deletions.
13 changes: 6 additions & 7 deletions compiler/builders/ast_to_builder.jou
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,12 @@ class AstToBuilder:
def build_expression_without_implicit_cast(self, expr: AstExpression*) -> BuilderValue:
match expr->kind:
case AstExpressionKind.String:
return self->builder->string(expr->string)
if expr->types.orig_type == byteType->pointer_type():
return self->builder->string(expr->string)
else:
assert expr->types.orig_type->kind == TypeKind.Array
assert expr->types.orig_type->array.item_type == byteType
return self->builder->string_array(expr->string, expr->types.orig_type->array.len)
case AstExpressionKind.Byte:
return self->builder->integer(byteType, expr->byte_value)
case AstExpressionKind.Short:
Expand Down Expand Up @@ -359,12 +364,6 @@ class AstToBuilder:
if expr->types.implicit_array_to_pointer_cast:
return self->builder->cast(self->build_address_of_expression(expr), expr->types.implicit_cast_type)

if expr->types.implicit_string_to_array_cast:
assert expr->types.implicit_cast_type != NULL
assert expr->types.implicit_cast_type->kind == TypeKind.Array
assert expr->kind == AstExpressionKind.String
return self->builder->string_array(expr->string, expr->types.implicit_cast_type->array.len)

raw = self->build_expression_without_implicit_cast(expr)
if expr->types.orig_type == NULL and expr->types.implicit_cast_type == NULL:
# Function/method call that returns no value
Expand Down
84 changes: 43 additions & 41 deletions compiler/typecheck/step3_function_and_method_bodies.jou
Original file line number Diff line number Diff line change
Expand Up @@ -282,25 +282,12 @@ def do_implicit_cast(
if from == to:
return

if (
expr->kind == AstExpressionKind.String
and from == byteType->pointer_type()
and to->kind == TypeKind.Array
and to->array.item_type == byteType
):
string_size = strlen(expr->string) + 1
if to->array.len < string_size:
msg: byte[500]
snprintf(msg, sizeof(msg), "a string of %d bytes (including '\\0') does not fit into %s", string_size, to->name)
fail(location, msg)
expr->types.implicit_string_to_array_cast = True
# Passing in NULL for errormsg_template can be used to "force" a cast to happen.
elif errormsg_template != NULL and not can_cast_implicitly(from, to):
if errormsg_template != NULL and not can_cast_implicitly(from, to):
fail_with_implicit_cast_error(location, errormsg_template, from, to)

expr->types.implicit_cast_type = to
expr->types.implicit_array_to_pointer_cast = (from->kind == TypeKind.Array and to->is_pointer_type())

if expr->types.implicit_array_to_pointer_cast:
ensure_can_take_address(
fom,
Expand Down Expand Up @@ -334,8 +321,8 @@ def do_explicit_cast(fom: FunctionOrMethodTypes*, expr: AstExpression*, to: Type
cast_array_to_pointer(fom, expr)


def typecheck_expression_not_void(state: State*, expr: AstExpression*) -> Type*:
typecheck_expression(state, expr)
def typecheck_expression_not_void(state: State*, expr: AstExpression*, type_hint: Type*) -> Type*:
typecheck_expression(state, expr, type_hint)
if expr->types.orig_type != NULL:
# The happy path. Evaluating the expression results in a value.
return expr->types.orig_type
Expand All @@ -354,7 +341,8 @@ def typecheck_expression_with_implicit_cast(
casttype: Type*,
errormsg_template: byte*,
) -> None:
typecheck_expression_not_void(state, expr)
assert casttype != NULL
typecheck_expression_not_void(state, expr, casttype)
do_implicit_cast(state->fom_types, expr, casttype, expr->location, errormsg_template)


Expand Down Expand Up @@ -488,7 +476,7 @@ def check_increment_or_decrement(state: State*, expr: AstExpression*) -> Type*:
case _:
assert False

t = typecheck_expression_not_void(state, &expr->operands[0])
t = typecheck_expression_not_void(state, &expr->operands[0], NULL)
if not t->is_number_type() and t->kind != TypeKind.Pointer:
msg: byte[500]
snprintf(msg, sizeof(msg), bad_type_fmt, t->name)
Expand All @@ -514,7 +502,7 @@ def typecheck_indexing(
) -> Type*:
msg: byte[500]

orig_type = typecheck_expression_not_void(state, ptrexpr)
orig_type = typecheck_expression_not_void(state, ptrexpr, NULL)
match orig_type->kind:
case TypeKind.Pointer:
ptrtype = orig_type
Expand All @@ -528,7 +516,7 @@ def typecheck_indexing(
assert ptrtype != NULL
assert ptrtype->kind == TypeKind.Pointer

indextype = typecheck_expression_not_void(state, indexexpr)
indextype = typecheck_expression_not_void(state, indexexpr, NULL)
assert indextype != NULL
if not indextype->is_integer_type():
snprintf(msg, sizeof(msg), "the index inside [...] must be an integer, not %s", indextype->name)
Expand Down Expand Up @@ -637,7 +625,7 @@ def typecheck_function_or_method_call(state: State*, call: AstCall*, self_type:

for i = k; i < call->nargs; i++:
# This code runs for varargs, e.g. the things to format in printf().
t = typecheck_expression_not_void(state, &call->args[i])
t = typecheck_expression_not_void(state, &call->args[i], NULL)
if t->kind == TypeKind.Array:
cast_array_to_pointer(state->fom_types, &call->args[i])
elif (t->is_integer_type() and t->size_in_bits < 32) or t == boolType:
Expand Down Expand Up @@ -845,7 +833,7 @@ def handle_conflicting_class_field_and_enum_member_syntax(state: State*, expr: A
}


def typecheck_expression(state: State*, expr: AstExpression*) -> None:
def typecheck_expression(state: State*, expr: AstExpression*, type_hint: Type*) -> None:
msg: byte[500]
result: Type* = NULL

Expand All @@ -869,7 +857,11 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
case AstExpressionKind.Null:
result = voidPtrType
case AstExpressionKind.String:
result = byteType->pointer_type()
# Use array string if type hint given, e.g. foo: byte[100] = "hello"
if type_hint != NULL and type_hint->kind == TypeKind.Array and type_hint->array.item_type == byteType:
result = type_hint
else:
result = byteType->pointer_type()

case AstExpressionKind.GetEnumMember:
result = state->file_types->find_type(expr->enum_member.enum_name)
Expand All @@ -884,22 +876,31 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
fail(expr->location, msg)

case AstExpressionKind.SizeOf:
typecheck_expression_not_void(state, &expr->operands[0])
if expr->operands[0].kind == AstExpressionKind.String:
# sizeof("foo") should be 4
obj_type_hint = byteType->array_type((strlen(expr->operands[0].string) as int) + 1)
else:
obj_type_hint = NULL
typecheck_expression_not_void(state, &expr->operands[0], obj_type_hint)
result = longType

case AstExpressionKind.Instantiate:
result = typecheck_instantiation(state, &expr->instantiation, expr->location)

case AstExpressionKind.Array:
if type_hint != NULL and type_hint->kind == TypeKind.Array:
item_type_hint = type_hint->array.item_type
else:
item_type_hint = NULL
n = expr->array.length
for i = 0; i < n; i++:
typecheck_expression_not_void(state, &expr->array.items[i])
typecheck_expression_not_void(state, &expr->array.items[i], item_type_hint)
membertype = cast_array_members_to_a_common_type(state->fom_types, expr->location, expr->array)
result = membertype->array_type(n)

case AstExpressionKind.GetClassField:
if expr->class_field.uses_arrow_operator:
temptype = typecheck_expression_not_void(state, expr->class_field.instance)
temptype = typecheck_expression_not_void(state, expr->class_field.instance, NULL)
if temptype->kind != TypeKind.Pointer or temptype->value_type->kind != TypeKind.Class:
snprintf(
msg, sizeof(msg),
Expand All @@ -910,7 +911,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
fail(expr->location, msg)
result = typecheck_class_field(temptype->value_type, expr->class_field.field_name, expr->location)->type
else:
temptype = typecheck_expression_not_void(state, expr->class_field.instance)
temptype = typecheck_expression_not_void(state, expr->class_field.instance, NULL)
if temptype->kind != TypeKind.Class:
snprintf(
msg, sizeof(msg),
Expand All @@ -929,7 +930,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
if expr->call.method_call_self == NULL:
result = typecheck_function_or_method_call(state, &expr->call, NULL, expr->location)
elif expr->call.uses_arrow_operator:
temptype = typecheck_expression_not_void(state, expr->call.method_call_self)
temptype = typecheck_expression_not_void(state, expr->call.method_call_self, NULL)
if temptype->kind != TypeKind.Pointer or temptype->value_type->kind != TypeKind.Class:
snprintf(msg, sizeof(msg),
"left side of '->' operator must be a pointer to an instance of a class, not %s",
Expand All @@ -939,7 +940,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
fail(expr->location, msg)
result = typecheck_function_or_method_call(state, &expr->call, temptype->value_type, expr->location)
else:
temptype = typecheck_expression_not_void(state, expr->call.method_call_self)
temptype = typecheck_expression_not_void(state, expr->call.method_call_self, NULL)
if temptype->kind != TypeKind.Class:
snprintf(msg, sizeof(msg),
"left side of '.' operator must be an instance of a class, not %s",
Expand Down Expand Up @@ -974,7 +975,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
result = typecheck_indexing(state, &expr->operands[0], &expr->operands[1])

case AstExpressionKind.AddressOf:
result = typecheck_expression_not_void(state, &expr->operands[0])->pointer_type()
result = typecheck_expression_not_void(state, &expr->operands[0], NULL)->pointer_type()
ensure_can_take_address(state->fom_types, &expr->operands[0], "the '&' operator cannot be used with %s")

case AstExpressionKind.GetVariable:
Expand All @@ -989,7 +990,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
result = selfvar->type

case AstExpressionKind.Dereference:
temptype = typecheck_expression_not_void(state, &expr->operands[0])
temptype = typecheck_expression_not_void(state, &expr->operands[0], NULL)
typecheck_dereferenced_pointer(expr->location, temptype)
result = temptype->value_type

Expand All @@ -1008,7 +1009,7 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
result = boolType

case AstExpressionKind.Negate:
result = typecheck_expression_not_void(state, &expr->operands[0])
result = typecheck_expression_not_void(state, &expr->operands[0], NULL)
if result->kind != TypeKind.SignedInteger and result->kind != TypeKind.FloatingPoint:
snprintf(msg, sizeof(msg),
"value after '-' must be a float or double or a signed integer, not %s",
Expand All @@ -1028,8 +1029,8 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
| AstExpressionKind.Lt
| AstExpressionKind.Le
):
typecheck_expression_not_void(state, &expr->operands[0])
typecheck_expression_not_void(state, &expr->operands[1])
typecheck_expression_not_void(state, &expr->operands[0], NULL)
typecheck_expression_not_void(state, &expr->operands[1], NULL)
result = check_binop(state->fom_types, expr->kind, expr->location, &expr->operands[0], &expr->operands[1])

case (
Expand All @@ -1041,8 +1042,9 @@ def typecheck_expression(state: State*, expr: AstExpression*) -> None:
result = check_increment_or_decrement(state, expr)

case AstExpressionKind.As:
typecheck_expression_not_void(state, &expr->as_->value)
# TODO: test this: sizeof("foo" as byte[100])
result = type_from_ast(state->file_types, &expr->as_->type)
typecheck_expression_not_void(state, &expr->as_->value, result)
do_explicit_cast(state->fom_types, &expr->as_->value, result, expr->location)

case _:
Expand Down Expand Up @@ -1080,7 +1082,7 @@ def typecheck_match_statement(state: State*, match_stmt: AstMatchStatement*) ->
nremaining = -1

if match_stmt->func_name[0] == '\0':
case_type = typecheck_expression_not_void(state, &match_stmt->match_obj)
case_type = typecheck_expression_not_void(state, &match_stmt->match_obj, NULL)

match case_type->kind:
case TypeKind.SignedInteger | TypeKind.UnsignedInteger:
Expand Down Expand Up @@ -1226,11 +1228,11 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None:
and state->find_any_var(targetexpr->varname) == NULL
):
# Making a new variable. Use the type of the value being assigned.
type = typecheck_expression_not_void(state, valueexpr)
type = typecheck_expression_not_void(state, valueexpr, NULL)
state->fom_types->add_variable(type, targetexpr->varname)
else:
# Convert value to the type of an existing variable or other assignment target.
targettype = typecheck_expression_not_void(state, targetexpr)
targettype = typecheck_expression_not_void(state, targetexpr, NULL)
ensure_can_take_address(state->fom_types, targetexpr, "cannot assign to %s")

if targetexpr->kind == AstExpressionKind.Dereference:
Expand All @@ -1250,8 +1252,8 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None:
targetexpr = &stmt->assignment.target
valueexpr = &stmt->assignment.value

targettype = typecheck_expression_not_void(state, targetexpr)
value_type = typecheck_expression_not_void(state, valueexpr)
targettype = typecheck_expression_not_void(state, targetexpr, NULL)
value_type = typecheck_expression_not_void(state, valueexpr, NULL)
ensure_can_take_address(state->fom_types, targetexpr, "cannot assign to %s")

match stmt->kind:
Expand Down Expand Up @@ -1330,7 +1332,7 @@ def typecheck_statement(state: State*, stmt: AstStatement*) -> None:
"initial value for variable of type <to> cannot be of type <from>")

case AstStatementKind.ExpressionStatement:
typecheck_expression(state, &stmt->expression)
typecheck_expression(state, &stmt->expression, NULL)

case AstStatementKind.Assert:
typecheck_expression_with_implicit_cast(state, &stmt->expression, boolType, "assertion must be a bool, not <from>")
Expand Down
1 change: 0 additions & 1 deletion compiler/types_in_ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class ExpressionTypes:

# Flags to indicate whether special kinds of implicit casts happened
implicit_array_to_pointer_cast: bool # Foo[N] to Foo*
implicit_string_to_array_cast: bool # "..." to byte[N]


class LocalVariable:
Expand Down
16 changes: 15 additions & 1 deletion tests/should_succeed/array.jou
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
declare printf(fmt: byte*, ...) -> int
import "stdlib/io.jou"

# c can't do this
def make_array() -> int[3]:
Expand Down Expand Up @@ -51,4 +51,18 @@ def main() -> int:
if strings[i] == NULL:
printf("strings[%d] is NULL\n", i)

# array of fixed-size strings (#683)
# Output: hello
# Output: world
# Output: test
# Output: foo
# Output: bar
# Output: bazzybaz
strings50: byte[50][3] = ["hello", "world", "test"]
for i = 0; i < 3; i++:
puts(strings50[i])
strings50 = ["foo", "bar", "bazzybaz"]
for i = 0; i < 3; i++:
puts(strings50[i])

return 0
42 changes: 1 addition & 41 deletions tests/should_succeed/sizeof.jou
Original file line number Diff line number Diff line change
@@ -1,45 +1,5 @@
import "stdlib/io.jou"
import "stdlib/mem.jou"

def side_effect() -> int:
printf("Side Effect !!!!!\n")
return 123

class Foo:
a: int
b: long
c: byte

# See issue #224.
def ensure_sizeof_isnt_too_small_in_a_weird_corner_case() -> None:
value = Foo{a=1, b=2, c='x'}
# We need the heap allocation, because otherwise the optimizer happens to make things work.
ptr = malloc(50) as Foo*
memcpy(ptr, &value, sizeof value)
# If sizeof is too small, this prints garbage.
printf("%c\n", ptr->c) # Output: x
free(ptr)

def main() -> int:
ensure_sizeof_isnt_too_small_in_a_weird_corner_case()

bo: bool
by: byte
n: int
m: long

printf("%lld %lld %lld %lld\n", sizeof bo, sizeof by, sizeof n, sizeof m) # Output: 1 1 4 8

# test that operator precedence works
printf("%lld\n", sizeof by + sizeof n + sizeof m) # Output: 13

arr: long[100]
printf("%lld\n", sizeof arr) # Output: 800

# The "array length trick"
printf("%lld\n", sizeof arr / sizeof arr[0]) # Output: 100

# Evaluating a sizeof has no side effects.
printf("%lld\n", sizeof side_effect()) # Output: 4

printf("%lld\n", sizeof("hello"))
return 0

0 comments on commit 968f04a

Please sign in to comment.