Skip to content

Commit

Permalink
Convert string constants to arrays (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Mar 26, 2023
1 parent 7233ba9 commit 5273af6
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 82 deletions.
4 changes: 1 addition & 3 deletions self_hosted/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,7 @@ class AstClassMember:
elif self->kind == AstClassMemberKind::Method:
printf(" method ")
self->method.signature.print()
tp = TreePrinter{}
strcpy(tp.prefix, " ")
self->method.body.print(tp)
self->method.body.print(TreePrinter{prefix = " "})
else:
assert False

Expand Down
43 changes: 27 additions & 16 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,23 @@ class AstToIR:
block = LLVMAppendBasicBlock(self->llvm_function, name_hint)
LLVMPositionBuilderAtEnd(self->builder, block)

def make_a_string_constant(self, s: byte*) -> LLVMValue*:
array = LLVMConstString(s, strlen(s) as int, False)
global_var = LLVMAddGlobal(self->module, LLVMTypeOf(array), "string_literal")
LLVMSetLinkage(global_var, LLVMLinkage::Private) # This makes it a static global variable
LLVMSetInitializer(global_var, array)

string_type = LLVMPointerType(LLVMInt8Type(), 0)
return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr")
# If array_len is -1, returns a pointer to the start of a static global string.
# Otherwise returns an array value.
def make_a_string_constant(self, s: byte*, array_len: int) -> LLVMValue*:
if array_len == -1:
array = LLVMConstString(s, strlen(s) as int, False)
global_var = LLVMAddGlobal(self->module, LLVMTypeOf(array), "string_literal")
LLVMSetLinkage(global_var, LLVMLinkage::Private) # This makes it a static global variable
LLVMSetInitializer(global_var, array)
string_type = LLVMPointerType(LLVMInt8Type(), 0)
return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr")
else:
assert strlen(s) < array_len
padded = calloc(1, array_len)
strcpy(padded, s)
array = LLVMConstString(padded, array_len, True)
free(padded)
return array

def do_cast(self, obj: LLVMValue*, from: Type*, to: Type*) -> LLVMValue*:
# Treat enums as just integers
Expand Down Expand Up @@ -319,8 +328,8 @@ class AstToIR:
assert assert_fail_func != NULL

args = [
self->make_a_string_constant("foo"),
self->make_a_string_constant("bar"),
self->make_a_string_constant("foo", -1),
self->make_a_string_constant("bar", -1),
LLVMConstInt(LLVMInt32Type(), 123, False),
]

Expand Down Expand Up @@ -424,10 +433,8 @@ class AstToIR:
for i = 0; i < call->nargs; i++:
args[k++] = self->do_expression(&call->args[i])

name_hint: byte[100]
if signature->return_type == NULL:
strcpy(name_hint, "")
else:
name_hint: byte[100] = ""
if signature->return_type != NULL:
sprintf(name_hint, "%.20s_return_value", signature->name)

result = LLVMBuildCall2(self->builder, function_type, function, args, k, name_hint)
Expand All @@ -448,7 +455,11 @@ class AstToIR:
)

if ast->kind == AstExpressionKind::String:
result = self->make_a_string_constant(ast->string)
if types->implicit_string_to_array_cast:
array_len = types->implicit_cast_type->array.length
else:
array_len = -1
result = self->make_a_string_constant(ast->string, array_len)
elif ast->kind == AstExpressionKind::Bool:
result = LLVMConstInt(LLVMInt1Type(), ast->bool_value as long, False)
elif ast->kind == AstExpressionKind::Byte:
Expand Down Expand Up @@ -568,7 +579,7 @@ class AstToIR:
assert False

types = self->function_or_method_types->get_expression_types(ast)
if types->implicit_cast_type == NULL:
if types->implicit_cast_type == NULL or types->implicit_string_to_array_cast:
return result
return self->do_cast(result, types->original_type, types->implicit_cast_type)

Expand Down
4 changes: 1 addition & 3 deletions self_hosted/parser.jou
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ def parse_function_or_method_signature(tokens: Token**, is_method: bool) -> AstS
if not is_method:
fail((*tokens)->location, "'self' cannot be used here")

the_self: byte[100]
strcpy(the_self, "self")
result.args = realloc(result.args, sizeof result.args[0] * (result.nargs+1))
result.args[result.nargs++] = AstNameTypeValue{
name = the_self,
name = "self",
name_location = (*tokens)->location,
}
++*tokens
Expand Down
2 changes: 1 addition & 1 deletion self_hosted/target.jou
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def init_target() -> void:

if is_windows():
# LLVM's default is x86_64-pc-windows-msvc
strcpy(target.triple, "x86_64-pc-windows-gnu")
target.triple = "x86_64-pc-windows-gnu"
else:
triple = LLVMGetDefaultTargetTriple()
assert strlen(triple) < sizeof target.triple
Expand Down
22 changes: 11 additions & 11 deletions self_hosted/token.jou
Original file line number Diff line number Diff line change
Expand Up @@ -104,33 +104,33 @@ class Token:
def fail_expected_got(self, what_was_expected_instead: byte*) -> void:
got: byte[100]
if self->kind == TokenKind::Short:
strcpy(got, "a short")
got = "a short"
elif self->kind == TokenKind::Int:
strcpy(got, "an integer")
got = "an integer"
elif self->kind == TokenKind::Long:
strcpy(got, "a long integer")
got = "a long integer"
elif self->kind == TokenKind::Float:
strcpy(got, "a float constant")
got = "a float constant"
elif self->kind == TokenKind::Double:
strcpy(got, "a double constant")
got = "a double constant"
elif self->kind == TokenKind::Byte:
strcpy(got, "a byte literal")
got = "a byte literal"
elif self->kind == TokenKind::String:
strcpy(got, "a string")
got = "a string"
elif self->kind == TokenKind::Name:
snprintf(got, sizeof got, "a variable name '%s'", self->short_string)
elif self->kind == TokenKind::Keyword:
snprintf(got, sizeof got, "the '%s' keyword", self->short_string)
elif self->kind == TokenKind::Newline:
strcpy(got, "end of line")
got = "end of line"
elif self->kind == TokenKind::Indent:
strcpy(got, "more indentation")
got = "more indentation"
elif self->kind == TokenKind::Dedent:
strcpy(got, "less indentation")
got = "less indentation"
elif self->kind == TokenKind::Operator:
snprintf(got, sizeof got, "'%s'", self->short_string)
elif self->kind == TokenKind::EndOfFile:
strcpy(got, "end of file")
got = "end of file"
else:
assert False

Expand Down
65 changes: 42 additions & 23 deletions self_hosted/typecheck.jou
Original file line number Diff line number Diff line change
Expand Up @@ -101,26 +101,47 @@ class ExpressionTypes:
expression: AstExpression*
original_type: Type*
implicit_cast_type: Type* # NULL if no implicit casting is needed
implicit_array_to_pointer_cast: bool # Whether a special kind of implicit cast happened
next: ExpressionTypes* # TODO: switch to more efficient structure than linked list?

# Flags to indicate whether special kinds of implicit casts happened
implicit_array_to_pointer_cast: bool # Foo[N] to Foo*
implicit_string_to_array_cast: bool # "..." to byte[N]

def get_type_after_implicit_cast(self) -> Type*:
assert self->original_type != NULL
if self->implicit_cast_type == NULL:
return self->original_type
return self->implicit_cast_type

# TODO: error_location is probably unnecessary, can get location from self->expression
def do_implicit_cast(self, to: Type*, error_location: Location, error_template: byte*) -> void:
# This cannot be called multiple times
assert self->implicit_cast_type == NULL
assert not self->implicit_array_to_pointer_cast
assert not self->implicit_string_to_array_cast

from = self->original_type
if from == to:
return

if (
self->expression->kind == AstExpressionKind::String
and from == byte_type->get_pointer_type()
and to->kind == TypeKind::Array
and to->array.item_type == byte_type
):
string_size = strlen(self->expression->string) + 1
if to->array.length < string_size:
message: byte[100]
snprintf(
message, sizeof message,
"a string of %d bytes (including '\\0') does not fit into %s",
string_size, to->name,
)
fail(error_location, message)
self->implicit_string_to_array_cast = True
# Passing in NULL for error_template can be used to force a cast to happen.
if error_template != NULL and not can_cast_implicitly(from, to):
elif error_template != NULL and not can_cast_implicitly(from, to):
fail_with_implicit_cast_error(error_location, error_template, from, to)

self->implicit_cast_type = to
Expand Down Expand Up @@ -437,7 +458,7 @@ def short_expression_description(expr: AstExpression*) -> byte[200]:
or expr->kind == AstExpressionKind::Bool
or expr->kind == AstExpressionKind::Null
):
strcpy(result, "a constant")
return "a constant"
elif (
expr->kind == AstExpressionKind::Negate
or expr->kind == AstExpressionKind::Add
Expand All @@ -446,7 +467,7 @@ def short_expression_description(expr: AstExpression*) -> byte[200]:
or expr->kind == AstExpressionKind::Divide
or expr->kind == AstExpressionKind::Modulo
):
strcpy(result, "the result of a calculation")
return "the result of a calculation"
elif (
expr->kind == AstExpressionKind::Eq
or expr->kind == AstExpressionKind::Ne
Expand All @@ -455,44 +476,45 @@ def short_expression_description(expr: AstExpression*) -> byte[200]:
or expr->kind == AstExpressionKind::Lt
or expr->kind == AstExpressionKind::Le
):
strcpy(result, "the result of a comparison")
return "the result of a comparison"
elif expr->kind == AstExpressionKind::Call:
sprintf(result, "a %s call", expr->call.function_or_method())
return result
elif expr->kind == AstExpressionKind::Instantiate:
strcpy(result, "a newly created instance")
return "a newly created instance"
elif expr->kind == AstExpressionKind::GetVariable:
strcpy(result, "a variable")
return "a variable"
elif expr->kind == AstExpressionKind::GetEnumMember:
strcpy(result, "an enum member")
return "an enum member"
elif expr->kind == AstExpressionKind::GetClassField:
snprintf(result, sizeof result, "field '%s'", expr->class_field.field_name)
return result
elif expr->kind == AstExpressionKind::As:
strcpy(result, "the result of a cast")
return "the result of a cast"
elif expr->kind == AstExpressionKind::SizeOf:
strcpy(result, "a sizeof expression")
return "a sizeof expression"
elif expr->kind == AstExpressionKind::AddressOf:
subresult = short_expression_description(expr->operands)
snprintf(result, sizeof result, "address of %s", subresult)
return result
elif expr->kind == AstExpressionKind::Dereference:
strcpy(result, "the value of a pointer")
return "the value of a pointer"
elif expr->kind == AstExpressionKind::And:
strcpy(result, "the result of 'and'")
return "the result of 'and'"
elif expr->kind == AstExpressionKind::Or:
strcpy(result, "the result of 'or'")
return "the result of 'or'"
elif expr->kind == AstExpressionKind::Not:
strcpy(result, "the result of 'not'")
return "the result of 'not'"
elif expr->kind == AstExpressionKind::PreIncr or expr->kind == AstExpressionKind::PostIncr:
strcpy(result, "the result of incrementing a value")
return "the result of incrementing a value"
elif expr->kind == AstExpressionKind::PreDecr or expr->kind == AstExpressionKind::PostDecr:
strcpy(result, "the result of decrementing a value")
return "the result of decrementing a value"
elif expr->kind == AstExpressionKind::Indexing:
strcpy(result, "an indexed value")
return "an indexed value"
else:
printf("*** %d\n", expr->kind)
assert False

return result

# The & operator can't go in front of most expressions.
# You can't do &(1 + 2), for example.
#
Expand Down Expand Up @@ -1111,10 +1133,7 @@ class Stage3TypeChecker:
# This is a common error, so try to produce a helpful error message.
error_template: byte[500]
if target_expr->kind == AstExpressionKind::Dereference:
strcpy(
error_template,
"cannot place a value of type <from> into a pointer of type <to>*",
)
error_template = "cannot place a value of type <from> into a pointer of type <to>*"
else:
target_description: byte[200] = short_expression_description(target_expr)
snprintf(
Expand Down
24 changes: 8 additions & 16 deletions self_hosted/types.jou
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,10 @@ global int_type: Type*
global long_type: Type*

def init_types() -> void:
strcpy(void_ptr_type.name, "void*")
void_ptr_type.kind = TypeKind::VoidPointer

strcpy(bool_type.name, "bool")
bool_type.kind = TypeKind::Bool

strcpy(float_type.name, "float")
strcpy(double_type.name, "double")
float_type.size_in_bits = 32
double_type.size_in_bits = 64
float_type.kind = TypeKind::FloatingPoint
double_type.kind = TypeKind::FloatingPoint
void_ptr_type = Type{name = "void*", kind = TypeKind::VoidPointer}
bool_type = Type{name = "bool", kind = TypeKind::Bool}
float_type = Type{name = "float", size_in_bits = 32, kind = TypeKind::FloatingPoint}
double_type = Type{name = "double", size_in_bits = 64, kind = TypeKind::FloatingPoint}

for size = 8; size <= 64; size *= 2:
sprintf(signed_integers[size].name, "<%d-bit signed integer>", size)
Expand All @@ -160,10 +152,10 @@ def init_types() -> void:
int_type = &signed_integers[32]
long_type = &signed_integers[64]

strcpy(byte_type->name, "byte")
strcpy(short_type->name, "short")
strcpy(int_type->name, "int")
strcpy(long_type->name, "long")
byte_type->name = "byte"
short_type->name = "short"
int_type->name = "int"
long_type->name = "long"

def create_opaque_class(name: byte*) -> Type*:
result: Type* = malloc(sizeof *result)
Expand Down
26 changes: 23 additions & 3 deletions src/build_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,24 @@ static const LocalVariable *build_expression(struct State *st, const AstExpressi
return memberptr;
}

if (types && types->implicit_string_to_array_cast) {
assert(types->implicit_cast_type);
assert(types->implicit_cast_type->kind == TYPE_ARRAY);
assert(expr->kind == AST_EXPR_CONSTANT);
assert(expr->data.constant.kind == CONSTANT_STRING);

char *padded = calloc(1, types->implicit_cast_type->data.array.len);
strcpy(padded, expr->data.constant.data.str);

const LocalVariable *result = add_local_var(st, types->implicit_cast_type);
union CfInstructionData data = { .strarray = {
.len = types->implicit_cast_type->data.array.len,
.str = padded,
}};
add_instruction(st, expr->location, CF_STRING_ARRAY, &data, NULL, result);
return result;
}

const LocalVariable *result, *temp;

switch(expr->kind) {
Expand Down Expand Up @@ -586,9 +604,11 @@ static const LocalVariable *build_expression(struct State *st, const AstExpressi
result = build_address_of_expression(st, &expr->data.operands[0]);
break;
case AST_EXPR_SIZEOF:
result = add_local_var(st, longType);
union CfInstructionData data = { .type = get_expr_types(st, &expr->data.operands[0])->type };
add_instruction(st, expr->location, CF_SIZEOF, &data, NULL, result);
{
result = add_local_var(st, longType);
union CfInstructionData data = { .type = get_expr_types(st, &expr->data.operands[0])->type };
add_instruction(st, expr->location, CF_SIZEOF, &data, NULL, result);
}
break;
case AST_EXPR_DEREFERENCE:
temp = build_expression(st, &expr->data.operands[0]);
Expand Down
1 change: 1 addition & 0 deletions src/codegen.c
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ static void codegen_instruction(const struct State *st, const CfInstruction *ins
}
break;
case CF_CONSTANT: setdest(codegen_constant(st, &ins->data.constant)); break;
case CF_STRING_ARRAY: setdest(LLVMConstString(ins->data.strarray.str, ins->data.strarray.len, true)); break;
case CF_SIZEOF: setdest(LLVMSizeOf(codegen_type(ins->data.type))); break;
case CF_ADDRESS_OF_LOCAL_VAR: setdest(get_pointer_to_local_var(st, ins->operands[0])); break;
case CF_ADDRESS_OF_GLOBAL_VAR: setdest(LLVMGetNamedGlobal(st->module, ins->data.globalname)); break;
Expand Down
2 changes: 2 additions & 0 deletions src/free.c
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ void free_control_flow_graph_block(const CfGraph *cfg, CfBlock *b)
for (const CfInstruction *ins = b->instructions.ptr; ins < End(b->instructions); ins++) {
if (ins->kind == CF_CONSTANT)
free_constant(&ins->data.constant);
if (ins->kind == CF_STRING_ARRAY)
free(ins->data.strarray.str);
if (ins->kind == CF_CALL)
free_signature(&ins->data.signature);
free(ins->operands);
Expand Down
Loading

0 comments on commit 5273af6

Please sign in to comment.