From 8a7df41895013e6a9f1c8049a9612e38fdfdb158 Mon Sep 17 00:00:00 2001 From: Akuli Date: Wed, 15 Jan 2025 16:45:23 +0200 Subject: [PATCH] Use even more match statements in the compiler (#624) --- compiler/ast.jou | 76 ++++++++------- compiler/build_cf_graph.jou | 96 ++++++++++--------- compiler/cf_graph.jou | 16 ++-- compiler/codegen.jou | 75 +++++++-------- compiler/evaluate.jou | 30 +++--- compiler/main.jou | 51 +++++----- compiler/typecheck/common.jou | 71 +++++++------- .../step3_function_and_method_bodies.jou | 63 ++++++------ compiler/types.jou | 42 ++++---- 9 files changed, 261 insertions(+), 259 deletions(-) diff --git a/compiler/ast.jou b/compiler/ast.jou index 6cef051c..8111175f 100644 --- a/compiler/ast.jou +++ b/compiler/ast.jou @@ -595,42 +595,46 @@ class AstStatement: printf("\n") def free(self) -> None: - if self->kind == AstStatementKind.ExpressionStatement: - self->expression.free() - if self->kind == AstStatementKind.Assert: - self->assertion.free() - if self->kind == AstStatementKind.Return and self->return_value != NULL: - self->return_value->free() - free(self->return_value) - if self->kind == AstStatementKind.If: - self->if_statement.free() - if self->kind == AstStatementKind.WhileLoop: - self->while_loop.free() - if self->kind == AstStatementKind.ForLoop: - self->for_loop.free() - if self->kind == AstStatementKind.Match: - self->match_statement.free() - if ( - self->kind == AstStatementKind.DeclareLocalVar - or self->kind == AstStatementKind.GlobalVariableDeclaration - or self->kind == AstStatementKind.GlobalVariableDefinition - ): - self->var_declaration.free() - if ( - self->kind == AstStatementKind.Assign - or self->kind == AstStatementKind.InPlaceAdd - or self->kind == AstStatementKind.InPlaceSub - or self->kind == AstStatementKind.InPlaceMul - or self->kind == AstStatementKind.InPlaceDiv - or self->kind == AstStatementKind.InPlaceMod - ): - self->assignment.free() - if self->kind == AstStatementKind.Function: - self->function.free() - if self->kind == AstStatementKind.Class: - self->classdef.free() - if self->kind == AstStatementKind.Enum: - self->enumdef.free() + match self->kind: + case AstStatementKind.ExpressionStatement: + self->expression.free() + case AstStatementKind.Assert: + self->assertion.free() + case AstStatementKind.Return: + if self->return_value != NULL: + self->return_value->free() + free(self->return_value) + case AstStatementKind.If: + self->if_statement.free() + case AstStatementKind.WhileLoop: + self->while_loop.free() + case AstStatementKind.ForLoop: + self->for_loop.free() + case AstStatementKind.Match: + self->match_statement.free() + case ( + AstStatementKind.DeclareLocalVar + | AstStatementKind.GlobalVariableDeclaration + | AstStatementKind.GlobalVariableDefinition + ): + self->var_declaration.free() + case ( + AstStatementKind.Assign + | AstStatementKind.InPlaceAdd + | AstStatementKind.InPlaceSub + | AstStatementKind.InPlaceMul + | AstStatementKind.InPlaceDiv + | AstStatementKind.InPlaceMod + ): + self->assignment.free() + case AstStatementKind.Function: + self->function.free() + case AstStatementKind.Class: + self->classdef.free() + case AstStatementKind.Enum: + self->enumdef.free() + case AstStatementKind.Pass | AstStatementKind.Break | AstStatementKind.Continue: + pass # Useful for e.g. "while condition: body", "if condition: body" diff --git a/compiler/build_cf_graph.jou b/compiler/build_cf_graph.jou index 28921ed5..d17dd1c9 100644 --- a/compiler/build_cf_graph.jou +++ b/compiler/build_cf_graph.jou @@ -425,62 +425,64 @@ class CfBuilder: return result def build_address_of_expression(self, address_of_what: AstExpression*) -> LocalVariable*: - if address_of_what->kind == AstExpressionKind.GetVariable: - ptrtype = address_of_what->types.type->pointer_type() - addr = self->add_var(ptrtype) - - local_var = self->find_var(address_of_what->varname) - if local_var == NULL: - # Global variable (possibly imported from another file) - ins = CfInstruction { - location = address_of_what->location, - kind = CfInstructionKind.AddressOfGlobalVar, - destvar = addr, - } - assert sizeof(ins.globalname) == sizeof(address_of_what->varname) - strcpy(ins.globalname, address_of_what->varname) - self->add_instruction(ins) - else: - self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr) - return addr + match address_of_what->kind: + case AstExpressionKind.GetVariable: + ptrtype = address_of_what->types.type->pointer_type() + addr = self->add_var(ptrtype) + + local_var = self->find_var(address_of_what->varname) + if local_var == NULL: + # Global variable (possibly imported from another file) + ins = CfInstruction { + location = address_of_what->location, + kind = CfInstructionKind.AddressOfGlobalVar, + destvar = addr, + } + assert sizeof(ins.globalname) == sizeof(address_of_what->varname) + strcpy(ins.globalname, address_of_what->varname) + self->add_instruction(ins) + else: + self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr) + return addr - if address_of_what->kind == AstExpressionKind.Self: - ptrtype = address_of_what->types.type->pointer_type() - addr = self->add_var(ptrtype) + case AstExpressionKind.Self: + ptrtype = address_of_what->types.type->pointer_type() + addr = self->add_var(ptrtype) - local_var = self->find_var("self") - assert local_var != NULL - self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr) - return addr + local_var = self->find_var("self") + assert local_var != NULL + self->unary_op(address_of_what->location, CfInstructionKind.AddressOfLocalVar, local_var, addr) + return addr - if address_of_what->kind == AstExpressionKind.Dereference: - # &*foo --> just evaluate foo - return self->build_expression(&address_of_what->operands[0]) + case AstExpressionKind.Dereference: + # &*foo --> just evaluate foo + return self->build_expression(&address_of_what->operands[0]) - if address_of_what->kind == AstExpressionKind.GetClassField: - if address_of_what->class_field.uses_arrow_operator: - # &obj->field aka &(obj->field) - obj = self->build_expression(address_of_what->class_field.instance) - else: - # &obj.field aka &(obj.field), evaluate as &(&obj)->field - obj = self->build_address_of_expression(address_of_what->class_field.instance) + case AstExpressionKind.GetClassField: + if address_of_what->class_field.uses_arrow_operator: + # &obj->field aka &(obj->field) + obj = self->build_expression(address_of_what->class_field.instance) + else: + # &obj.field aka &(obj.field), evaluate as &(&obj)->field + obj = self->build_address_of_expression(address_of_what->class_field.instance) - assert obj->type->kind == TypeKind.Pointer - assert obj->type->value_type->kind == TypeKind.Class - return self->build_class_field_pointer(obj, address_of_what->class_field.field_name, address_of_what->location) + assert obj->type->kind == TypeKind.Pointer + assert obj->type->value_type->kind == TypeKind.Class + return self->build_class_field_pointer(obj, address_of_what->class_field.field_name, address_of_what->location) - if address_of_what->kind == AstExpressionKind.Indexing: - ptr = self->build_expression(&address_of_what->operands[0]) - assert ptr->type->kind == TypeKind.Pointer + case AstExpressionKind.Indexing: + ptr = self->build_expression(&address_of_what->operands[0]) + assert ptr->type->kind == TypeKind.Pointer - index = self->build_expression(&address_of_what->operands[1]) - assert index->type->is_integer_type() + index = self->build_expression(&address_of_what->operands[1]) + assert index->type->is_integer_type() - result = self->add_var(ptr->type) - self->binary_op(address_of_what->location, CfInstructionKind.PtrAddInt, ptr, index, result) - return result + result = self->add_var(ptr->type) + self->binary_op(address_of_what->location, CfInstructionKind.PtrAddInt, ptr, index, result) + return result - assert False + case _: + assert False def build_call( self, diff --git a/compiler/cf_graph.jou b/compiler/cf_graph.jou index 9b28cb72..26597f08 100644 --- a/compiler/cf_graph.jou +++ b/compiler/cf_graph.jou @@ -12,13 +12,15 @@ import "./types.jou" def very_short_number_type_description(t: Type*) -> byte*: - if t->kind == TypeKind.FloatingPoint: - return "floating" - if t->kind == TypeKind.SignedInteger: - return "signed" - if t->kind == TypeKind.UnsignedInteger: - return "unsigned" - assert False + match t->kind: + case TypeKind.FloatingPoint: + return "floating" + case TypeKind.SignedInteger: + return "signed" + case TypeKind.UnsignedInteger: + return "unsigned" + case _: + assert False class CfStringArray: diff --git a/compiler/codegen.jou b/compiler/codegen.jou index e273b1de..bbd2af16 100644 --- a/compiler/codegen.jou +++ b/compiler/codegen.jou @@ -84,32 +84,29 @@ def codegen_class_type(type: Type*) -> LLVMType*: def codegen_type(type: Type*) -> LLVMType*: - if type->kind == TypeKind.Array: - return LLVMArrayType(codegen_type(type->array.item_type), type->array.len) - if type->is_pointer_type(): - # Element type doesn't matter in new LLVM versions. - return LLVMPointerType(LLVMInt8Type(), 0) - if type->kind == TypeKind.FloatingPoint: - if type->size_in_bits == 32: - return LLVMFloatType() - if type->size_in_bits == 64: - return LLVMDoubleType() - assert False - if ( - type->kind == TypeKind.SignedInteger - or type->kind == TypeKind.UnsignedInteger - ): - return LLVMIntType(type->size_in_bits) - if type->kind == TypeKind.Bool: - return LLVMInt1Type() - if type->kind == TypeKind.OpaqueClass: - # this is compiler internal/temporary thing and should never end up here - assert False - if type->kind == TypeKind.Class: - return codegen_class_type(type) - if type->kind == TypeKind.Enum: - return LLVMInt32Type() - assert False + match type->kind: + case TypeKind.Array: + return LLVMArrayType(codegen_type(type->array.item_type), type->array.len) + case TypeKind.Pointer | TypeKind.VoidPointer: + # Element type doesn't matter in new LLVM versions. + return LLVMPointerType(LLVMInt8Type(), 0) + case TypeKind.FloatingPoint: + if type->size_in_bits == 32: + return LLVMFloatType() + if type->size_in_bits == 64: + return LLVMDoubleType() + assert False + case TypeKind.SignedInteger | TypeKind.UnsignedInteger: + return LLVMIntType(type->size_in_bits) + case TypeKind.Bool: + return LLVMInt1Type() + case TypeKind.OpaqueClass: + # this is compiler internal/temporary thing and should never end up here + assert False + case TypeKind.Class: + return codegen_class_type(type) + case TypeKind.Enum: + return LLVMInt32Type() def codegen_function_type(sig: Signature*) -> LLVMType*: @@ -240,19 +237,19 @@ class CodeGen: def do_constant(self, c: Constant*) -> LLVMValue*: - if c->kind == ConstantKind.Bool: - return LLVMConstInt(LLVMInt1Type(), c->boolean as long, False as int) - if c->kind == ConstantKind.Integer: - return LLVMConstInt(codegen_type(c->get_type()), c->integer.value, c->integer.is_signed as int) - if c->kind == ConstantKind.Float or c->kind == ConstantKind.Double: - return LLVMConstRealOfString(codegen_type(c->get_type()), c->double_or_float_text) - if c->kind == ConstantKind.Null: - return LLVMConstNull(codegen_type(voidPtrType)) - if c->kind == ConstantKind.String: - return self->do_string(c->str) - if c->kind == ConstantKind.EnumMember: - return LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int) - assert False + match c->kind: + case ConstantKind.Bool: + return LLVMConstInt(LLVMInt1Type(), c->boolean as long, False as int) + case ConstantKind.Integer: + return LLVMConstInt(codegen_type(c->get_type()), c->integer.value, c->integer.is_signed as int) + case ConstantKind.Float | ConstantKind.Double: + return LLVMConstRealOfString(codegen_type(c->get_type()), c->double_or_float_text) + case ConstantKind.Null: + return LLVMConstNull(codegen_type(voidPtrType)) + case ConstantKind.String: + return self->do_string(c->str) + case ConstantKind.EnumMember: + return LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int) def do_arithmetic_instruction(self, ins: CfInstruction*) -> None: lhs = self->getlocal(ins->operands[0]) diff --git a/compiler/evaluate.jou b/compiler/evaluate.jou index 3fa71202..17912c20 100644 --- a/compiler/evaluate.jou +++ b/compiler/evaluate.jou @@ -22,21 +22,21 @@ def get_special_constant(name: byte*) -> int: def evaluate_condition(expr: AstExpression*) -> bool: - if expr->kind == AstExpressionKind.GetVariable: - v = get_special_constant(expr->varname) - if v == 0: - return False - if v == 1: - return True - - if expr->kind == AstExpressionKind.And: - return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1]) - if expr->kind == AstExpressionKind.Or: - return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1]) - if expr->kind == AstExpressionKind.Not: - return not evaluate_condition(&expr->operands[0]) - - fail(expr->location, "cannot evaluate condition at compile time") + match expr->kind: + case AstExpressionKind.GetVariable: + v = get_special_constant(expr->varname) + if v == 0: + return False + if v == 1: + return True + case AstExpressionKind.And: + return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1]) + case AstExpressionKind.Or: + return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1]) + case AstExpressionKind.Not: + return not evaluate_condition(&expr->operands[0]) + case _: + fail(expr->location, "cannot evaluate condition at compile time") # returns the statements to replace if statement with diff --git a/compiler/main.jou b/compiler/main.jou index 6d68d06e..1065cc81 100644 --- a/compiler/main.jou +++ b/compiler/main.jou @@ -47,34 +47,29 @@ def defines_main(ast: AstFile*) -> bool: def statement_conflicts_with_an_import(stmt: AstStatement*, importsym: ExportSymbol*) -> bool: - if stmt->kind == AstStatementKind.Function: - return ( - importsym->kind == ExportSymbolKind.Function - and strcmp(importsym->name, stmt->function.signature.name) == 0 - ) - - if ( - stmt->kind == AstStatementKind.GlobalVariableDeclaration - or stmt->kind == AstStatementKind.GlobalVariableDefinition - ): - return ( - importsym->kind == ExportSymbolKind.GlobalVar - and strcmp(importsym->name, stmt->var_declaration.name) == 0 - ) - - if stmt->kind == AstStatementKind.Class: - return ( - importsym->kind == ExportSymbolKind.Type - and strcmp(importsym->name, stmt->classdef.name) == 0 - ) - - if stmt->kind == AstStatementKind.Enum: - return ( - importsym->kind == ExportSymbolKind.Type - and strcmp(importsym->name, stmt->enumdef.name) == 0 - ) - - assert False + match stmt->kind: + case AstStatementKind.Function: + return ( + importsym->kind == ExportSymbolKind.Function + and strcmp(importsym->name, stmt->function.signature.name) == 0 + ) + case AstStatementKind.GlobalVariableDeclaration | AstStatementKind.GlobalVariableDefinition: + return ( + importsym->kind == ExportSymbolKind.GlobalVar + and strcmp(importsym->name, stmt->var_declaration.name) == 0 + ) + case AstStatementKind.Class: + return ( + importsym->kind == ExportSymbolKind.Type + and strcmp(importsym->name, stmt->classdef.name) == 0 + ) + case AstStatementKind.Enum: + return ( + importsym->kind == ExportSymbolKind.Type + and strcmp(importsym->name, stmt->enumdef.name) == 0 + ) + case _: + assert False def print_llvm_ir(module: LLVMModule*, is_optimized: bool) -> None: diff --git a/compiler/typecheck/common.jou b/compiler/typecheck/common.jou index 2b68905e..529fa524 100644 --- a/compiler/typecheck/common.jou +++ b/compiler/typecheck/common.jou @@ -173,39 +173,38 @@ def type_from_ast(ft: FileTypes*, asttype: AstType*) -> Type*: snprintf(msg, sizeof(msg), "'%s' cannot be used here because it is not a type", asttype->name) fail(asttype->location, msg) - if asttype->kind == AstTypeKind.Named: - if strcmp(asttype->name, "short") == 0: - return shortType - if strcmp(asttype->name, "int") == 0: - return intType - if strcmp(asttype->name, "long") == 0: - return longType - if strcmp(asttype->name, "byte") == 0: - return byteType - if strcmp(asttype->name, "bool") == 0: - return boolType - if strcmp(asttype->name, "float") == 0: - return floatType - if strcmp(asttype->name, "double") == 0: - return doubleType - - found = ft->find_type(asttype->name) - if found != NULL: - return found - - snprintf(msg, sizeof(msg), "there is no type named '%s'", asttype->name) - fail(asttype->location, msg) - - if asttype->kind == AstTypeKind.Pointer: - if asttype->value_type->is_void(): - return voidPtrType - return type_from_ast(ft, asttype->value_type)->pointer_type() - - if asttype->kind == AstTypeKind.Array: - tmp = type_from_ast(ft, asttype->value_type) - len = evaluate_array_length(asttype->array.length) - if len <= 0: - fail(asttype->array.length->location, "array length must be positive") - return tmp->array_type(len) - - assert False + match asttype->kind: + case AstTypeKind.Named: + if strcmp(asttype->name, "short") == 0: + return shortType + if strcmp(asttype->name, "int") == 0: + return intType + if strcmp(asttype->name, "long") == 0: + return longType + if strcmp(asttype->name, "byte") == 0: + return byteType + if strcmp(asttype->name, "bool") == 0: + return boolType + if strcmp(asttype->name, "float") == 0: + return floatType + if strcmp(asttype->name, "double") == 0: + return doubleType + + found = ft->find_type(asttype->name) + if found != NULL: + return found + + snprintf(msg, sizeof(msg), "there is no type named '%s'", asttype->name) + fail(asttype->location, msg) + + case AstTypeKind.Pointer: + if asttype->value_type->is_void(): + return voidPtrType + return type_from_ast(ft, asttype->value_type)->pointer_type() + + case AstTypeKind.Array: + tmp = type_from_ast(ft, asttype->value_type) + len = evaluate_array_length(asttype->array.length) + if len <= 0: + fail(asttype->array.length->location, "array length must be positive") + return tmp->array_type(len) diff --git a/compiler/typecheck/step3_function_and_method_bodies.jou b/compiler/typecheck/step3_function_and_method_bodies.jou index 05f86ab3..904ffc67 100644 --- a/compiler/typecheck/step3_function_and_method_bodies.jou +++ b/compiler/typecheck/step3_function_and_method_bodies.jou @@ -497,13 +497,15 @@ def typecheck_indexing( ) -> Type*: msg: byte[500] - ptrtype = typecheck_expression_not_void(ft, ptrexpr) - if ptrtype->kind == TypeKind.Array: - cast_array_to_pointer(ft->current_fom_types, ptrexpr) - ptrtype = ptrexpr->types.implicit_cast_type - else: - if ptrtype->kind != TypeKind.Pointer: - snprintf(msg, sizeof(msg), "value of type %s cannot be indexed", ptrtype->name) + orig_type = typecheck_expression_not_void(ft, ptrexpr) + match orig_type->kind: + case TypeKind.Pointer: + ptrtype = orig_type + case TypeKind.Array: + cast_array_to_pointer(ft->current_fom_types, ptrexpr) + ptrtype = ptrexpr->types.implicit_cast_type + case _: + snprintf(msg, sizeof(msg), "value of type %s cannot be indexed", orig_type->name) fail(ptrexpr->location, msg) assert ptrtype != NULL @@ -1152,7 +1154,7 @@ def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s") if targetexpr->kind == AstExpressionKind.Dereference: - strcpy(msg, "cannot place a value of type into a pointer of type *") + msg = "cannot place a value of type into a pointer of type *" else: desc = short_expression_description(targetexpr) snprintf(msg, sizeof msg, "cannot assign a value of type to %s of type ", desc) @@ -1289,24 +1291,27 @@ def typecheck_function_or_method_body(ft: FileTypes*, sig: Signature*, body: Ast def typecheck_step3_function_and_method_bodies(ft: FileTypes*, ast: AstFile*) -> None: for i = 0; i < ast->body.nstatements; i++: stmt = &ast->body.statements[i] - if stmt->kind == AstStatementKind.Function and stmt->function.body.nstatements > 0: - sig = ft->find_function(stmt->function.signature.name) - assert sig != NULL - typecheck_function_or_method_body(ft, sig, &stmt->function.body) - - if stmt->kind == AstStatementKind.Class: - classtype: Type* = NULL - for t = ft->owned_types; t < &ft->owned_types[ft->n_owned_types]; t++: - if strcmp((*t)->name, stmt->classdef.name) == 0: - classtype = *t - break - assert classtype != NULL - - for m = stmt->classdef.members; m < &stmt->classdef.members[stmt->classdef.nmembers]; m++: - if m->kind != AstClassMemberKind.Method: - continue - method = &m->method - - sig = classtype->find_method(method->signature.name) - assert sig != NULL - typecheck_function_or_method_body(ft, sig, &method->body) + match stmt->kind: + case AstStatementKind.Function: + if stmt->function.body.nstatements > 0: + sig = ft->find_function(stmt->function.signature.name) + assert sig != NULL + typecheck_function_or_method_body(ft, sig, &stmt->function.body) + case AstStatementKind.Class: + classtype: Type* = NULL + for t = ft->owned_types; t < &ft->owned_types[ft->n_owned_types]; t++: + if strcmp((*t)->name, stmt->classdef.name) == 0: + classtype = *t + break + assert classtype != NULL + + for m = stmt->classdef.members; m < &stmt->classdef.members[stmt->classdef.nmembers]; m++: + if m->kind != AstClassMemberKind.Method: + continue + method = &m->method + + sig = classtype->find_method(method->signature.name) + assert sig != NULL + typecheck_function_or_method_body(ft, sig, &method->body) + case _: + pass diff --git a/compiler/types.jou b/compiler/types.jou index e4dd78eb..6c6bae56 100644 --- a/compiler/types.jou +++ b/compiler/types.jou @@ -89,23 +89,19 @@ class Type: return &arr->type def short_description(self) -> byte*: - if self->kind == TypeKind.OpaqueClass or self->kind == TypeKind.Class: - return "a class" - if self->kind == TypeKind.Enum: - return "an enum" - if self->kind == TypeKind.VoidPointer or self->kind == TypeKind.Pointer: - return "a pointer type" - if ( - self->kind == TypeKind.SignedInteger - or self->kind == TypeKind.UnsignedInteger - or self->kind == TypeKind.FloatingPoint - ): - return "a number type" - if self->kind == TypeKind.Array: - return "an array type" - if self->kind == TypeKind.Bool: - return "the built-in bool type" - assert False + match self->kind: + case TypeKind.OpaqueClass | TypeKind.Class: + return "a class" + case TypeKind.Enum: + return "an enum" + case TypeKind.Pointer | TypeKind.VoidPointer: + return "a pointer type" + case TypeKind.SignedInteger | TypeKind.UnsignedInteger | TypeKind.FloatingPoint: + return "a number type" + case TypeKind.Array: + return "an array type" + case TypeKind.Bool: + return "the built-in bool type" def find_method(self, name: byte*) -> Signature*: if self->kind != TypeKind.Class: @@ -270,11 +266,13 @@ class Signature: def get_self_class(self) -> Type*: if self->nargs > 0 and strcmp(self->argnames[0], "self") == 0: - if self->argtypes[0]->kind == TypeKind.Pointer: - return self->argtypes[0]->value_type - if self->argtypes[0]->kind == TypeKind.Class: - return self->argtypes[0] - assert False + match self->argtypes[0]->kind: + case TypeKind.Pointer: + return self->argtypes[0]->value_type + case TypeKind.Class: + return self->argtypes[0] + case _: + assert False return NULL # Useful for error messages, not much else.