From 057f1bfd6eafc7bb1b4346f9550926120f2da605 Mon Sep 17 00:00:00 2001 From: Akuli Date: Thu, 30 Jan 2025 03:03:01 +0200 Subject: [PATCH] Introduce Undefined Value Graphs (UVGs) (#717) --- .../other_errors/undefined_variable.jou | 29 ++ ...lue_warning.jou => undefined_variable.jou} | 0 compiler/builders/ast_to_builder.jou | 135 ++++--- compiler/builders/either_builder.jou | 370 ++++++++++++++++++ compiler/builders/llvm_builder.jou | 220 +++++------ compiler/builders/uvg_builder.jou | 249 ++++++++++++ compiler/command_line_args.jou | 12 +- compiler/main.jou | 13 + compiler/uvg.jou | 153 ++++++++ compiler/uvg_analyze.jou | 7 + doc/{ => compiler_internals}/syntax-spec.md | 0 doc/compiler_internals/uvg.md | 182 +++++++++ 12 files changed, 1191 insertions(+), 179 deletions(-) create mode 100644 broken_tests/other_errors/undefined_variable.jou rename broken_tests/should_succeed/{undefined_value_warning.jou => undefined_variable.jou} (100%) create mode 100644 compiler/builders/either_builder.jou create mode 100644 compiler/builders/uvg_builder.jou create mode 100644 compiler/uvg.jou create mode 100644 compiler/uvg_analyze.jou rename doc/{ => compiler_internals}/syntax-spec.md (100%) create mode 100644 doc/compiler_internals/uvg.md diff --git a/broken_tests/other_errors/undefined_variable.jou b/broken_tests/other_errors/undefined_variable.jou new file mode 100644 index 00000000..5983e88d --- /dev/null +++ b/broken_tests/other_errors/undefined_variable.jou @@ -0,0 +1,29 @@ +import "stdlib/io.jou" + +@public +def maybe_undefined(n: int) -> None: + for i = 0; i < n; i++: + message = "Hi" + puts(message) # Warning: the value of 'message' may be undefined + +@public +def surely_undefined_loop() -> None: + while False: + message = "Hi" # Warning: this code will never run + puts(message) # Warning: the value of 'message' is undefined + +@public +def surely_undefined_annotation() -> None: + x: byte* + puts(x) # Warning: the value of 'x' is undefined + +@public +def surely_undefined_assignments() -> None: + a: int + b = &a + c = b + d = c + e = *d # TODO: should emit warning, but this is too "advanced" for UVGs + +def main() -> int: + return 0 diff --git a/broken_tests/should_succeed/undefined_value_warning.jou b/broken_tests/should_succeed/undefined_variable.jou similarity index 100% rename from broken_tests/should_succeed/undefined_value_warning.jou rename to broken_tests/should_succeed/undefined_variable.jou diff --git a/compiler/builders/ast_to_builder.jou b/compiler/builders/ast_to_builder.jou index c5f367d8..83b0497a 100644 --- a/compiler/builders/ast_to_builder.jou +++ b/compiler/builders/ast_to_builder.jou @@ -6,30 +6,44 @@ import "../errors_and_warnings.jou" import "../evaluate.jou" import "../types.jou" import "../types_in_ast.jou" -import "./llvm_builder.jou" +import "./either_builder.jou" class LocalVar: name: byte[100] # All local variables are represented as pointers to stack space, even # if they are never reassigned. LLVM will optimize the mess. - ptr: BuilderValue + ptr: EitherBuilderValue class Loop: - on_break: BuilderBlock - on_continue: BuilderBlock + on_break: EitherBuilderBlock + on_continue: EitherBuilderBlock class AstToBuilder: - builder: Builder* + builder: EitherBuilder* locals: LocalVar* nlocals: int loops: Loop* nloops: int returns_a_value: bool - def begin_function(self, sig: Signature*, locals: LocalVariable*, nlocals: int, public: bool) -> None: + location: Location + + # Returns old location. Use it to restore the location when you're done. + def set_location(self, location: Location) -> Location: + old = self->location + self->location = location + # If no reasonable location is available (e.g. implicit return at end of function), + # continue using the previous location for now. + if location.path != NULL and location.lineno != 0: + self->builder->set_location(location) + return old + + def begin_function(self, sig: Signature*, location: Location, locals: LocalVariable*, nlocals: int, public: bool) -> None: + old = self->set_location(location) + # First n local variables are the arguments assert sig->nargs >= 0 assert sig->nargs <= nlocals @@ -45,7 +59,7 @@ class AstToBuilder: for i = 0; i < nlocals; i++: var_name = locals[i].name var_type = locals[i].type - var_ptr = self->builder->stack_alloc(var_type) + var_ptr = self->builder->stack_alloc(var_type, var_name) self->locals[i] = LocalVar{name = var_name, ptr = var_ptr} if i < sig->nargs: # First n local variables are the function arguments @@ -56,11 +70,10 @@ class AstToBuilder: jou_startup_sig = Signature{name = "_jou_startup"} self->builder->call(&jou_startup_sig, NULL, 0) + self->set_location(old) + def end_function(self) -> None: - if self->returns_a_value: - self->builder->unreachable() - else: - self->builder->ret(NULL) # implicit "return" when falling off end of function + self->builder->ret(NULL) # implicit "return" when falling off end of function self->builder->end_function() def local_var_exists(self, name: byte*) -> bool: @@ -69,27 +82,27 @@ class AstToBuilder: return True return False - def local_var_ptr(self, name: byte*) -> BuilderValue: + def local_var_ptr(self, name: byte*) -> EitherBuilderValue: for i = 0; i < self->nlocals; i++: if strcmp(self->locals[i].name, name) == 0: return self->locals[i].ptr assert False - def build_function_call(self, call: AstCall*) -> BuilderValue: + def build_function_call(self, call: AstCall*) -> EitherBuilderValue: assert call->method_call_self == NULL assert call->nargs <= 100 - args: BuilderValue[100] + args: EitherBuilderValue[100] for i = 0; i < call->nargs; i++: args[i] = self->build_expression(&call->args[i]) return self->builder->call(call->called_signature, args, call->nargs) - def build_method_call(self, call: AstCall*) -> BuilderValue: + def build_method_call(self, call: AstCall*) -> EitherBuilderValue: assert call->method_call_self != NULL # leave room for self assert call->nargs <= 99 - args: BuilderValue[100] + args: EitherBuilderValue[100] k = 0 @@ -108,7 +121,7 @@ class AstToBuilder: return self->builder->call(call->called_signature, args, k) - def build_binop(self, op: AstExpressionKind, lhs: BuilderValue, rhs: BuilderValue) -> BuilderValue: + def build_binop(self, op: AstExpressionKind, lhs: EitherBuilderValue, rhs: EitherBuilderValue) -> EitherBuilderValue: match op: case AstExpressionKind.Eq: return self->builder->eq(lhs, rhs) @@ -142,11 +155,11 @@ class AstToBuilder: new_value = self->build_binop(op, old_value, rhs_value) self->builder->set_ptr(lhs_ptr, new_value) - def build_instantiation(self, class_type: Type*, inst: AstInstantiation*) -> BuilderValue: + def build_instantiation(self, class_type: Type*, inst: AstInstantiation*) -> EitherBuilderValue: assert class_type != NULL assert class_type->kind == TypeKind.Class - inst_ptr = self->builder->stack_alloc(class_type) + inst_ptr = self->builder->stack_alloc(class_type, NULL) self->builder->memset_to_zero(inst_ptr) for i = 0; i < inst->nfields; i++: field_ptr = self->builder->class_field_pointer(inst_ptr, inst->field_names[i]) @@ -155,7 +168,7 @@ class AstToBuilder: return self->builder->dereference(inst_ptr) - def build_and(self, lhsexpr: AstExpression*, rhsexpr: AstExpression*) -> BuilderValue: + def build_and(self, lhsexpr: AstExpression*, rhsexpr: AstExpression*) -> EitherBuilderValue: # Must be careful with side effects. # # # lhs returning False means we don't evaluate rhs @@ -166,7 +179,7 @@ class AstToBuilder: lhstrue = self->builder->add_block() lhsfalse = self->builder->add_block() done = self->builder->add_block() - resultptr = self->builder->stack_alloc(boolType) + resultptr = self->builder->stack_alloc(boolType, NULL) # if lhs: self->builder->branch(self->build_expression(lhsexpr), lhstrue, lhsfalse) @@ -185,7 +198,7 @@ class AstToBuilder: return self->builder->dereference(resultptr) - def build_or(self, lhsexpr: AstExpression*, rhsexpr: AstExpression*) -> BuilderValue: + def build_or(self, lhsexpr: AstExpression*, rhsexpr: AstExpression*) -> EitherBuilderValue: # Must be careful with side effects. # # # lhs returning True means we don't evaluate rhs @@ -196,7 +209,7 @@ class AstToBuilder: lhstrue = self->builder->add_block() lhsfalse = self->builder->add_block() done = self->builder->add_block() - resultptr = self->builder->stack_alloc(boolType) + resultptr = self->builder->stack_alloc(boolType, NULL) # if lhs: self->builder->branch(self->build_expression(lhsexpr), lhstrue, lhsfalse) @@ -215,7 +228,7 @@ class AstToBuilder: return self->builder->dereference(resultptr) - def build_increment_or_decrement(self, inner: AstExpression*, pre: bool, diff: int) -> BuilderValue: + def build_increment_or_decrement(self, inner: AstExpression*, pre: bool, diff: int) -> EitherBuilderValue: assert diff == 1 or diff == -1 # 1=increment, -1=decrement ptr = self->build_address_of_expression(inner) @@ -234,11 +247,11 @@ class AstToBuilder: else: return old_value - def build_array(self, t: Type*, items: AstExpression*, nitems: int) -> BuilderValue: + def build_array(self, t: Type*, items: AstExpression*, nitems: int) -> EitherBuilderValue: assert t->kind == TypeKind.Array assert t->array.len == nitems - arr_ptr = self->builder->stack_alloc(t) + arr_ptr = self->builder->stack_alloc(t, NULL) first_item_ptr = self->builder->cast(arr_ptr, t->array.item_type->pointer_type()) for i = 0; i < nitems; i++: @@ -249,7 +262,7 @@ class AstToBuilder: return self->builder->dereference(arr_ptr) - def build_expression_without_implicit_cast(self, expr: AstExpression*) -> BuilderValue: + def build_expression_without_implicit_cast(self, expr: AstExpression*) -> EitherBuilderValue: match expr->kind: case AstExpressionKind.String: if expr->types.orig_type == byteType->pointer_type(): @@ -309,7 +322,7 @@ class AstToBuilder: # We need to copy, because it's not always possible to evaluate &foo. # For example, consider evaluating some_function().some_field. instance = self->build_expression(expr->class_field.instance) - ptr = self->builder->stack_alloc(instance.type) + ptr = self->builder->stack_alloc(expr->class_field.instance->types.implicit_cast_type, NULL) self->builder->set_ptr(ptr, instance) fieldptr = self->builder->class_field_pointer(ptr, expr->class_field.field_name) return self->builder->dereference(fieldptr) @@ -360,21 +373,28 @@ class AstToBuilder: return self->builder->not_(self->build_expression(&expr->operands[0])) assert False - def build_expression(self, expr: AstExpression*) -> BuilderValue: - if expr->types.implicit_array_to_pointer_cast: - return self->builder->cast(self->build_address_of_expression(expr), expr->types.implicit_cast_type) + def build_expression(self, expr: AstExpression*) -> EitherBuilderValue: + old_location = self->set_location(expr->location) - raw = self->build_expression_without_implicit_cast(expr) - if expr->types.orig_type == NULL and expr->types.implicit_cast_type == NULL: - # Function/method call that returns no value - assert expr->kind == AstExpressionKind.Call - return BuilderValue{} + if expr->types.implicit_array_to_pointer_cast: + result = self->builder->cast(self->build_address_of_expression(expr), expr->types.implicit_cast_type) else: - assert expr->types.orig_type != NULL - assert expr->types.implicit_cast_type != NULL - return self->builder->cast(raw, expr->types.implicit_cast_type) + raw = self->build_expression_without_implicit_cast(expr) + if expr->types.orig_type == NULL and expr->types.implicit_cast_type == NULL: + # Function/method call that returns no value + assert expr->kind == AstExpressionKind.Call + result = EitherBuilderValue{} + else: + assert expr->types.orig_type != NULL + assert expr->types.implicit_cast_type != NULL + result = self->builder->cast(raw, expr->types.implicit_cast_type) + + self->set_location(old_location) + return result + + def build_address_of_expression(self, expr: AstExpression*) -> EitherBuilderValue: + old_location = self->set_location(expr->location) - def build_address_of_expression(self, expr: AstExpression*) -> BuilderValue: match expr->kind: case AstExpressionKind.GetClassField: if expr->class_field.uses_arrow_operator: @@ -383,25 +403,28 @@ class AstToBuilder: else: # &obj.field = &obj + memory offset ptr = self->build_address_of_expression(expr->class_field.instance) - return self->builder->class_field_pointer(ptr, expr->class_field.field_name) + result = self->builder->class_field_pointer(ptr, expr->class_field.field_name) case AstExpressionKind.Self: - return self->local_var_ptr("self") + result = self->local_var_ptr("self") case AstExpressionKind.GetVariable: if self->local_var_exists(expr->varname): - return self->local_var_ptr(expr->varname) + result = self->local_var_ptr(expr->varname) else: - return self->builder->global_var_ptr(expr->varname, expr->types.orig_type) + result = self->builder->global_var_ptr(expr->varname, expr->types.orig_type) case AstExpressionKind.Indexing: # &ptr[index] = ptr + memory offset ptr = self->build_expression(&expr->operands[0]) index = self->build_expression(&expr->operands[1]) - return self->builder->indexed_pointer(ptr, index) + result = self->builder->indexed_pointer(ptr, index) case AstExpressionKind.Dereference: # &*ptr = ptr - return self->build_expression(&expr->operands[0]) + result = self->build_expression(&expr->operands[0]) case _: assert False + self->set_location(old_location) + return result + def build_if_statement(self, ifst: AstIfStatement*) -> None: done = self->builder->add_block() for i = 0; i < ifst->n_if_and_elifs; i++: @@ -464,7 +487,7 @@ class AstToBuilder: done = self->builder->add_block() for i = 0; i < match_stmt->ncases; i++: then = self->builder->add_block() - otherwise = BuilderBlock{} # will be replaced by loop below + otherwise = EitherBuilderBlock{} # will be replaced by loop below for k = 0; k < match_stmt->cases[i].n_case_objs; k++: case_obj = self->build_expression(&match_stmt->cases[i].case_objs[k]) if match_stmt->func_name[0] == '\0': @@ -472,7 +495,7 @@ class AstToBuilder: else: args = [match_obj, case_obj] func_ret = self->builder->call(&match_stmt->func_signature, args, 2) - zero = self->builder->integer(func_ret.type, 0) + zero = self->builder->integer(match_stmt->func_signature.returntype, 0) cond = self->builder->eq(func_ret, zero) otherwise = self->builder->add_block() self->builder->branch(cond, then, otherwise) @@ -519,6 +542,8 @@ class AstToBuilder: self->builder->set_current_block(ok_block) def build_statement(self, stmt: AstStatement*) -> None: + old_location = self->set_location(stmt->location) + match stmt->kind: case AstStatementKind.If: self->build_if_statement(&stmt->if_statement) @@ -575,21 +600,23 @@ class AstToBuilder: # other statements shouldn't occur inside functions/methods assert False + self->set_location(old_location) + def build_body(self, body: AstBody*) -> None: for i = 0; i < body->nstatements; i++: self->build_statement(&body->statements[i]) @public -def feed_ast_to_builder(ast: AstFunctionOrMethod*, builder: Builder*) -> None: +def feed_ast_to_builder(func_ast: AstFunctionOrMethod*, func_location: Location, builder: EitherBuilder*) -> None: public = ( - ast->public - or ast->types.signature.is_main_function() - or ast->types.signature.get_self_class() != NULL + func_ast->public + or func_ast->types.signature.is_main_function() + or func_ast->types.signature.get_self_class() != NULL ) ast2ir = AstToBuilder{builder = builder} - ast2ir.begin_function(&ast->types.signature, ast->types.locals, ast->types.nlocals, public) - ast2ir.build_body(&ast->body) + ast2ir.begin_function(&func_ast->types.signature, func_location, func_ast->types.locals, func_ast->types.nlocals, public) + ast2ir.build_body(&func_ast->body) ast2ir.end_function() free(ast2ir.locals) free(ast2ir.loops) diff --git a/compiler/builders/either_builder.jou b/compiler/builders/either_builder.jou new file mode 100644 index 00000000..bc6f34e7 --- /dev/null +++ b/compiler/builders/either_builder.jou @@ -0,0 +1,370 @@ +# This file implements a builder that can build either LLVM or UVG, depending +# on which is needed. +# +# Ideally there would be inheritance, so that you could simply pass in +# whichever builder you need to the code that visits AST and calls methods on +# the builder, but Jou has no way to do inheritance yet. + +import "stdlib/mem.jou" + +import "../errors_and_warnings.jou" +import "../constants.jou" +import "../llvm.jou" +import "../types.jou" +import "../uvg.jou" +import "./llvm_builder.jou" +import "./uvg_builder.jou" + + +class EitherBuilderValue: + lvalue: LBuilderValue + uvalue: int + + +class EitherBuilderBlock: + lblock: LLVMBasicBlock* + ublock: UvgBlock* # funny naming :) + + +class EitherBuilder: + lbuilder: LBuilder* # may be NULL + ubuilder: UBuilder* # may be NULL + + def begin_function(self, sig: Signature*, public: bool) -> None: + if self->lbuilder != NULL: + self->lbuilder->begin_function(sig, public) + if self->ubuilder != NULL: + self->ubuilder->begin_function(sig, public) + + def end_function(self) -> None: + if self->lbuilder != NULL: + self->lbuilder->end_function() + if self->ubuilder != NULL: + self->ubuilder->end_function() + + def set_location(self, location: Location) -> None: + if self->lbuilder != NULL: + self->lbuilder->set_location(location) + if self->ubuilder != NULL: + self->ubuilder->set_location(location) + + # Allocates enough stack space in the function to hold a value of given type. + # Returns a pointer to the stack space. + # If allocated stack memory is not a local variable, varname must be NULL. + def stack_alloc(self, t: Type*, varname: byte*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->stack_alloc(t, varname) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->stack_alloc(t, varname) + return result + + # *ptr = value + def set_ptr(self, ptr: EitherBuilderValue, value: EitherBuilderValue) -> None: + if self->lbuilder != NULL: + self->lbuilder->set_ptr(ptr.lvalue, value.lvalue) + else: + self->ubuilder->set_ptr(ptr.uvalue, value.uvalue) + + # *ptr + def dereference(self, ptr: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->dereference(ptr.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->dereference(ptr.uvalue) + return result + + # Returns &ptr[index] + def indexed_pointer(self, ptr: EitherBuilderValue, index: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->indexed_pointer(ptr.lvalue, index.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->indexed_pointer(ptr.uvalue, index.uvalue) + return result + + # Returns &ptr->field + def class_field_pointer(self, ptr: EitherBuilderValue, field_name: byte*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->class_field_pointer(ptr.lvalue, field_name) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->class_field_pointer(ptr.uvalue, field_name) + return result + + # Returns &global_variable. Type needs to be passed in because a new builder + # unaware of global variables is created for each function. + def global_var_ptr(self, name: byte*, var_type: Type*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->global_var_ptr(name, var_type) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->global_var_ptr(name, var_type) + return result + + # Returns the i'th argument given to function + def get_argument(self, i: int, argtype: Type*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->get_argument(i, argtype) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->get_argument(i, argtype) + return result + + # Function or method call. If method, self with the correct type must be included in args. + def call(self, sig: Signature*, args: EitherBuilderValue*, nargs: int) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + largs: LBuilderValue* = malloc(sizeof(largs[0]) * nargs) + assert largs != NULL + for i = 0; i < nargs; i++: + largs[i] = args[i].lvalue + result.lvalue = self->lbuilder->call(sig, largs, nargs) + free(largs) + if self->ubuilder != NULL: + uargs: int* = malloc(sizeof(uargs[0]) * nargs) + assert uargs != NULL + for i = 0; i < nargs; i++: + uargs[i] = args[i].uvalue + result.uvalue = self->ubuilder->call(sig, uargs, nargs) + free(uargs) + return result + + # string as array of bytes + def string_array(self, s: byte*, array_size: int) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->string_array(s, array_size) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->string_array(s, array_size) + return result + + # string as '\0' terminated pointer + def string(self, s: byte*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->string(s) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->string(s) + return result + + def boolean(self, b: bool) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->boolean(b) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->boolean(b) + return result + + def integer(self, t: Type*, value: long) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->integer(t, value) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->integer(t, value) + return result + + def float_or_double(self, t: Type*, string: byte*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->float_or_double(t, string) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->float_or_double(t, string) + return result + + def zero_of_type(self, t: Type*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->zero_of_type(t) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->zero_of_type(t) + return result + + def enum_member(self, t: Type*, name: byte*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->enum_member(t, name) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->enum_member(t, name) + return result + + # TODO: delete this? + def constant(self, c: Constant*) -> EitherBuilderValue: + match c->kind: + case ConstantKind.Bool: + return self->boolean(c->boolean) + case ConstantKind.Integer: + return self->integer(c->get_type(), c->integer.value) + case ConstantKind.Float | ConstantKind.Double: + return self->float_or_double(c->get_type(), c->double_or_float_text) + case ConstantKind.Null: + return self->zero_of_type(voidPtrType) + case ConstantKind.String: + return self->string(c->str) + case ConstantKind.EnumMember: + return self->enum_member(c->get_type(), c->get_type()->enummembers.names[c->enum_member.memberidx]) + assert False + + # a + b + def add(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->add(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->add(a.uvalue, b.uvalue) + return result + + # a - b + def sub(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->sub(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->sub(a.uvalue, b.uvalue) + return result + + # a * b + def mul(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->mul(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->mul(a.uvalue, b.uvalue) + return result + + # a / b + def div(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->div(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->div(a.uvalue, b.uvalue) + return result + + # a % b + def mod(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->mod(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->mod(a.uvalue, b.uvalue) + return result + + # a == b + def eq(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->eq(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->eq(a.uvalue, b.uvalue) + return result + + # a < b + def lt(self, a: EitherBuilderValue, b: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->lt(a.lvalue, b.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->lt(a.uvalue, b.uvalue) + return result + + # not value + def not_(self, value: EitherBuilderValue) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->not_(value.lvalue) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->not_(value.uvalue) + return result + + # sizeof(any value of given type) + def size_of(self, t: Type*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->size_of(t) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->size_of(t) + return result + + # memset(ptr, 0, sizeof(*ptr)) + def memset_to_zero(self, ptr: EitherBuilderValue) -> None: + if self->lbuilder != NULL: + self->lbuilder->memset_to_zero(ptr.lvalue) + if self->ubuilder != NULL: + self->ubuilder->memset_to_zero(ptr.uvalue) + + # value as to + def cast(self, value: EitherBuilderValue, to: Type*) -> EitherBuilderValue: + result = EitherBuilderValue{} + if self->lbuilder != NULL: + result.lvalue = self->lbuilder->cast(value.lvalue, to) + if self->ubuilder != NULL: + result.uvalue = self->ubuilder->cast(value.uvalue, to) + return result + + # Blocks are used to implement e.g. if statements and loops. + def add_block(self) -> EitherBuilderBlock: + result = EitherBuilderBlock{} + if self->lbuilder != NULL: + result.lblock = self->lbuilder->add_block() + if self->ubuilder != NULL: + result.ublock = self->ubuilder->add_block() + return result + + # Decide which block will contain the resulting instructions. + def set_current_block(self, block: EitherBuilderBlock) -> None: + if self->lbuilder != NULL: + self->lbuilder->set_current_block(block.lblock) + if self->ubuilder != NULL: + self->ubuilder->set_current_block(block.ublock) + + # Conditional branch: + # + # if cond: + # then + # else: + # otherwise + # + # This terminates the current block and hence leaves the builder in a + # "no current block" state, i.e. you must call set_current_block() + # before the builder is usable again. + def branch(self, cond: EitherBuilderValue, then: EitherBuilderBlock, otherwise: EitherBuilderBlock) -> None: + if self->lbuilder != NULL: + self->lbuilder->branch(cond.lvalue, then.lblock, otherwise.lblock) + if self->ubuilder != NULL: + self->ubuilder->branch(cond.uvalue, then.ublock, otherwise.ublock) + + # Go to the block. Similar to branch() but no condition. LLVM calls this + # "unconditional branch", which IMO doesn't make sense because it always + # jumps and hence doesn't branch. + # + # This terminates the current block and hence leaves the builder in a + # "no current block" state, i.e. you must call set_current_block() + # before the builder is usable again. + def jump(self, next_block: EitherBuilderBlock) -> None: + if self->lbuilder != NULL: + self->lbuilder->jump(next_block.lblock) + if self->ubuilder != NULL: + self->ubuilder->jump(next_block.ublock) + + # Add an instruction that should never run. + # May be used by optimizer, but also tells LLVM that the block ends here. + def unreachable(self) -> None: + if self->lbuilder != NULL: + self->lbuilder->unreachable() + if self->ubuilder != NULL: + self->ubuilder->unreachable() + + # Return from function/method. Value should be NULL if the function is '-> None'. + def ret(self, value: EitherBuilderValue*) -> None: + if self->lbuilder != NULL: + if value == NULL: + self->lbuilder->ret(NULL) + else: + self->lbuilder->ret(&value->lvalue) + if self->ubuilder != NULL: + if value == NULL: + self->ubuilder->ret(NULL) + else: + self->ubuilder->ret(&value->uvalue) diff --git a/compiler/builders/llvm_builder.jou b/compiler/builders/llvm_builder.jou index 1d74bf6b..bcc8189e 100644 --- a/compiler/builders/llvm_builder.jou +++ b/compiler/builders/llvm_builder.jou @@ -2,19 +2,18 @@ # function or method. # # The idea is that instead of building just LLVM IR, we can also build other -# useful data structures by modifying only the IR builder. This means we don't -# walk through the AST in multiple different places that could handle some -# corner cases differently. -# -# Currently this is used only to build LLVM IR, but more uses are planned. +# useful data structures by modifying only the IR builder, such as UVG. This +# means we don't walk through the AST in multiple different places that could +# handle some corner cases differently. import "stdlib/math.jou" import "stdlib/mem.jou" import "stdlib/str.jou" import "stdlib/io.jou" +import "../errors_and_warnings.jou" import "../ast.jou" -import "../constants.jou" +import "./either_builder.jou" import "../llvm.jou" import "../target.jou" import "../types.jou" @@ -221,25 +220,24 @@ def find_enum_member(enumtype: Type*, name: byte*) -> int: assert False -class BuilderValue: +class LBuilderValue: type: Type* llvm_value: LLVMValue* -class BuilderBlock: - llvm_block: LLVMBasicBlock* - - -class Builder: +# Not named LLVMBuilder because that is the name of LLVM's thing. +class LBuilder: llvm_module: LLVMModule* llvm_builder: LLVMBuilder* llvm_func: LLVMValue* alloca_block: LLVMBasicBlock* # local variables created here before code of function runs code_start_block: LLVMBasicBlock* current_block: LLVMBasicBlock* + returns_a_value: bool def begin_function(self, sig: Signature*, public: bool) -> None: self->llvm_func = declare_in_llvm(sig, self->llvm_module) + self->returns_a_value = sig->returntype != NULL self->alloca_block = LLVMAppendBasicBlock(self->llvm_func, "alloca") self->code_start_block = LLVMAppendBasicBlock(self->llvm_func, "code_start") @@ -253,14 +251,24 @@ class Builder: LLVMPositionBuilderAtEnd(self->llvm_builder, self->alloca_block) LLVMBuildBr(self->llvm_builder, self->code_start_block) - # Allocates enough stack space in the function to hold a value of given type. - # Returns a pointer to the stack space. - def stack_alloc(self, t: Type*) -> BuilderValue: + # TODO: use the location for debug info + def set_location(self, location: Location) -> None: + pass + + def stack_alloc(self, t: Type*, varname: byte*) -> LBuilderValue: + if varname == NULL: + debug_name = "stack_alloc" + else: + debug_name = varname + + # Place all allocations to the same block at start of function, so that + # we don't overflow the stack when the part of code that creates local + # var runs many times. LLVMPositionBuilderAtEnd(self->llvm_builder, self->alloca_block) - llvm_ptr = LLVMBuildAlloca(self->llvm_builder, type_to_llvm(t), "stack_alloc") + llvm_ptr = LLVMBuildAlloca(self->llvm_builder, type_to_llvm(t), debug_name) LLVMPositionBuilderAtEnd(self->llvm_builder, self->current_block) llvm_ptr = LLVMBuildBitCast(self->llvm_builder, llvm_ptr, type_to_llvm(voidPtrType), "legacy_llvm14_cast") - return BuilderValue{type = t->pointer_type(), llvm_value = llvm_ptr} + return LBuilderValue{type = t->pointer_type(), llvm_value = llvm_ptr} # TODO: Everything named legacy_llvm14_cast can be removed once we drop LLVM 14 support # Needed due to LLVM opaque pointer types transition @@ -268,7 +276,7 @@ class Builder: # TODO: which casts are necessary on LLVM 14 and which are not? # *ptr = value - def set_ptr(self, ptr: BuilderValue, value: BuilderValue) -> None: + def set_ptr(self, ptr: LBuilderValue, value: LBuilderValue) -> None: if ptr.type != value.type->pointer_type(): printf("Cannot set value of %s to %s\n", ptr.type->name, value.type->name) assert ptr.type == value.type->pointer_type() @@ -277,13 +285,13 @@ class Builder: LLVMBuildStore(self->llvm_builder, value.llvm_value, llvm_ptr) # *ptr - def dereference(self, ptr: BuilderValue) -> BuilderValue: + def dereference(self, ptr: LBuilderValue) -> LBuilderValue: assert ptr.type->kind == TypeKind.Pointer llvm_result = LLVMBuildLoad2(self->llvm_builder, type_to_llvm(ptr.type->value_type), ptr.llvm_value, "dereference") - return BuilderValue{type = ptr.type->value_type, llvm_value = llvm_result} + return LBuilderValue{type = ptr.type->value_type, llvm_value = llvm_result} # Returns &ptr[index] - def indexed_pointer(self, ptr: BuilderValue, index: BuilderValue) -> BuilderValue: + def indexed_pointer(self, ptr: LBuilderValue, index: LBuilderValue) -> LBuilderValue: assert ptr.type->kind == TypeKind.Pointer assert index.type == longType # doesn't work right if it's other type llvm_result = LLVMBuildGEP2( @@ -295,10 +303,10 @@ class Builder: "indexed_pointer", ) llvm_result = LLVMBuildBitCast(self->llvm_builder, llvm_result, type_to_llvm(voidPtrType), "legacy_llvm14_cast") - return BuilderValue{type = ptr.type, llvm_value = llvm_result} + return LBuilderValue{type = ptr.type, llvm_value = llvm_result} # Returns &ptr->field - def class_field_pointer(self, ptr: BuilderValue, field_name: byte*) -> BuilderValue: + def class_field_pointer(self, ptr: LBuilderValue, field_name: byte*) -> LBuilderValue: assert ptr.type->kind == TypeKind.Pointer classtype = ptr.type->value_type assert classtype->kind == TypeKind.Class @@ -315,23 +323,22 @@ class Builder: llvm_ptr = LLVMBuildBitCast(self->llvm_builder, ptr.llvm_value, LLVMPointerType(llvm_struct_type, 0), "legacy_llvm14_cast") llvm_result = LLVMBuildStructGEP2(self->llvm_builder, llvm_struct_type, llvm_ptr, field->union_id, field->name) llvm_result = LLVMBuildBitCast(self->llvm_builder, llvm_result, type_to_llvm(voidPtrType), "legacy_llvm14_cast") - return BuilderValue{type = field->type->pointer_type(), llvm_value = llvm_result} + return LBuilderValue{type = field->type->pointer_type(), llvm_value = llvm_result} - # Returns &llvm_stringiable. Type needs to be passed in because a new builder - # unaware of global variables is created for each function. - def global_var_ptr(self, name: byte*, var_type: Type*) -> BuilderValue: + # &global_variable + def global_var_ptr(self, name: byte*, var_type: Type*) -> LBuilderValue: llvm_result = LLVMGetNamedGlobal(self->llvm_module, name) assert llvm_result != NULL llvm_result = LLVMBuildBitCast(self->llvm_builder, llvm_result, type_to_llvm(voidPtrType), "legacy_llvm14_cast") - return BuilderValue{type = var_type->pointer_type(), llvm_value = llvm_result} + return LBuilderValue{type = var_type->pointer_type(), llvm_value = llvm_result} - # Returns the i'th argument given to function - def get_argument(self, i: int, argtype: Type*) -> BuilderValue: + # i'th argument given to this function + def get_argument(self, i: int, argtype: Type*) -> LBuilderValue: llvm_result = LLVMGetParam(self->llvm_func, i) - return BuilderValue{type = argtype, llvm_value = llvm_result} + return LBuilderValue{type = argtype, llvm_value = llvm_result} - # Function or method call. If method, self with the correct type must be included in args. - def call(self, sig: Signature*, args: BuilderValue*, nargs: int) -> BuilderValue: + # Function or method call, self included in args if method + def call(self, sig: Signature*, args: LBuilderValue*, nargs: int) -> LBuilderValue: assert nargs >= sig->nargs if nargs > sig->nargs: assert sig->takes_varargs @@ -352,83 +359,64 @@ class Builder: llvm_return_value = LLVMBuildCall2(self->llvm_builder, signature_to_llvm(sig), llvm_func, llvm_args, nargs, debug_name) if sig->returntype == NULL: - return BuilderValue{} + return LBuilderValue{} else: assert llvm_return_value != NULL - return BuilderValue{type = sig->returntype, llvm_value = llvm_return_value} + return LBuilderValue{type = sig->returntype, llvm_value = llvm_return_value} # string as array of bytes - def string_array(self, s: byte*, array_size: int) -> BuilderValue: + def string_array(self, s: byte*, array_size: int) -> LBuilderValue: assert strlen(s) < array_size padded: byte* = malloc(array_size) memset(padded, 0, array_size) strcpy(padded, s) llvm_array = LLVMConstString(padded, array_size, True as int) free(padded) - return BuilderValue{type = byteType->array_type(array_size), llvm_value = llvm_array} + return LBuilderValue{type = byteType->array_type(array_size), llvm_value = llvm_array} # string as '\0' terminated pointer - def string(self, s: byte*) -> BuilderValue: + def string(self, s: byte*) -> LBuilderValue: llvm_array = self->string_array(s, (strlen(s) + 1) as int).llvm_value llvm_string = LLVMAddGlobal(self->llvm_module, LLVMTypeOf(llvm_array), "string_literal") LLVMSetLinkage(llvm_string, LLVMLinkage.Private) # This makes it a static global variable LLVMSetInitializer(llvm_string, llvm_array) llvm_string = LLVMBuildBitCast(self->llvm_builder, llvm_string, type_to_llvm(byteType->pointer_type()), "legacy_llvm14_cast") - return BuilderValue{type = byteType->pointer_type(), llvm_value = llvm_string} + return LBuilderValue{type = byteType->pointer_type(), llvm_value = llvm_string} - def boolean(self, b: bool) -> BuilderValue: - return BuilderValue{ + def boolean(self, b: bool) -> LBuilderValue: + return LBuilderValue{ type = boolType, llvm_value = LLVMConstInt(LLVMInt1Type(), b as long, False as int), } - def integer(self, t: Type*, value: long) -> BuilderValue: + def integer(self, t: Type*, value: long) -> LBuilderValue: assert t->is_integer_type() - return BuilderValue{ + return LBuilderValue{ type = t, llvm_value = LLVMConstInt(type_to_llvm(t), value, (t->kind == TypeKind.SignedInteger) as int), } - def float_or_double(self, t: Type*, string: byte*) -> BuilderValue: + def float_or_double(self, t: Type*, string: byte*) -> LBuilderValue: assert t->kind == TypeKind.FloatingPoint - return BuilderValue{ + return LBuilderValue{ type = t, llvm_value = LLVMConstRealOfString(type_to_llvm(t), string) } - def zero_of_type(self, t: Type*) -> BuilderValue: - return BuilderValue{ + def zero_of_type(self, t: Type*) -> LBuilderValue: + return LBuilderValue{ type = t, llvm_value = LLVMConstNull(type_to_llvm(t)), } - def enum_member(self, t: Type*, name: byte*) -> BuilderValue: - return BuilderValue{ + def enum_member(self, t: Type*, name: byte*) -> LBuilderValue: + return LBuilderValue{ type = t, llvm_value = LLVMConstInt(LLVMInt32Type(), find_enum_member(t, name), False as int), } - def constant(self, c: Constant*) -> BuilderValue: - llvm_constant: LLVMValue* = NULL - match c->kind: - case ConstantKind.Bool: - return self->boolean(c->boolean) - case ConstantKind.Integer: - return self->integer(c->get_type(), c->integer.value) - case ConstantKind.Float | ConstantKind.Double: - return self->float_or_double(c->get_type(), c->double_or_float_text) - case ConstantKind.Null: - llvm_constant = LLVMConstNull(type_to_llvm(voidPtrType)) - case ConstantKind.String: - return self->string(c->str) - case ConstantKind.EnumMember: - llvm_constant = LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int) - - assert llvm_constant != NULL - return BuilderValue{type = c->get_type(), llvm_value = llvm_constant} - # a + b - def add(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def add(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.FloatingPoint: @@ -437,10 +425,10 @@ class Builder: llvm_sum = LLVMBuildAdd(self->llvm_builder, a.llvm_value, b.llvm_value, "int_sum") case _: assert False - return BuilderValue{type = a.type, llvm_value = llvm_sum} + return LBuilderValue{type = a.type, llvm_value = llvm_sum} # a - b - def sub(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def sub(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.FloatingPoint: @@ -449,10 +437,10 @@ class Builder: llvm_diff = LLVMBuildSub(self->llvm_builder, a.llvm_value, b.llvm_value, "int_diff") case _: assert False - return BuilderValue{type = a.type, llvm_value = llvm_diff} + return LBuilderValue{type = a.type, llvm_value = llvm_diff} # a * b - def mul(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def mul(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.FloatingPoint: @@ -461,10 +449,10 @@ class Builder: llvm_prod = LLVMBuildMul(self->llvm_builder, a.llvm_value, b.llvm_value, "int_prod") case _: assert False - return BuilderValue{type = a.type, llvm_value = llvm_prod} + return LBuilderValue{type = a.type, llvm_value = llvm_prod} # a / b - def div(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def div(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.FloatingPoint: @@ -475,10 +463,10 @@ class Builder: llvm_quot = LLVMBuildUDiv(self->llvm_builder, a.llvm_value, b.llvm_value, "uint_quot") case _: assert False - return BuilderValue{type = a.type, llvm_value = llvm_quot} + return LBuilderValue{type = a.type, llvm_value = llvm_quot} # a % b - def mod(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def mod(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.FloatingPoint: @@ -489,10 +477,10 @@ class Builder: llvm_mod = LLVMBuildURem(self->llvm_builder, a.llvm_value, b.llvm_value, "uint_mod") case _: assert False - return BuilderValue{type = a.type, llvm_value = llvm_mod} + return LBuilderValue{type = a.type, llvm_value = llvm_mod} # a == b - def eq(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def eq(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.SignedInteger | TypeKind.UnsignedInteger | TypeKind.Enum | TypeKind.Bool: @@ -501,10 +489,10 @@ class Builder: llvm_result = LLVMBuildFCmp(self->llvm_builder, LLVMRealPredicate.OEQ, a.llvm_value, b.llvm_value, "eq") case _: assert False - return BuilderValue{type = boolType, llvm_value = llvm_result} + return LBuilderValue{type = boolType, llvm_value = llvm_result} # a < b - def lt(self, a: BuilderValue, b: BuilderValue) -> BuilderValue: + def lt(self, a: LBuilderValue, b: LBuilderValue) -> LBuilderValue: assert a.type == b.type match a.type->kind: case TypeKind.SignedInteger: @@ -515,72 +503,57 @@ class Builder: llvm_result = LLVMBuildFCmp(self->llvm_builder, LLVMRealPredicate.OLT, a.llvm_value, b.llvm_value, "lt") case _: assert False - return BuilderValue{type = boolType, llvm_value = llvm_result} + return LBuilderValue{type = boolType, llvm_value = llvm_result} # not value - def not_(self, value: BuilderValue) -> BuilderValue: + def not_(self, value: LBuilderValue) -> LBuilderValue: llvm_result = LLVMBuildXor(self->llvm_builder, value.llvm_value, LLVMConstInt(LLVMInt1Type(), 1, False as int), "not") - return BuilderValue{type = boolType, llvm_value = llvm_result} + return LBuilderValue{type = boolType, llvm_value = llvm_result} # sizeof(any value of given type) - def size_of(self, t: Type*) -> BuilderValue: - return BuilderValue{ + def size_of(self, t: Type*) -> LBuilderValue: + return LBuilderValue{ type = longType, llvm_value = LLVMSizeOf(type_to_llvm(t)), } # memset(ptr, 0, sizeof(*ptr)) - def memset_to_zero(self, ptr: BuilderValue) -> None: + def memset_to_zero(self, ptr: LBuilderValue) -> None: assert ptr.type->kind == TypeKind.Pointer size = self->size_of(ptr.type->value_type).llvm_value zero_byte = LLVMConstInt(LLVMInt8Type(), 0, False as int) LLVMBuildMemSet(self->llvm_builder, ptr.llvm_value, zero_byte, size, 0) # value as to - def cast(self, value: BuilderValue, to: Type*) -> BuilderValue: + def cast(self, value: LBuilderValue, to: Type*) -> LBuilderValue: llvm_result = build_llvm_cast(self->llvm_builder, value.llvm_value, value.type, to) - return BuilderValue{type = to, llvm_value = llvm_result} + return LBuilderValue{type = to, llvm_value = llvm_result} # Blocks are used to implement e.g. if statements and loops. - def add_block(self) -> BuilderBlock: - llvm_block = LLVMAppendBasicBlock(self->llvm_func, "block") - return BuilderBlock{llvm_block = llvm_block} + def add_block(self) -> LLVMBasicBlock*: + return LLVMAppendBasicBlock(self->llvm_func, "block") # Decide which block will contain the resulting instructions. - def set_current_block(self, block: BuilderBlock) -> None: - LLVMPositionBuilderAtEnd(self->llvm_builder, block.llvm_block) - self->current_block = block.llvm_block + def set_current_block(self, block: LLVMBasicBlock*) -> None: + LLVMPositionBuilderAtEnd(self->llvm_builder, block) + self->current_block = block - # Conditional branch: - # - # if cond: - # then - # else: - # otherwise - # - # This leaves the builder in a "no current block" state, i.e. you must call - # set_current_block() before the builder is usable again. - def branch(self, cond: BuilderValue, then: BuilderBlock, otherwise: BuilderBlock) -> None: - LLVMBuildCondBr(self->llvm_builder, cond.llvm_value, then.llvm_block, otherwise.llvm_block) - - # Go to the block. Similar to branch() but no condition. LLVM calls this - # "unconditional branch", which IMO doesn't make sense because it always - # jumps and hence doesn't branch. - # - # This leaves the builder in a "no current block" state, i.e. you must call - # set_current_block() before the builder is usable again. - def jump(self, next_block: BuilderBlock) -> None: - LLVMBuildBr(self->llvm_builder, next_block.llvm_block) + def branch(self, cond: LBuilderValue, then: LLVMBasicBlock*, otherwise: LLVMBasicBlock*) -> None: + LLVMBuildCondBr(self->llvm_builder, cond.llvm_value, then, otherwise) + + def jump(self, next_block: LLVMBasicBlock*) -> None: + LLVMBuildBr(self->llvm_builder, next_block) - # Add an instruction that should never run. - # May be used by optimizer, but also tells LLVM that the block ends here. def unreachable(self) -> None: LLVMBuildUnreachable(self->llvm_builder) - # Return from function/method. Value should be NULL if the function is '-> None'. - def ret(self, value: BuilderValue*) -> None: + def ret(self, value: LBuilderValue*) -> None: if value == NULL: - LLVMBuildRetVoid(self->llvm_builder) + if self->returns_a_value: + # Implicit "return" at the end of a function that should return a value + LLVMBuildUnreachable(self->llvm_builder) + else: + LLVMBuildRetVoid(self->llvm_builder) else: LLVMBuildRet(self->llvm_builder, value->llvm_value) @@ -591,7 +564,8 @@ def build_llvm_ir(ast: AstFile*) -> LLVMModule*: LLVMSetTarget(module, target.triple) LLVMSetDataLayout(module, target.data_layout) - builder = Builder{llvm_module = module, llvm_builder = LLVMCreateBuilder()} + builder = LBuilder{llvm_module = module, llvm_builder = LLVMCreateBuilder()} + builder_wrapper = EitherBuilder{lbuilder = &builder} for g = ast->types.globals; g < &ast->types.globals[ast->types.nglobals]; g++: t = type_to_llvm(g->type) @@ -602,11 +576,11 @@ def build_llvm_ir(ast: AstFile*) -> LLVMModule*: for stmt = ast->body.statements; stmt < &ast->body.statements[ast->body.nstatements]; stmt++: match stmt->kind: case AstStatementKind.FunctionDef: - feed_ast_to_builder(&stmt->function, &builder) + feed_ast_to_builder(&stmt->function, stmt->location, &builder_wrapper) case AstStatementKind.Class: for inner = stmt->classdef.body->statements; inner < &stmt->classdef.body->statements[stmt->classdef.body->nstatements]; inner++: if inner->kind == AstStatementKind.MethodDef: - feed_ast_to_builder(&inner->method, &builder) + feed_ast_to_builder(&inner->method, inner->location, &builder_wrapper) case _: pass diff --git a/compiler/builders/uvg_builder.jou b/compiler/builders/uvg_builder.jou new file mode 100644 index 00000000..21d64ead --- /dev/null +++ b/compiler/builders/uvg_builder.jou @@ -0,0 +1,249 @@ +import "stdlib/str.jou" +import "stdlib/mem.jou" + +import "../uvg_analyze.jou" +import "../ast.jou" +import "../errors_and_warnings.jou" +import "../types.jou" +import "../uvg.jou" +import "./ast_to_builder.jou" +import "./either_builder.jou" + + +# Within this class, UVG variable ID -1 denotes a value that is not a pointer, +# or is some pointer that we don't keep track of. Any other ID represents a +# pointer to the corresponding UVG variable. +class UBuilder: + uvg: Uvg + current_block: UvgBlock* + returns_a_value: bool + location: Location + + def begin_function(self, sig: Signature*, public: bool) -> None: + self->returns_a_value = sig->returntype != NULL + self_class = sig->get_self_class() + if self_class == NULL: + # function + assert sizeof(self->uvg.name) >= sizeof(sig->name) + strcpy(self->uvg.name, sig->name) + else: + # method + snprintf(self->uvg.name, sizeof self->uvg.name, "%s.%s", self_class->name, sig->name) + + self->current_block = self->uvg.add_block() + + def end_function(self) -> None: + self->current_block = NULL + self->location = Location{} + + def set_location(self, location: Location) -> None: + self->location = location + + def add_instruction(self, ins: UvgInstruction) -> None: + assert ins.var >= 0 + + assert self->location.path != NULL + assert self->location.lineno != 0 + ins.location = self->location + + b = self->current_block + assert b != NULL + + b->instructions = realloc(b->instructions, sizeof(b->instructions[0]) * (b->ninstructions + 1)) + assert b->instructions != NULL + b->instructions[b->ninstructions++] = ins + + def use(self, var: int) -> None: + if var != -1: + self->add_instruction(UvgInstruction{kind = UvgInstructionKind.Use, var = var}) + + def set(self, var: int) -> None: + if var != -1: + self->add_instruction(UvgInstruction{kind = UvgInstructionKind.Set, var = var}) + + def dont_analyze(self, var: int) -> None: + if var != -1: + self->add_instruction(UvgInstruction{kind = UvgInstructionKind.DontAnalyze, var = var}) + + def stack_alloc(self, t: Type*, varname: byte*) -> int: + if varname != NULL: + # Currently the UVG does not support multiple variables with the same name. + assert not self->uvg.has_local_var(varname) + return self->uvg.get_local_var_ptr(varname) + + def set_ptr(self, ptr: int, value: int) -> None: + self->dont_analyze(value) + self->set(ptr) + + def dereference(self, ptr: int) -> int: + self->use(ptr) + return -1 + + def indexed_pointer(self, ptr: int, index: int) -> int: + self->dont_analyze(ptr) + return -1 + + def class_field_pointer(self, ptr: int, field_name: byte*) -> int: + self->dont_analyze(ptr) + return -1 + + def global_var_ptr(self, name: byte*, var_type: Type*) -> int: + return -1 + + def get_argument(self, i: int, argtype: Type*) -> int: + return -1 + + def call(self, sig: Signature*, args: int*, nargs: int) -> int: + for i = 0; i < nargs; i++: + self->dont_analyze(args[i]) + return -1 + + def string_array(self, s: byte*, array_size: int) -> int: + return -1 + + def string(self, s: byte*) -> int: + return -1 + + def boolean(self, b: bool) -> int: + return -1 + + def integer(self, t: Type*, value: long) -> int: + return -1 + + def float_or_double(self, t: Type*, string: byte*) -> int: + return -1 + + def zero_of_type(self, t: Type*) -> int: + return -1 + + def enum_member(self, t: Type*, name: byte*) -> int: + return -1 + + # a + b + def add(self, a: int, b: int) -> int: + return -1 + + # a - b + def sub(self, a: int, b: int) -> int: + return -1 + + # a * b + def mul(self, a: int, b: int) -> int: + return -1 + + # a / b + def div(self, a: int, b: int) -> int: + return -1 + + # a % b + def mod(self, a: int, b: int) -> int: + return -1 + + # a == b + def eq(self, a: int, b: int) -> int: + return -1 + + # a < b + def lt(self, a: int, b: int) -> int: + return -1 + + # not value + def not_(self, value: int) -> int: + return -1 + + # sizeof(any value of given type) + def size_of(self, t: Type*) -> int: + return -1 + + # memset(ptr, 0, sizeof(*ptr)) + def memset_to_zero(self, ptr: int) -> None: + self->set(ptr) + + # value as to + def cast(self, value: int, to: Type*) -> int: + if to->is_pointer_type(): + return value + else: + # e.g. cast pointer to long + self->dont_analyze(value) + return -1 + + # Blocks are used to implement e.g. if statements and loops. + def add_block(self) -> UvgBlock*: + return self->uvg.add_block() + + # Decide which block will contain the resulting instructions. + def set_current_block(self, block: UvgBlock*) -> None: + self->current_block = block + + def branch(self, cond: int, then: UvgBlock*, otherwise: UvgBlock*) -> None: + # TODO: do something with cond? + assert self->current_block != NULL + assert self->current_block->terminator.kind == UvgTerminatorKind.NotSet + self->current_block->terminator = UvgTerminator{ + kind = UvgTerminatorKind.Branch, + branch = UvgBranch{then = then, otherwise = otherwise}, + } + self->current_block = NULL + + def jump(self, next_block: UvgBlock*) -> None: + assert self->current_block != NULL + assert self->current_block->terminator.kind == UvgTerminatorKind.NotSet + self->current_block->terminator = UvgTerminator{ + kind = UvgTerminatorKind.Jump, + jump_block = next_block, + } + self->current_block = NULL + + def unreachable(self) -> None: + assert self->current_block != NULL + assert self->current_block->terminator.kind == UvgTerminatorKind.NotSet + self->current_block->terminator = UvgTerminator{kind = UvgTerminatorKind.Unreachable} + self->current_block = NULL + + def ret(self, value: int*) -> None: + if value != NULL: + self->set(self->uvg.get_local_var_ptr("return")) + + if self->returns_a_value: + self->use(self->uvg.get_local_var_ptr("return")) + + assert self->current_block != NULL + assert self->current_block->terminator.kind == UvgTerminatorKind.NotSet + self->current_block->terminator = UvgTerminator{kind = UvgTerminatorKind.Return} + + +enum UvgProcessing: + Print + Analyze + + +def do_processing(uvg: Uvg*, proc: UvgProcessing) -> None: + match proc: + case UvgProcessing.Print: + uvg->print() + case UvgProcessing.Analyze: + uvg_analyze(uvg) + + +@public +def build_and_process_uvgs(ast: AstFile*, proc: UvgProcessing) -> None: + builder = UBuilder{} + builder_wrapper = EitherBuilder{ubuilder = &builder} + + for stmt = ast->body.statements; stmt < &ast->body.statements[ast->body.nstatements]; stmt++: + match stmt->kind: + case AstStatementKind.FunctionDef: + feed_ast_to_builder(&stmt->function, stmt->location, &builder_wrapper) + do_processing(&builder.uvg, proc) + builder.uvg.free() + memset(&builder.uvg, 0, sizeof(builder.uvg)) + case AstStatementKind.Class: + for inner = stmt->classdef.body->statements; inner < &stmt->classdef.body->statements[stmt->classdef.body->nstatements]; inner++: + if inner->kind == AstStatementKind.MethodDef: + feed_ast_to_builder(&inner->method, inner->location, &builder_wrapper) + do_processing(&builder.uvg, proc) + builder.uvg.free() + memset(&builder.uvg, 0, sizeof(builder.uvg)) + case _: + pass diff --git a/compiler/command_line_args.jou b/compiler/command_line_args.jou index e3bcb9c7..09f28854 100644 --- a/compiler/command_line_args.jou +++ b/compiler/command_line_args.jou @@ -12,6 +12,7 @@ class CommandLineArgs: valgrind: bool # true --> Use valgrind when runnning user's jou program tokenize_only: bool # If true, tokenize the file passed on command line and don't actually compile anything parse_only: bool # If true, parse the file passed on command line and don't actually compile anything + uvg_only: bool # If true, generate and print UVG's and don't actually compile anything optlevel: int # Optimization level (0 don't optimize, 3 optimize a lot) infile: byte* # The "main" Jou file (can import other files) outfile: byte* # If not NULL, where to output executable @@ -34,9 +35,10 @@ def print_help(argv0: byte*) -> None: printf(" -v / --verbose display some progress information\n") printf(" -vv display a lot of information about all compilation steps\n") printf(" --valgrind use valgrind when running the code\n") - printf(" --tokenize-only display only the output of the tokenizer, don't do anything else\n") - printf(" --parse-only display only the AST (parse tree), don't do anything else\n") printf(" --linker-flags appended to the linker command, so you can use external libraries\n") + printf(" --tokenize-only display the output of the tokenizer, do not compile further\n") + printf(" --parse-only display the AST (parse tree), do not compile further\n") + printf(" --uvg-only display Undefined Value Graphs, do not compile further\n") @public @@ -79,6 +81,12 @@ def parse_command_line_args(argc: int, argv: byte**) -> None: exit(2) command_line_args.parse_only = True i++ + case "--uvg-only": + if argc > 3: + fprintf(stderr, "%s: --uvg-only cannot be used together with other flags (try \"%s --help\")", argv[0], argv[0]) + exit(2) + command_line_args.uvg_only = True + i++ case "--linker-flags": if command_line_args.linker_flags != NULL: fprintf(stderr, "%s: --linker-flags cannot be given multiple times (try \"%s --help\")\n", argv[0], argv[0]) diff --git a/compiler/main.jou b/compiler/main.jou index 37da2898..b7b3362c 100644 --- a/compiler/main.jou +++ b/compiler/main.jou @@ -4,6 +4,7 @@ import "stdlib/str.jou" import "stdlib/process.jou" import "./builders/llvm_builder.jou" +import "./builders/uvg_builder.jou" import "./command_line_args.jou" import "./evaluate.jou" import "./run.jou" @@ -309,6 +310,12 @@ class CompileState: free(pending_exports) + def analyze_all_uvgs(self) -> None: + for i = 0; i < self->nfiles; i++: + if command_line_args.verbosity >= 1: + printf("Building and analyzing UVGs for %s\n", self->files[i].path) + build_and_process_uvgs(&self->files[i].ast, UvgProcessing.Analyze) + def optimize(module: LLVMModule*, level: int) -> None: assert 1 <= level and level <= 3 @@ -397,6 +404,12 @@ def main(argc: int, argv: byte**) -> int: compst.typecheck_all_files() + if command_line_args.uvg_only: + mainfile = compst.find_file(command_line_args.infile) + assert mainfile != NULL + build_and_process_uvgs(&mainfile->ast, UvgProcessing.Print) + return 0 + objpaths: byte** = calloc(sizeof objpaths[0], compst.nfiles + 1) for i = 0; i < compst.nfiles; i++: llvm_ir = compst.files[i].build_llvm_ir() diff --git a/compiler/uvg.jou b/compiler/uvg.jou new file mode 100644 index 00000000..3cc73619 --- /dev/null +++ b/compiler/uvg.jou @@ -0,0 +1,153 @@ +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" + +import "./errors_and_warnings.jou" + + +enum UvgInstructionKind: + Use # something = *x + Set # *x = something + DontAnalyze # do_something_complicated(&x) + +class UvgInstruction: + location: Location + kind: UvgInstructionKind + var: int + + +class UvgBranch: + then: UvgBlock* + otherwise: UvgBlock* + + +enum UvgTerminatorKind: + NotSet # must be first so it's zero memory + Jump + Branch + Return + Unreachable + +class UvgTerminator: + kind: UvgTerminatorKind + union: + jump_block: UvgBlock* # UvgTerminatorKind.Jump + branch: UvgBranch # UvgTerminatorKind.Branch + + +class UvgBlock: + instructions: UvgInstruction* + ninstructions: int + terminator: UvgTerminator + + def free(self) -> None: + free(self->instructions) + + +# We build one UVG for each function. +class Uvg: + name: byte[200] + + # Each block is allocated separately so that we can pass them around as + # pointers, and they don't become invalid when adding more blocks. + blocks: UvgBlock** + nblocks: int + + varnames: byte[100]* + nvars: int + + def free(self) -> None: + for i = 0; i < self->nblocks; i++: + self->blocks[i]->free() + free(self->blocks[i]) + free(self->blocks) + free(self->varnames) + + def index_of_block(self, b: UvgBlock*) -> int: + for i = 0; i < self->nblocks; i++: + if self->blocks[i] == b: + return i + assert False + + def print(self) -> None: + printf("===== UVG for %s =====\n", self->name) + + assert self->nblocks > 0 + + for i = 0; i < self->nblocks; i++: + if i == 0: + printf("block 0 (start):\n") + else: + printf("block %d:\n", i) + b = self->blocks[i] + + for ins = b->instructions; ins < &b->instructions[b->ninstructions]; ins++: + s: byte[50] + sprintf(s, " [line %-5d]", ins->location.lineno) + + # Move the ']' right next to the number because I like it that way :D + while strstr(s, " ]") != NULL: + memswap(strstr(s, " ]"), strstr(s, "]"), 1) + printf("%s", s) + + match ins->kind: + case UvgInstructionKind.Use: + printf("use %s\n", self->varnames[ins->var]) + case UvgInstructionKind.Set: + printf("set %s\n", self->varnames[ins->var]) + case UvgInstructionKind.DontAnalyze: + printf("don't analyze %s\n", self->varnames[ins->var]) + + printf(" ") + match b->terminator.kind: + case UvgTerminatorKind.Jump: + printf("Jump to block %d.\n", self->index_of_block(b->terminator.jump_block)) + case UvgTerminatorKind.Branch: + printf( + "Jump to either block %d or %d depending on some condition.\n", + self->index_of_block(b->terminator.branch.then), + self->index_of_block(b->terminator.branch.otherwise), + ) + case UvgTerminatorKind.Return: + printf("Return from function.\n") + case UvgTerminatorKind.Unreachable: + printf("The end of this block is unreachable. It will never run.\n") + case UvgTerminatorKind.NotSet: + printf("(terminator not set)\n") + printf("\n") + + def add_block(self) -> UvgBlock*: + b: UvgBlock* = malloc(sizeof(*b)) + assert b != NULL + memset(b, 0, sizeof(*b)) + + self->blocks = realloc(self->blocks, sizeof(self->blocks[0]) * (self->nblocks + 1)) + assert self->blocks != NULL + self->blocks[self->nblocks++] = b + + return b + + def has_local_var(self, varname: byte*) -> bool: + assert varname != NULL + for i = 0; i < self->nvars; i++: + if strcmp(self->varnames[i], varname) == 0: + return True + return False + + def get_local_var_ptr(self, varname: byte*) -> int: + if varname != NULL: + for i = 0; i < self->nvars; i++: + if strcmp(self->varnames[i], varname) == 0: + # Reuse existing + return i + + self->varnames = realloc(self->varnames, sizeof(self->varnames[0]) * (self->nvars + 1)) + assert self->varnames != NULL + var_id = self->nvars++ + + if varname == NULL: + sprintf(self->varnames[var_id], "$%d", var_id) + else: + assert strlen(varname) < sizeof(self->varnames[var_id]) + strcpy(self->varnames[var_id], varname) + return var_id diff --git a/compiler/uvg_analyze.jou b/compiler/uvg_analyze.jou new file mode 100644 index 00000000..b7b6d685 --- /dev/null +++ b/compiler/uvg_analyze.jou @@ -0,0 +1,7 @@ +# See doc/compiler_internals/uvg.md +import "./uvg.jou" + +@public +def uvg_analyze(uvg: Uvg*) -> None: + # TODO + pass diff --git a/doc/syntax-spec.md b/doc/compiler_internals/syntax-spec.md similarity index 100% rename from doc/syntax-spec.md rename to doc/compiler_internals/syntax-spec.md diff --git a/doc/compiler_internals/uvg.md b/doc/compiler_internals/uvg.md new file mode 100644 index 00000000..0a55dded --- /dev/null +++ b/doc/compiler_internals/uvg.md @@ -0,0 +1,182 @@ +# Undefined Value Graphs + +UVGs are used to determine which values may be undefined when the code runs. +This file explains how they work using examples. + + +## `set` and `use` + +Consider the following Jou code: + +```python +import "stdlib/io.jou" + +def foo(a: int) -> int: # line 3 + x = a + 6 # line 4 + y: int # line 5 + z = y # line 6 + printf("%d %d\n", x, y, z) # line 7 +``` + +Running `jou --uvg-only file.jou` prints the following UVG: + +``` +===== UVG for foo ===== +block 0 (start): + [line 3] set a + [line 4] use a + [line 4] set x + [line 6] use y + [line 6] set z + [line 7] use x + [line 7] use y + [line 7] use z + [line 7] use return + Return from function. +``` + +Let's look at this in more detail: +- Line 3 sets variable `a` because it is the argument of the function. +- Line 4 uses variable `a` to compute `x`. This is fine because the value of `a` has been set. +- Line 5 does not show up at all. Creating a variable is not an instruction in UVG, and all variables exist already when the function begins. +- Line 6 uses `y`, which is undefined. There will be a warning. It also assigns a value to `z`. +- Line 7 uses `y` and `return`. Here `return` is a special variable that represents the return value of the function. It is undefined because the `return` statement is missing. + +This causes the compiler to show the following warnings: + +``` +compiler warning for file "foo.jou", line 6: the value of 'y' is undefined +compiler warning for file "foo.jou", line 7: the value of 'y' is undefined +compiler error in file "foo.jou", line 7: function 'foo' must return a value, because it is defined with '-> int' +``` + + +## The "don't analyze" instruction + +Sometimes the code does something that is too complicated for the compiler to analyze. +For example: + +```python +import "stdlib/io.jou" + +def bar() -> None: # line 3 + a, b: int # line 4 + scanf("%d\n", &a) # line 5 + printf("%lld\n", &b as long) # line 6 +``` + +The UVG for this function is: + +``` +===== UVG for bar ===== +block 0 (start): + [line 5] don't analyze a + [line 6] don't analyze b + Return from function. +``` + +The "don't analyze" UVG instruction means that +the address of the variable has been used in some complicated way. +From that point on, it is not possible to determine whether the value is defined or undefined. +For example, a function call with `&a` might store `a` to a global variable and set its value later. +No matter what you do after a variable is marked as "don't analyze", the compiler will not complain. + +Even a simple variable assignment can introduce the "don't analyze" instruction. +For example, consider the following: + +```python +import "stdlib/io.jou" + +def baz() -> None: + x: int + y = &x # line 5 + scanf("%d\n", y) # line 6 +``` + +The UVG is: + +``` +===== UVG for baz ===== +block 0 (start): + [line 5] don't analyze x + [line 5] set y + [line 6] use y + Return from function. +``` + +Just setting `y = &x` emits the "don't analyze" instruction for `x`. +This way, when we see `scanf("%d\n", y)`, +we don't need to somehow know which variables might end up in the variable `y`. +That might be doable, but for now, it seems unnecessarily complicated. + + +## Value Statuses and Branching + +To implement the warnings in +[tests/should_succeed/undefined_variable.jou](../../broken_tests/should_succeed/undefined_variable.jou), +the Jou compiler keeps track of the possible statuses of the variables in UVG. +The **status** of a variable is a subset of the following: +- `undefined` +- `points to foo`, where `foo` is another variable in the UVG +- `defined` (some value has been assigned to the value, but we don't know what value) + +For example, a variable with status `{undefined, points to foo}` +is either undefined or `&foo` depending on some `if` statement or loop or other control flow thing. + +In UVG, branching and loops are handled just like in LLVM. +UVG instructions are placed into **blocks**. +The end of the block is a **terminator**, which can be: +- jump to another block +- branch: jump to one of two blocks depending on some value +- return from function +- unreachable: the end of the block will never run for some reason. + +The statuses are figured out block by block. +The compiler internally stores the status of each variable at the end of each block in UVG. +They are first set to empty sets and then filled in block by block, +revisiting blocks multiple as needed to handle loops. + +Pseudo code: + +```python +statuses_at_end = {b: {v: set() for v in values} for b in blocks} + +def analyze_block(b, warn=False): + statuses = {} + for v in values: + statuses[v] = set() + for sourceblock in blocks_that_jump_to_b: + statuses[v] |= statuses_at_end[sourceblock][v] + if b == the_start_block: + statuses[v].add("undefined") + + for ins in b.instructions: + (update statuses based on ins) + if warn and (ins uses an undefined value): + show a warning + + if statuses_at_end[b] != statuses: + statuses_at_end[b] = statuses + for destblock in blocks_that_b_jumps_to: + queue.add(destblock) + +queue = {the_start_block} +while queue: + analyze_block(queue.pop(), warn=False) +for b in all_blocks: + analyze_block(b, warn=True) +``` + +Initially, the possible statuses of values come from other blocks that jump into the block being analyzed. +Also, all values may be undefined in the beginning of the start block. + +When we have computed the statuses of variables at the end of a block, +and they differ from the previous analysis, +we take the blocks that used the outdated statuses and queue them for another round of analyzing. + +At this point, we know which statuses are possible, and control flow has basically been taken care of. +Because we only stored the statuses at the end of each block, +we must loop through the instructions in each block again to show warnings. + +This could be done in a way that loops over the instructions fewer times, but that is unnecessary, +because this code is not the performance bottleneck of the compiler.