Skip to content

Commit

Permalink
Use even more match statements in the compiler (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Jan 15, 2025
1 parent 766ca36 commit 8a7df41
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 259 deletions.
76 changes: 40 additions & 36 deletions compiler/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -595,42 +595,46 @@ class AstStatement:
printf("\n")

def free(self) -> None:
if self->kind == AstStatementKind.ExpressionStatement:
self->expression.free()
if self->kind == AstStatementKind.Assert:
self->assertion.free()
if self->kind == AstStatementKind.Return and self->return_value != NULL:
self->return_value->free()
free(self->return_value)
if self->kind == AstStatementKind.If:
self->if_statement.free()
if self->kind == AstStatementKind.WhileLoop:
self->while_loop.free()
if self->kind == AstStatementKind.ForLoop:
self->for_loop.free()
if self->kind == AstStatementKind.Match:
self->match_statement.free()
if (
self->kind == AstStatementKind.DeclareLocalVar
or self->kind == AstStatementKind.GlobalVariableDeclaration
or self->kind == AstStatementKind.GlobalVariableDefinition
):
self->var_declaration.free()
if (
self->kind == AstStatementKind.Assign
or self->kind == AstStatementKind.InPlaceAdd
or self->kind == AstStatementKind.InPlaceSub
or self->kind == AstStatementKind.InPlaceMul
or self->kind == AstStatementKind.InPlaceDiv
or self->kind == AstStatementKind.InPlaceMod
):
self->assignment.free()
if self->kind == AstStatementKind.Function:
self->function.free()
if self->kind == AstStatementKind.Class:
self->classdef.free()
if self->kind == AstStatementKind.Enum:
self->enumdef.free()
match self->kind:
case AstStatementKind.ExpressionStatement:
self->expression.free()
case AstStatementKind.Assert:
self->assertion.free()
case AstStatementKind.Return:
if self->return_value != NULL:
self->return_value->free()
free(self->return_value)
case AstStatementKind.If:
self->if_statement.free()
case AstStatementKind.WhileLoop:
self->while_loop.free()
case AstStatementKind.ForLoop:
self->for_loop.free()
case AstStatementKind.Match:
self->match_statement.free()
case (
AstStatementKind.DeclareLocalVar
| AstStatementKind.GlobalVariableDeclaration
| AstStatementKind.GlobalVariableDefinition
):
self->var_declaration.free()
case (
AstStatementKind.Assign
| AstStatementKind.InPlaceAdd
| AstStatementKind.InPlaceSub
| AstStatementKind.InPlaceMul
| AstStatementKind.InPlaceDiv
| AstStatementKind.InPlaceMod
):
self->assignment.free()
case AstStatementKind.Function:
self->function.free()
case AstStatementKind.Class:
self->classdef.free()
case AstStatementKind.Enum:
self->enumdef.free()
case AstStatementKind.Pass | AstStatementKind.Break | AstStatementKind.Continue:
pass


# Useful for e.g. "while condition: body", "if condition: body"
Expand Down
96 changes: 49 additions & 47 deletions compiler/build_cf_graph.jou
Original file line number Diff line number Diff line change
Expand Up @@ -425,62 +425,64 @@ class CfBuilder:
return result

def build_address_of_expression(self, address_of_what: AstExpression*) -> LocalVariable*:
if address_of_what->kind == AstExpressionKind.GetVariable:
ptrtype = address_of_what->types.type->pointer_type()
addr = self->add_var(ptrtype)

local_var = self->find_var(address_of_what->varname)
if local_var == NULL:
# Global variable (possibly imported from another file)
ins = CfInstruction {
location = address_of_what->location,
kind = CfInstructionKind.AddressOfGlobalVar,
destvar = addr,
}
assert sizeof(ins.globalname) == sizeof(address_of_what->varname)
strcpy(ins.globalname, address_of_what->varname)
self->add_instruction(ins)
else:
self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr)
return addr
match address_of_what->kind:
case AstExpressionKind.GetVariable:
ptrtype = address_of_what->types.type->pointer_type()
addr = self->add_var(ptrtype)

local_var = self->find_var(address_of_what->varname)
if local_var == NULL:
# Global variable (possibly imported from another file)
ins = CfInstruction {
location = address_of_what->location,
kind = CfInstructionKind.AddressOfGlobalVar,
destvar = addr,
}
assert sizeof(ins.globalname) == sizeof(address_of_what->varname)
strcpy(ins.globalname, address_of_what->varname)
self->add_instruction(ins)
else:
self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr)
return addr

if address_of_what->kind == AstExpressionKind.Self:
ptrtype = address_of_what->types.type->pointer_type()
addr = self->add_var(ptrtype)
case AstExpressionKind.Self:
ptrtype = address_of_what->types.type->pointer_type()
addr = self->add_var(ptrtype)

local_var = self->find_var("self")
assert local_var != NULL
self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr)
return addr
local_var = self->find_var("self")
assert local_var != NULL
self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr)
return addr

if address_of_what->kind == AstExpressionKind.Dereference:
# &*foo --> just evaluate foo
return self->build_expression(&address_of_what->operands[0])
case AstExpressionKind.Dereference:
# &*foo --> just evaluate foo
return self->build_expression(&address_of_what->operands[0])

if address_of_what->kind == AstExpressionKind.GetClassField:
if address_of_what->class_field.uses_arrow_operator:
# &obj->field aka &(obj->field)
obj = self->build_expression(address_of_what->class_field.instance)
else:
# &obj.field aka &(obj.field), evaluate as &(&obj)->field
obj = self->build_address_of_expression(address_of_what->class_field.instance)
case AstExpressionKind.GetClassField:
if address_of_what->class_field.uses_arrow_operator:
# &obj->field aka &(obj->field)
obj = self->build_expression(address_of_what->class_field.instance)
else:
# &obj.field aka &(obj.field), evaluate as &(&obj)->field
obj = self->build_address_of_expression(address_of_what->class_field.instance)

assert obj->type->kind == TypeKind.Pointer
assert obj->type->value_type->kind == TypeKind.Class
return self->build_class_field_pointer(obj, address_of_what->class_field.field_name, address_of_what->location)
assert obj->type->kind == TypeKind.Pointer
assert obj->type->value_type->kind == TypeKind.Class
return self->build_class_field_pointer(obj, address_of_what->class_field.field_name, address_of_what->location)

if address_of_what->kind == AstExpressionKind.Indexing:
ptr = self->build_expression(&address_of_what->operands[0])
assert ptr->type->kind == TypeKind.Pointer
case AstExpressionKind.Indexing:
ptr = self->build_expression(&address_of_what->operands[0])
assert ptr->type->kind == TypeKind.Pointer

index = self->build_expression(&address_of_what->operands[1])
assert index->type->is_integer_type()
index = self->build_expression(&address_of_what->operands[1])
assert index->type->is_integer_type()

result = self->add_var(ptr->type)
self->binary_op(address_of_what->location, CfInstructionKind.PtrAddInt, ptr, index, result)
return result
result = self->add_var(ptr->type)
self->binary_op(address_of_what->location, CfInstructionKind.PtrAddInt, ptr, index, result)
return result

assert False
case _:
assert False

def build_call(
self,
Expand Down
16 changes: 9 additions & 7 deletions compiler/cf_graph.jou
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import "./types.jou"


def very_short_number_type_description(t: Type*) -> byte*:
if t->kind == TypeKind.FloatingPoint:
return "floating"
if t->kind == TypeKind.SignedInteger:
return "signed"
if t->kind == TypeKind.UnsignedInteger:
return "unsigned"
assert False
match t->kind:
case TypeKind.FloatingPoint:
return "floating"
case TypeKind.SignedInteger:
return "signed"
case TypeKind.UnsignedInteger:
return "unsigned"
case _:
assert False


class CfStringArray:
Expand Down
75 changes: 36 additions & 39 deletions compiler/codegen.jou
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,29 @@ def codegen_class_type(type: Type*) -> LLVMType*:


def codegen_type(type: Type*) -> LLVMType*:
if type->kind == TypeKind.Array:
return LLVMArrayType(codegen_type(type->array.item_type), type->array.len)
if type->is_pointer_type():
# Element type doesn't matter in new LLVM versions.
return LLVMPointerType(LLVMInt8Type(), 0)
if type->kind == TypeKind.FloatingPoint:
if type->size_in_bits == 32:
return LLVMFloatType()
if type->size_in_bits == 64:
return LLVMDoubleType()
assert False
if (
type->kind == TypeKind.SignedInteger
or type->kind == TypeKind.UnsignedInteger
):
return LLVMIntType(type->size_in_bits)
if type->kind == TypeKind.Bool:
return LLVMInt1Type()
if type->kind == TypeKind.OpaqueClass:
# this is compiler internal/temporary thing and should never end up here
assert False
if type->kind == TypeKind.Class:
return codegen_class_type(type)
if type->kind == TypeKind.Enum:
return LLVMInt32Type()
assert False
match type->kind:
case TypeKind.Array:
return LLVMArrayType(codegen_type(type->array.item_type), type->array.len)
case TypeKind.Pointer | TypeKind.VoidPointer:
# Element type doesn't matter in new LLVM versions.
return LLVMPointerType(LLVMInt8Type(), 0)
case TypeKind.FloatingPoint:
if type->size_in_bits == 32:
return LLVMFloatType()
if type->size_in_bits == 64:
return LLVMDoubleType()
assert False
case TypeKind.SignedInteger | TypeKind.UnsignedInteger:
return LLVMIntType(type->size_in_bits)
case TypeKind.Bool:
return LLVMInt1Type()
case TypeKind.OpaqueClass:
# this is compiler internal/temporary thing and should never end up here
assert False
case TypeKind.Class:
return codegen_class_type(type)
case TypeKind.Enum:
return LLVMInt32Type()


def codegen_function_type(sig: Signature*) -> LLVMType*:
Expand Down Expand Up @@ -240,19 +237,19 @@ class CodeGen:


def do_constant(self, c: Constant*) -> LLVMValue*:
if c->kind == ConstantKind.Bool:
return LLVMConstInt(LLVMInt1Type(), c->boolean as long, False as int)
if c->kind == ConstantKind.Integer:
return LLVMConstInt(codegen_type(c->get_type()), c->integer.value, c->integer.is_signed as int)
if c->kind == ConstantKind.Float or c->kind == ConstantKind.Double:
return LLVMConstRealOfString(codegen_type(c->get_type()), c->double_or_float_text)
if c->kind == ConstantKind.Null:
return LLVMConstNull(codegen_type(voidPtrType))
if c->kind == ConstantKind.String:
return self->do_string(c->str)
if c->kind == ConstantKind.EnumMember:
return LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int)
assert False
match c->kind:
case ConstantKind.Bool:
return LLVMConstInt(LLVMInt1Type(), c->boolean as long, False as int)
case ConstantKind.Integer:
return LLVMConstInt(codegen_type(c->get_type()), c->integer.value, c->integer.is_signed as int)
case ConstantKind.Float | ConstantKind.Double:
return LLVMConstRealOfString(codegen_type(c->get_type()), c->double_or_float_text)
case ConstantKind.Null:
return LLVMConstNull(codegen_type(voidPtrType))
case ConstantKind.String:
return self->do_string(c->str)
case ConstantKind.EnumMember:
return LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int)

def do_arithmetic_instruction(self, ins: CfInstruction*) -> None:
lhs = self->getlocal(ins->operands[0])
Expand Down
30 changes: 15 additions & 15 deletions compiler/evaluate.jou
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def get_special_constant(name: byte*) -> int:


def evaluate_condition(expr: AstExpression*) -> bool:
if expr->kind == AstExpressionKind.GetVariable:
v = get_special_constant(expr->varname)
if v == 0:
return False
if v == 1:
return True

if expr->kind == AstExpressionKind.And:
return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1])
if expr->kind == AstExpressionKind.Or:
return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1])
if expr->kind == AstExpressionKind.Not:
return not evaluate_condition(&expr->operands[0])

fail(expr->location, "cannot evaluate condition at compile time")
match expr->kind:
case AstExpressionKind.GetVariable:
v = get_special_constant(expr->varname)
if v == 0:
return False
if v == 1:
return True
case AstExpressionKind.And:
return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1])
case AstExpressionKind.Or:
return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1])
case AstExpressionKind.Not:
return not evaluate_condition(&expr->operands[0])
case _:
fail(expr->location, "cannot evaluate condition at compile time")


# returns the statements to replace if statement with
Expand Down
Loading

0 comments on commit 8a7df41

Please sign in to comment.