From 5b48facc72868f11503a6f9449ed82572c922e0a Mon Sep 17 00:00:00 2001 From: Akuli Date: Wed, 8 Jan 2025 20:11:06 +0200 Subject: [PATCH] Port the whole Jou compiler to Jou (#562) --- .github/workflows/netbsd.yml | 5 +- compare_compilers.sh | 18 +- self_hosted/ast.jou | 27 +- self_hosted/build_cfg.jou | 1181 ++++++++ self_hosted/codegen.jou | 500 ++++ self_hosted/errors_and_warnings.jou | 20 +- self_hosted/evaluate.jou | 18 +- self_hosted/free.jou | 71 + self_hosted/llvm.jou | 17 + self_hosted/main.jou | 909 +++--- self_hosted/output.jou | 95 + self_hosted/parser.jou | 2 +- self_hosted/paths.jou | 88 +- self_hosted/print.jou | 248 ++ self_hosted/run.jou | 29 + self_hosted/runs_wrong.txt | 15 - self_hosted/structs.jou | 252 ++ self_hosted/target.jou | 9 +- self_hosted/typecheck.jou | 2614 +++++++++-------- self_hosted/types.jou | 466 +-- self_hosted/update.jou | 63 + self_hosted_old/ast.jou | 837 ++++++ .../create_llvm_ir.jou | 0 self_hosted_old/errors_and_warnings.jou | 19 + self_hosted_old/evaluate.jou | 72 + self_hosted_old/llvm.jou | 278 ++ self_hosted_old/main.jou | 510 ++++ self_hosted_old/parser.jou | 1145 ++++++++ self_hosted_old/paths.jou | 181 ++ self_hosted_old/runs_wrong.txt | 16 + self_hosted_old/target.jou | 70 + self_hosted_old/token.jou | 139 + self_hosted_old/tokenizer.jou | 624 ++++ self_hosted_old/typecheck.jou | 1433 +++++++++ self_hosted_old/types.jou | 246 ++ 35 files changed, 10246 insertions(+), 1971 deletions(-) create mode 100644 self_hosted/build_cfg.jou create mode 100644 self_hosted/codegen.jou create mode 100644 self_hosted/free.jou create mode 100644 self_hosted/output.jou create mode 100644 self_hosted/print.jou create mode 100644 self_hosted/run.jou create mode 100644 self_hosted/structs.jou create mode 100644 self_hosted/update.jou create mode 100644 self_hosted_old/ast.jou rename {self_hosted => self_hosted_old}/create_llvm_ir.jou (100%) create mode 100644 self_hosted_old/errors_and_warnings.jou create mode 100644 self_hosted_old/evaluate.jou create mode 100644 self_hosted_old/llvm.jou create mode 100644 self_hosted_old/main.jou create mode 100644 self_hosted_old/parser.jou create mode 100644 self_hosted_old/paths.jou create mode 100644 self_hosted_old/runs_wrong.txt create mode 100644 self_hosted_old/target.jou create mode 100644 self_hosted_old/token.jou create mode 100644 self_hosted_old/tokenizer.jou create mode 100644 self_hosted_old/typecheck.jou create mode 100644 self_hosted_old/types.jou diff --git a/.github/workflows/netbsd.yml b/.github/workflows/netbsd.yml index bef3fbd6..c5d3ec4c 100644 --- a/.github/workflows/netbsd.yml +++ b/.github/workflows/netbsd.yml @@ -10,7 +10,10 @@ jobs: timeout-minutes: 10 steps: - uses: actions/checkout@v4 - - uses: cross-platform-actions/action@v0.26.0 + # TODO: disabled for now because freezes with no output for some reason. + # Started happening in #562, so it has something to do with self-hosted compiler. + - if: false + uses: cross-platform-actions/action@v0.26.0 env: PKG_PATH: 'https://cdn.NetBSD.org/pub/pkgsrc/packages/NetBSD/amd64/10.0/All' with: diff --git a/compare_compilers.sh b/compare_compilers.sh index 7fe920af..ae8c4c9d 100755 --- a/compare_compilers.sh +++ b/compare_compilers.sh @@ -29,9 +29,9 @@ for arg in "$@"; do done if [ ${#files[@]} = 0 ]; then - # skip compiler_cli, because it has a race condition when two compilers simultaneously run it + # skip compiler_cli, because it hard-codes name of compiler executable # TODO: do not skip Advent Of Code files - files=( $(find stdlib examples tests -name '*.jou' | grep -v aoc202. | grep -v tests/should_succeed/compiler_cli | grep -v tests/crash | sort) ) + files=( $(find stdlib examples tests -name '*.jou' | grep -v aoc202. | grep -v tests/should_succeed/compiler_cli | grep -v tests/crash | grep -v x11_window | sort) ) fi if [ ${#actions[@]} = 0 ]; then actions=(tokenize parse run) @@ -108,17 +108,21 @@ for action in ${actions[@]}; do flag=--${action}-only fi - # Run both compilers, and filter out lines that are known to differ but it doesn't matter (mostly linker errors) - # Run compilers in parallel to speed up. + # Run both compilers, and filter out lines that are known to differ but it doesn't + # matter (mostly linker errors). + # + # It is tempting to run compilers in parallel to speed up, but it doesn't work + # because they use the same temporary files in jou_compiled directories. ( set +e ./jou $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: |compiler warning for file' - ) > tmp/compare_compilers/compiler_written_in_c.txt & + true + ) > tmp/compare_compilers/compiler_written_in_c.txt ( set +e ./self_hosted_compiler $flag $file 2>&1 | grep -vE 'undefined reference to|multiple definition of|\bld: |linking failed|compiler warning for file' - ) > tmp/compare_compilers/self_hosted.txt & - wait + true + ) > tmp/compare_compilers/self_hosted.txt if [ -f $error_list_file ] && grep -qxF $file <(cat $error_list_file | tr -d '\r'); then # The file is skipped, so the two compilers should behave differently diff --git a/self_hosted/ast.jou b/self_hosted/ast.jou index bcf3b890..4e6a96e2 100644 --- a/self_hosted/ast.jou +++ b/self_hosted/ast.jou @@ -15,6 +15,7 @@ class AstArrayType: member_type: AstType* length: AstExpression* + # TODO: use this def free(self) -> None: self->member_type->free() self->length->free() @@ -30,15 +31,19 @@ class AstType: value_type: AstType* # AstTypeKind::Pointer array: AstArrayType # AstTypeKind::Array + # TODO: use this def is_void(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "void") == 0 + # TODO: use this def is_none(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "None") == 0 + # TODO: use this def is_noreturn(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "noreturn") == 0 + # TODO: use this def print(self, show_lineno: bool) -> None: if self->kind == AstTypeKind::Named: printf("%s", self->name) @@ -54,6 +59,7 @@ class AstType: if show_lineno: printf(" [line %d]", self->location.lineno) + # TODO: use this def free(self) -> None: if self->kind == AstTypeKind::Pointer: self->value_type->free() @@ -96,7 +102,7 @@ enum AstExpressionKind: Self # not a variable lookup, so you can't use 'self' as variable name outside a class GetVariable GetEnumMember - GetClassField + GetClassField # foo.bar, foo->bar As # unary operators SizeOf # sizeof x @@ -139,15 +145,17 @@ class AstExpression: bool_value: bool call: AstCall instantiation: AstInstantiation - as_expression: AstAsExpression* # Must be pointer, because it contains an AstExpression + as_: AstAsExpression* # Must be pointer, because it contains an AstExpression array: AstArray varname: byte[100] float_or_double_text: byte[100] operands: AstExpression* # Only for operators. Length is arity, see get_arity() + # TODO: use this def print(self) -> None: self->print_with_tree_printer(TreePrinter{}) + # TODO: use this def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("[line %d] ", self->location.lineno) if self->kind == AstExpressionKind::String: @@ -210,9 +218,9 @@ class AstExpression: self->class_field.instance->print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstExpressionKind::As: printf("as ") - self->as_expression->type.print(True) + self->as_->type.print(True) printf("\n") - self->as_expression->value.print_with_tree_printer(tp.print_prefix(True)) + self->as_->value.print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstExpressionKind::SizeOf: printf("sizeof\n") elif self->kind == AstExpressionKind::AddressOf: @@ -263,12 +271,13 @@ class AstExpression: for i = 0; i < self->get_arity(); i++: self->operands[i].print_with_tree_printer(tp.print_prefix(i == self->get_arity()-1)) + # TODO: use this def free(self) -> None: if self->kind == AstExpressionKind::Call: self->call.free() elif self->kind == AstExpressionKind::As: - self->as_expression->free() - free(self->as_expression) + self->as_->free() + free(self->as_) elif self->kind == AstExpressionKind::String: free(self->string) elif self->kind == AstExpressionKind::GetClassField: @@ -279,6 +288,7 @@ class AstExpression: self->operands[i].free() free(self->operands) + # TODO: use this # arity = number of operands, e.g. 2 for a binary operator such as "+" def get_arity(self) -> int: if ( @@ -312,6 +322,7 @@ class AstExpression: return 2 return 0 + # TODO: use this def can_have_side_effects(self) -> bool: return ( self->kind == AstExpressionKind::Call @@ -325,6 +336,7 @@ class AstArray: length: int items: AstExpression* + # TODO: use this def free(self) -> None: for i = 0; i < self->length; i++: self->items[i].free() @@ -383,7 +395,7 @@ class AstCall: free(self->args) class AstInstantiation: - class_name_location: Location + class_name_location: Location # TODO: probably not necessary, can use location of the instantiate expression class_name: byte[100] nfields: int field_names: byte[100]* @@ -704,6 +716,7 @@ class AstImport: location: Location specified_path: byte* # Path in jou code e.g. "stdlib/io.jou" resolved_path: byte* # Absolute path or relative to current working directory e.g. "/home/akuli/jou/stdlib/io.jou" + used: bool # For warning messages def print(self) -> None: printf( diff --git a/self_hosted/build_cfg.jou b/self_hosted/build_cfg.jou new file mode 100644 index 00000000..4a1cece7 --- /dev/null +++ b/self_hosted/build_cfg.jou @@ -0,0 +1,1181 @@ +import "stdlib/str.jou" +import "stdlib/mem.jou" + +import "./structs.jou" +import "./evaluate.jou" +import "./types.jou" +import "./errors_and_warnings.jou" +import "./ast.jou" + + +class State: + filetypes: FileTypes* + fomtypes: FunctionOrMethodTypes* + cfg: CfGraph* + current_block: CfBlock* + breakstack: CfBlock** + continuestack: CfBlock** + nloops: int + + +def find_local_var_cf(st: State*, name: byte*) -> LocalVariable*: + for var = st->cfg->locals; var < &st->cfg->locals[st->cfg->nlocals]; var++: + if strcmp((*var)->name, name) == 0: + return *var + return NULL + +def add_local_var(st: State*, t: Type*) -> LocalVariable*: + var: LocalVariable* = calloc(1, sizeof *var) + var->id = st->cfg->nlocals + var->type = t + + st->cfg->locals = realloc(st->cfg->locals, sizeof(st->cfg->locals[0]) * (st->cfg->nlocals + 1)) + assert st->cfg->locals != NULL + st->cfg->locals[st->cfg->nlocals++] = var + + return var + + +def get_expr_types(st: State*, expr: AstExpression*) -> ExpressionTypes*: + # TODO: a fancy binary search algorithm (need to add sorting) + assert st->fomtypes != NULL + for i = 0; i < st->fomtypes->n_expr_types; i++: + if st->fomtypes->expr_types[i]->expr == expr: + return st->fomtypes->expr_types[i] + return NULL + + +def add_block(st: State*) -> CfBlock*: + block: CfBlock* = calloc(1, sizeof *block) + + st->cfg->all_blocks = realloc(st->cfg->all_blocks, sizeof(st->cfg->all_blocks[0]) * (st->cfg->n_all_blocks + 1)) + assert st->cfg->all_blocks != NULL + st->cfg->all_blocks[st->cfg->n_all_blocks++] = block + + return block + + +def add_jump( + st: State*, + branchvar: LocalVariable*, + iftrue: CfBlock*, + iffalse: CfBlock*, + new_current_block: CfBlock*, +) -> None: + assert (iftrue != NULL and iffalse != NULL) or (iftrue == NULL and iffalse == NULL and branchvar == NULL) + if iftrue != iffalse: + assert branchvar != NULL + assert branchvar->type == boolType + + st->current_block->branchvar = branchvar + st->current_block->iftrue = iftrue + st->current_block->iffalse = iffalse + if new_current_block == NULL: + st->current_block = add_block(st) + else: + st->current_block = new_current_block + + +# returned pointer is only valid until next call to add_instruction() +def add_instruction( + st: State*, +# location: Location, +# k: CfInstructionKind, + #const union CfInstructionData *dat, # TODO: won't work +# operands: LocalVariable**, # NULL terminated, or NULL for empty +# destvar: LocalVariable*, + ins: CfInstruction, +) -> CfInstruction*: + #ins = CfInstruction{location=location, kind=k, destvar=destvar} + #if (dat) + # ins.data=*dat; +# +# while operands != NULL and operands[ins.noperands] != NULL: +# ins.noperands++ +# if ins.noperands > 0: +# nbytes = sizeof(ins.operands[0]) * ins.noperands +# ins.operands = malloc(nbytes) +# memcpy(ins.operands, operands, nbytes) + + st->current_block->instructions = realloc(st->current_block->instructions, sizeof(st->current_block->instructions[0]) * (st->current_block->ninstructions + 1)) + assert st->current_block->instructions != NULL + st->current_block->instructions[st->current_block->ninstructions++] = ins + + return &st->current_block->instructions[st->current_block->ninstructions - 1] + + +# TODO: do we need this? +def add_unary_op( + st: State*, + location: Location, + op: CfInstructionKind, + arg: LocalVariable*, + target: LocalVariable*, +) -> None: + ins = CfInstruction{location = location, kind = op, destvar = target} + operands = [arg, NULL as LocalVariable*] + ins.set_operands(operands) + add_instruction(st, ins) + + +# TODO: do we need this? +def add_binary_op( + st: State*, + location: Location, + op: CfInstructionKind, + lhs: LocalVariable*, + rhs: LocalVariable*, + target: LocalVariable*, +) -> None: + ins = CfInstruction{location = location, kind = op, destvar = target} + operands = [lhs, rhs, NULL as LocalVariable*] + ins.set_operands(operands) + add_instruction(st, ins) + + +# TODO: do we need this? +def add_constant(st: State*, location: Location, c: Constant, target: LocalVariable*) -> CfInstruction*: + ins = CfInstruction{location = location, kind = CfInstructionKind::Constant, constant = copy_constant(c), destvar = target} + return add_instruction(st, ins) + + +def build_bool_to_int_conversion( + st: State*, + boolvar: LocalVariable*, + location: Location, + t: Type*, +) -> LocalVariable*: + assert is_integer_type(t) + result: LocalVariable* = add_local_var(st, t) + + set1 = add_block(st) + set0 = add_block(st) + done = add_block(st) + + add_jump(st, boolvar, set1, set0, set1) + add_constant(st, location, int_constant(t, 1), result)->hide_unreachable_warning = True + add_jump(st, NULL, done, done, set0) + add_constant(st, location, int_constant(t, 0), result)->hide_unreachable_warning = True + add_jump(st, NULL, done, done, done) + + return result + + +def build_cast( + st: State*, + obj: LocalVariable*, + to: Type *, + location: Location, +) -> LocalVariable*: + if obj->type == to: + return obj + + if is_pointer_type(obj->type) and is_pointer_type(to): + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::PtrCast, obj, result) + return result + + if is_number_type(obj->type) and is_number_type(to): + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::NumCast, obj, result) + return result + + if is_number_type(obj->type) and obj->type->size_in_bits == 64 and is_pointer_type(to): + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::Int64ToPtr, obj, result) + return result + + if is_integer_type(obj->type) or to->kind == TypeKind::Enum: + i32var = add_local_var(st, intType) + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::NumCast, obj, i32var) + add_unary_op(st, location, CfInstructionKind::Int32ToEnum, i32var, result) + return result + + if obj->type->kind == TypeKind::Enum and is_integer_type(to): + i32var = add_local_var(st, intType) + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::EnumToInt32, obj, i32var) + add_unary_op(st, location, CfInstructionKind::NumCast, i32var, result) + return result + + if obj->type == boolType and is_integer_type(to): + return build_bool_to_int_conversion(st, obj, location, to) + + if is_pointer_type(obj->type) and is_integer_type(to) and to->size_in_bits == 64: + result = add_local_var(st, to) + add_unary_op(st, location, CfInstructionKind::PtrToInt64, obj, result) + return result + + assert False + + +def build_bool_eq(st: State*, location: Location, a: LocalVariable*, b: LocalVariable*) -> LocalVariable*: + assert a->type == boolType + assert b->type == boolType + + # Pseudo code: + # + # if a: + # result = b + # else: + # result = not b + result = add_local_var(st, boolType) + + atrue = add_block(st) + afalse = add_block(st) + done = add_block(st) + + # if a: + add_jump(st, a, atrue, afalse, atrue) + + # result = b + add_unary_op(st, location, CfInstructionKind::VarCpy, b, result) + + # else: + add_jump(st, NULL, done, done, afalse) + + # result = not b + add_unary_op(st, location, CfInstructionKind::BoolNegate, b, result) + + add_jump(st, NULL, done, done, done) + return result + + +def build_binop( + st: State*, + op: AstExpressionKind, + location: Location, + lhs: LocalVariable*, + rhs: LocalVariable*, + result_type: Type*, +) -> LocalVariable*: + got_bools = lhs->type == boolType and rhs->type == boolType + got_numbers = is_number_type(lhs->type) and is_number_type(rhs->type) + got_pointers = is_pointer_type(lhs->type) and is_pointer_type(rhs->type) + assert got_bools or got_numbers or got_pointers + + negate = False + swap = False + + destvar: LocalVariable* + if got_bools: + assert result_type == boolType + destvar = build_bool_eq(st, location, lhs, rhs) + assert op == AstExpressionKind::Eq or op == AstExpressionKind::Ne + negate = (op == AstExpressionKind::Ne) + else: + destvar = add_local_var(st, result_type) + k: CfInstructionKind + if op == AstExpressionKind::Add: + k = CfInstructionKind::NumAdd + elif op == AstExpressionKind::Subtract: + k = CfInstructionKind::NumSub + elif op == AstExpressionKind::Multiply: + k = CfInstructionKind::NumMul + elif op == AstExpressionKind::Divide: + k = CfInstructionKind::NumDiv + elif op == AstExpressionKind::Modulo: + k = CfInstructionKind::NumMod + elif op == AstExpressionKind::Eq: + k = CfInstructionKind::NumEq + elif op == AstExpressionKind::Ne: + k = CfInstructionKind::NumEq + negate = True + elif op == AstExpressionKind::Lt: + k = CfInstructionKind::NumLt + elif op == AstExpressionKind::Gt: + k = CfInstructionKind::NumLt + swap = True + elif op == AstExpressionKind::Le: + k = CfInstructionKind::NumLt + negate = True + swap = True + elif op == AstExpressionKind::Ge: + k = CfInstructionKind::NumLt + negate = True + else: + assert False + + if swap: + add_binary_op(st, location, k, rhs, lhs, destvar) + else: + add_binary_op(st, location, k, lhs, rhs, destvar) + + if not negate: + return destvar + + negated = add_local_var(st, boolType) + add_unary_op(st, location, CfInstructionKind::BoolNegate, destvar, negated) + return negated + + +def build_class_field_pointer( + st: State*, instance: LocalVariable*, fieldname: byte*, location: Location +) -> LocalVariable*: + assert instance->type->kind == TypeKind::Pointer + assert instance->type->value_type->kind == TypeKind::Class + class_type = instance->type->value_type + + for f = class_type->classdata.fields; f < &class_type->classdata.fields[class_type->classdata.nfields]; f++: + if strcmp(f->name, fieldname) == 0: + result = add_local_var(st, get_pointer_type(f->type)) + + ins = CfInstruction{ + location = location, + kind = CfInstructionKind::PtrClassField, + destvar = result, + } + + assert sizeof(ins.fieldname) == sizeof(f->name) + strcpy(ins.fieldname, f->name) + + operands = [instance, NULL as LocalVariable*] + ins.set_operands(operands) + + add_instruction(st, ins) + return result + + assert False + + +def build_class_field( + st: State*, + instance: LocalVariable*, + fieldname: byte*, + location: Location, +) -> LocalVariable*: + ptr = add_local_var(st, get_pointer_type(instance->type)) + add_unary_op(st, location, CfInstructionKind::AddressOfLocalVar, instance, ptr) + field_ptr = build_class_field_pointer(st, ptr, fieldname, location) + field = add_local_var(st, field_ptr->type->value_type) + add_unary_op(st, location, CfInstructionKind::PtrLoad, field_ptr, field) + return field + + +enum PreOrPost: + Pre + Post + + +def build_increment_or_decrement( + st: State*, + location: Location, + inner: AstExpression*, + pop: PreOrPost, + diff: int, +) -> LocalVariable*: + assert diff == 1 or diff == -1 # 1=increment, -1=decrement + + addr = build_address_of_expression(st, inner) + assert addr->type->kind == TypeKind::Pointer + t = addr->type->value_type + + if not is_integer_type(t) and not is_pointer_type(t): + msg: byte[500] + if diff == 1: + snprintf(msg, sizeof(msg), "cannot increment a value of type %s", t->name) + else: + snprintf(msg, sizeof(msg), "cannot decrement a value of type %s", t->name) + fail(location, msg) + + old_value = add_local_var(st, t) + new_value = add_local_var(st, t) + if is_integer_type(t): + diffvar = add_local_var(st, t) + else: + diffvar = add_local_var(st, intType) + + add_constant(st, location, int_constant(diffvar->type, diff), diffvar) + add_unary_op(st, location, CfInstructionKind::PtrLoad, addr, old_value) + if is_number_type(t): + add_binary_op(st, location, CfInstructionKind::NumAdd, old_value, diffvar, new_value) + else: + add_binary_op(st, location, CfInstructionKind::PtrAddInt, old_value, diffvar, new_value) + add_binary_op(st, location, CfInstructionKind::PtrStore, addr, new_value, NULL) + + if pop == PreOrPost::Pre: + return new_value + elif pop == PreOrPost::Post: + return old_value + else: + assert False + + +enum AndOr: + And + Or + + +def build_and_or( + st: State*, + lhsexpr: AstExpression*, + rhsexpr: AstExpression*, + andor: AndOr, +) -> LocalVariable*: + # Must be careful with side effects. + # + # and: + # # lhs returning False means we don't evaluate rhs + # if lhs: + # result = rhs + # else: + # result = False + # + # or: + # # lhs returning True means we don't evaluate rhs + # if lhs: + # result = True + # else: + # result = rhs + lhs = build_expression(st, lhsexpr) + result = add_local_var(st, boolType) + + lhstrue = add_block(st) + lhsfalse = add_block(st) + done = add_block(st) + + # if lhs: + add_jump(st, lhs, lhstrue, lhsfalse, lhstrue) + + if andor == AndOr::And: + # result = rhs + rhs = build_expression(st, rhsexpr) + add_unary_op(st, rhsexpr->location, CfInstructionKind::VarCpy, rhs, result) + elif andor == AndOr::Or: + # result = True + ins = add_constant(st, lhsexpr->location, Constant{kind = ConstantKind::Bool, boolean = True}, result) + ins->hide_unreachable_warning = True + else: + assert False + + # else: + add_jump(st, NULL, done, done, lhsfalse) + + if andor == AndOr::And: + # result = False + ins = add_constant(st, lhsexpr->location, Constant{kind = ConstantKind::Bool, boolean = False}, result) + ins->hide_unreachable_warning = True + elif andor == AndOr::Or: + # result = rhs + rhs = build_expression(st, rhsexpr) + add_unary_op(st, rhsexpr->location, CfInstructionKind::VarCpy, rhs, result) + else: + assert False + + add_jump(st, NULL, done, done, done) + return result + + +def build_address_of_expression(st: State*, address_of_what: AstExpression*) -> LocalVariable*: + if address_of_what->kind == AstExpressionKind::GetVariable: + ptrtype = get_pointer_type(get_expr_types(st, address_of_what)->type) + addr = add_local_var(st, ptrtype) + + local_var = find_local_var_cf(st, address_of_what->varname) + if local_var == NULL: + # Global variable (possibly imported from another file) + ins = CfInstruction { + location = address_of_what->location, + kind = CfInstructionKind::AddressOfGlobalVar, + destvar = addr, + } + assert sizeof(ins.globalname) == sizeof(address_of_what->varname) + strcpy(ins.globalname, address_of_what->varname) + add_instruction(st, ins) + else: + add_unary_op(st, address_of_what->location, CfInstructionKind::AddressOfLocalVar, local_var, addr) + return addr + + if address_of_what->kind == AstExpressionKind::Self: + ptrtype = get_pointer_type(get_expr_types(st, address_of_what)->type) + addr = add_local_var(st, ptrtype) + + local_var = find_local_var_cf(st, "self") + assert local_var != NULL + add_unary_op(st, address_of_what->location, CfInstructionKind::AddressOfLocalVar, local_var, addr) + return addr + + if address_of_what->kind == AstExpressionKind::Dereference: + # &*foo --> just evaluate foo + return build_expression(st, &address_of_what->operands[0]) + + if address_of_what->kind == AstExpressionKind::GetClassField: + if address_of_what->class_field.uses_arrow_operator: + # &obj->field aka &(obj->field) + obj = build_expression(st, address_of_what->class_field.instance) + else: + # &obj.field aka &(obj.field), evaluate as &(&obj)->field + obj = build_address_of_expression(st, address_of_what->class_field.instance) + + assert obj->type->kind == TypeKind::Pointer + assert obj->type->value_type->kind == TypeKind::Class + return build_class_field_pointer(st, obj, address_of_what->class_field.field_name, address_of_what->location) + + if address_of_what->kind == AstExpressionKind::Indexing: + ptr = build_expression(st, &address_of_what->operands[0]) + assert ptr->type->kind == TypeKind::Pointer + + index = build_expression(st, &address_of_what->operands[1]) + assert is_integer_type(index->type) + + result = add_local_var(st, ptr->type) + add_binary_op(st, address_of_what->location, CfInstructionKind::PtrAddInt, ptr, index, result) + return result + + assert False + + +def build_function_or_method_call( + st: State*, + location: Location, + call: AstCall*, +) -> LocalVariable*: + sig: Signature* = NULL + + if call->method_call_self != NULL: + selfclass = get_expr_types(st, call->method_call_self)->type + if call->uses_arrow_operator: + assert selfclass->kind == TypeKind::Pointer + selfclass = selfclass->value_type + assert selfclass->kind == TypeKind::Class + + for s = selfclass->classdata.methods; s < &selfclass->classdata.methods[selfclass->classdata.nmethods]; s++: + assert get_self_class(s) == selfclass + if strcmp(s->name, call->name) == 0: + sig = s + break + + else: + for f = st->filetypes->functions; f < &st->filetypes->functions[st->filetypes->nfunctions]; f++: + if strcmp(f->signature.name, call->name) == 0: + sig = &f->signature + break + + assert sig != NULL + + args: LocalVariable** = calloc(call->nargs + 2, sizeof(args[0])) + k = 0 + + if call->method_call_self != NULL: + if is_pointer_type(sig->argtypes[0]) and not call->uses_arrow_operator: + args[k++] = build_address_of_expression(st, call->method_call_self) + elif (not is_pointer_type(sig->argtypes[0])) and call->uses_arrow_operator: + self_ptr = build_expression(st, call->method_call_self) + assert self_ptr->type->kind == TypeKind::Pointer + + # dereference the pointer + val = add_local_var(st, self_ptr->type->value_type) + add_unary_op(st, call->method_call_self->location, CfInstructionKind::PtrLoad, self_ptr, val) + args[k++] = val + else: + args[k++] = build_expression(st, call->method_call_self) + + for i = 0; i < call->nargs; i++: + args[k++] = build_expression(st, &call->args[i]) + + if sig->returntype != NULL: + return_value = add_local_var(st, sig->returntype) + else: + return_value = NULL + + ins = CfInstruction{ + location = location, + kind = CfInstructionKind::Call, + signature = copy_signature(sig), + destvar = return_value, + } + ins.set_operands(args) + add_instruction(st, ins) + + if sig->is_noreturn: + # Place the remaining code into an unreachable block, so you will get a warning if there is any + add_jump(st, NULL, NULL, NULL, NULL) + + free(args) + return return_value + + +def build_instantiation(st: State*, type: Type*, inst: AstInstantiation*, location: Location) -> LocalVariable*: + instance = add_local_var(st, type) + instanceptr = add_local_var(st, get_pointer_type(type)) + + add_unary_op(st, location, CfInstructionKind::AddressOfLocalVar, instance, instanceptr) + add_unary_op(st, location, CfInstructionKind::PtrMemsetToZero, instanceptr, NULL) + + for i = 0; i < inst->nfields; i++: + fieldptr = build_class_field_pointer(st, instanceptr, inst->field_names[i], inst->field_values[i].location) + fieldval = build_expression(st, &inst->field_values[i]) + add_binary_op(st, location, CfInstructionKind::PtrStore, fieldptr, fieldval, NULL) + + return instance + + +def build_array(st: State*, type: Type*, items: AstExpression*, location: Location) -> LocalVariable*: + assert type->kind == TypeKind::Array + + arr = add_local_var(st, type) + arrptr = add_local_var(st, get_pointer_type(type)) + add_unary_op(st, location, CfInstructionKind::AddressOfLocalVar, arr, arrptr) + first_item_ptr = add_local_var(st, get_pointer_type(type->array.item_type)) + add_unary_op(st, location, CfInstructionKind::PtrCast, arrptr, first_item_ptr) + + for i = 0; i < type->array.len; i++: + value = build_expression(st, &items[i]) + + ivar = add_local_var(st, intType) + add_constant(st, location, int_constant(intType, i), ivar) + + destptr = add_local_var(st, first_item_ptr->type) + add_binary_op(st, location, CfInstructionKind::PtrAddInt, first_item_ptr, ivar, destptr) + add_binary_op(st, location, CfInstructionKind::PtrStore, destptr, value, NULL) + + return arr + + +def find_enum_member(enumtype: Type*, name: byte*) -> int: + for i = 0; i < enumtype->enummembers.count; i++: + if strcmp(enumtype->enummembers.names[i], name) == 0: + return i + assert False + + +def build_expression(st: State*, expr: AstExpression*) -> LocalVariable*: + types = get_expr_types(st, expr) + + if types != NULL and types->implicit_array_to_pointer_cast: + arrptr = build_address_of_expression(st, expr) + memberptr = add_local_var(st, types->implicit_cast_type) + add_unary_op(st, expr->location, CfInstructionKind::PtrCast, arrptr, memberptr) + return memberptr + + if types != NULL and types->implicit_string_to_array_cast: + assert types->implicit_cast_type != NULL + assert types->implicit_cast_type->kind == TypeKind::Array + assert expr->kind == AstExpressionKind::String + + array_size = types->implicit_cast_type->array.len + assert strlen(expr->string) < array_size + padded: byte* = calloc(1, array_size) + assert padded != NULL + strcpy(padded, expr->string) + + result = add_local_var(st, types->implicit_cast_type) + ins = CfInstruction{ + location = expr->location, + kind = CfInstructionKind::StringArray, + strarray = CfStringArray{ + str = padded, + len = array_size, + }, + destvar = result, + } + add_instruction(st, ins) + return result + + if expr->kind == AstExpressionKind::Call: + result = build_function_or_method_call(st, expr->location, &expr->call) + if result == NULL: + # called function/method has no return value + return NULL + elif expr->kind == AstExpressionKind::Instantiate: + result = build_instantiation(st, types->type, &expr->instantiation, expr->location) + elif expr->kind == AstExpressionKind::Array: + assert types->type->kind == TypeKind::Array + assert types->type->array.len == expr->array.length + result = build_array(st, types->type, expr->array.items, expr->location) + elif expr->kind == AstExpressionKind::GetEnumMember: + result = add_local_var(st, types->type) + c = Constant{ + kind = ConstantKind::EnumMember, + enum_member = EnumMemberConstant{ + enumtype = types->type, + memberidx = find_enum_member(types->type, expr->enum_member.member_name), + } + } + add_constant(st, expr->location, c, result) + elif expr->kind == AstExpressionKind::GetVariable: + if get_special_constant(expr->varname) != -1: + result = add_local_var(st, boolType) + ins = CfInstruction{ + location = expr->location, + kind = CfInstructionKind::SpecialConstant, + destvar = result, + } + assert sizeof(ins.scname) == sizeof(expr->varname) + strcpy(ins.scname, expr->varname) + add_instruction(st, ins) + else: + temp = find_local_var_cf(st, expr->varname) + if temp != NULL: + if types->implicit_cast_type == NULL or types->type == types->implicit_cast_type: + # Must take a "snapshot" of this variable, as it may change soon. + result = add_local_var(st, temp->type) + add_unary_op(st, expr->location, CfInstructionKind::VarCpy, temp, result) + else: + result = temp + else: + # For other than local variables we can evaluate as &*variable. + # Would also work for locals, but it would confuse simplify_cfg. + temp = build_address_of_expression(st, expr) + result = add_local_var(st, types->type) + add_unary_op(st, expr->location, CfInstructionKind::PtrLoad, temp, result) + elif ( + expr->kind == AstExpressionKind::GetClassField + and not expr->class_field.uses_arrow_operator + ): + temp = build_expression(st, expr->class_field.instance) + result = build_class_field(st, temp, expr->class_field.field_name, expr->location) + elif ( + ( + expr->kind == AstExpressionKind::GetClassField + and expr->class_field.uses_arrow_operator + ) + or expr->kind == AstExpressionKind::Indexing + ): + # To evaluate foo->bar, we first evaluate &foo->bar and then dereference. + # We can similarly evaluate &foo[bar]. + # + # This technique cannot be used with all expressions. For example, &(1+2) + # doesn't work, and &foo.bar doesn't work either whenever &foo doesn't work. + # But &foo->bar and &foo[bar] always work, because foo is already a pointer + # and we only add a memory offset to it. + temp = build_address_of_expression(st, expr) + result = add_local_var(st, types->type) + add_unary_op(st, expr->location, CfInstructionKind::PtrLoad, temp, result) + elif expr->kind == AstExpressionKind::AddressOf: + result = build_address_of_expression(st, &expr->operands[0]) + elif expr->kind == AstExpressionKind::SizeOf: + result = add_local_var(st, longType) + ins = CfInstruction{ + location = expr->location, + kind = CfInstructionKind::SizeOf, + type = get_expr_types(st, &expr->operands[0])->type, + destvar = result, + } + add_instruction(st, ins) + elif expr->kind == AstExpressionKind::Dereference: + temp = build_expression(st, &expr->operands[0]) + result = add_local_var(st, types->type) + add_unary_op(st, expr->location, CfInstructionKind::PtrLoad, temp, result) + elif expr->kind == AstExpressionKind::Self: + selfvar = find_local_var_cf(st, "self") + assert selfvar != NULL + if types->implicit_cast_type == NULL or types->type == types->implicit_cast_type: + # Must take a "snapshot" of this variable, as it may change soon. + result = add_local_var(st, selfvar->type) + add_unary_op(st, expr->location, CfInstructionKind::VarCpy, selfvar, result) + else: + result = selfvar + elif ( + expr->kind == AstExpressionKind::Bool + or expr->kind == AstExpressionKind::Byte + or expr->kind == AstExpressionKind::Float + or expr->kind == AstExpressionKind::Double + or expr->kind == AstExpressionKind::Short + or expr->kind == AstExpressionKind::Int + or expr->kind == AstExpressionKind::Long + or expr->kind == AstExpressionKind::Null + or expr->kind == AstExpressionKind::String + ): + if expr->kind == AstExpressionKind::Bool: + c = Constant{kind = ConstantKind::Bool, boolean = expr->bool_value} + elif expr->kind == AstExpressionKind::Byte: + c = int_constant(byteType, expr->byte_value) + elif expr->kind == AstExpressionKind::Short: + c = int_constant(shortType, expr->short_value) + elif expr->kind == AstExpressionKind::Int: + c = int_constant(intType, expr->int_value) + elif expr->kind == AstExpressionKind::Long: + c = int_constant(longType, expr->long_value) + elif expr->kind == AstExpressionKind::Null: + c = Constant{kind = ConstantKind::Null} + elif expr->kind == AstExpressionKind::Float: + c = Constant{kind = ConstantKind::Float} + assert sizeof(c.double_or_float_text) == sizeof(expr->float_or_double_text) + strcpy(c.double_or_float_text, expr->float_or_double_text) + elif expr->kind == AstExpressionKind::Double: + c = Constant{kind = ConstantKind::Double} + assert sizeof(c.double_or_float_text) == sizeof(expr->float_or_double_text) + strcpy(c.double_or_float_text, expr->float_or_double_text) + elif expr->kind == AstExpressionKind::String: + c = Constant{kind = ConstantKind::String, str = strdup(expr->string)} + else: + assert False + result = add_local_var(st, types->type) + add_constant(st, expr->location, c, result) + elif expr->kind == AstExpressionKind::And: + result = build_and_or(st, &expr->operands[0], &expr->operands[1], AndOr::And) + elif expr->kind == AstExpressionKind::Or: + result = build_and_or(st, &expr->operands[0], &expr->operands[1], AndOr::Or) + elif expr->kind == AstExpressionKind::Not: + temp = build_expression(st, &expr->operands[0]) + result = add_local_var(st, boolType) + add_unary_op(st, expr->location, CfInstructionKind::BoolNegate, temp, result) + elif expr->kind == AstExpressionKind::Negate: + temp = build_expression(st, &expr->operands[0]) + zero = add_local_var(st, temp->type) + result = add_local_var(st, temp->type) + if temp->type == doubleType: + c = Constant{kind = ConstantKind::Double, double_or_float_text = "0"} + elif temp->type == floatType: + c = Constant{kind = ConstantKind::Float, double_or_float_text = "0"} + else: + c = int_constant(temp->type, 0) + add_constant(st, expr->location, c, zero) + add_binary_op(st, expr->location, CfInstructionKind::NumSub, zero, temp, result) + elif ( + expr->kind == AstExpressionKind::Add + or expr->kind == AstExpressionKind::Subtract + or expr->kind == AstExpressionKind::Multiply + or expr->kind == AstExpressionKind::Divide + or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Eq + or expr->kind == AstExpressionKind::Ne + or expr->kind == AstExpressionKind::Gt + or expr->kind == AstExpressionKind::Ge + or expr->kind == AstExpressionKind::Lt + or expr->kind == AstExpressionKind::Le + ): + # Refactoring note: If you rewrite this in another language, like C, make sure to + # evaluate lhs first. C doesn't guarantee evaluation order of function arguments. + lhs = build_expression(st, &expr->operands[0]) + rhs = build_expression(st, &expr->operands[1]) + result = build_binop(st, expr->kind, expr->location, lhs, rhs, types->type) + elif ( + expr->kind == AstExpressionKind::PreIncr + or expr->kind == AstExpressionKind::PreDecr + or expr->kind == AstExpressionKind::PostIncr + or expr->kind == AstExpressionKind::PostDecr + ): + if expr->kind == AstExpressionKind::PreIncr: + pop = PreOrPost::Pre + diff = 1 + elif expr->kind == AstExpressionKind::PreDecr: + pop = PreOrPost::Pre + diff = -1 + elif expr->kind == AstExpressionKind::PostIncr: + pop = PreOrPost::Post + diff = 1 + elif expr->kind == AstExpressionKind::PostDecr: + pop = PreOrPost::Post + diff = -1 + else: + assert False + result = build_increment_or_decrement(st, expr->location, &expr->operands[0], pop, diff) + elif expr->kind == AstExpressionKind::As: + temp = build_expression(st, &expr->as_->value) + result = build_cast(st, temp, types->type, expr->location) + else: + assert False + + assert types != NULL + assert result->type == types->type + if types->implicit_cast_type == NULL: + return result + return build_cast(st, result, types->implicit_cast_type, expr->location) + + +def build_if_statement(st: State*, ifstmt: AstIfStatement*) -> None: + assert ifstmt->n_if_and_elifs >= 1 + + done = add_block(st) + for i = 0; i < ifstmt->n_if_and_elifs; i++: + cond: LocalVariable* = build_expression(st, &ifstmt->if_and_elifs[i].condition) + then = add_block(st) + otherwise = add_block(st) + + add_jump(st, cond, then, otherwise, then) + build_body(st, &ifstmt->if_and_elifs[i].body) + add_jump(st, NULL, done, done, otherwise) + + build_body(st, &ifstmt->else_body) + add_jump(st, NULL, done, done, done) + + +def build_assert(st: State*, assert_location: Location, assertion: AstAssertion*) -> None: + condvar = build_expression(st, &assertion->condition) + + # If the condition is true, we jump to a block where the rest of the code goes. + # If the condition is false, we jump to a block that calls _jou_assert_fail(). + trueblock = add_block(st) + falseblock = add_block(st) + add_jump(st, condvar, trueblock, falseblock, falseblock) + + argnames: byte[100]* = malloc(3 * sizeof(argnames[0])) + strcpy(argnames[0], "assertion") + strcpy(argnames[1], "path") + strcpy(argnames[2], "lineno") + + argtypes: Type** = malloc(3 * sizeof(argtypes[0])) + argtypes[0] = get_pointer_type(byteType) + argtypes[1] = get_pointer_type(byteType) + argtypes[2] = intType + + args = [ + add_local_var(st, argtypes[0]), + add_local_var(st, argtypes[1]), + add_local_var(st, argtypes[2]), + NULL as LocalVariable*, + ] + + add_constant(st, assert_location, Constant{kind = ConstantKind::String, str = assertion->condition_str}, args[0]) + tmp = strdup(assertion->condition.location.path) + add_constant(st, assert_location, Constant{kind = ConstantKind::String, str = tmp}, args[1]) + free(tmp) + add_constant(st, assert_location, int_constant(intType, assert_location.lineno), args[2]) + + ins = CfInstruction{ + location = assert_location, + kind = CfInstructionKind::Call, + signature = Signature{ + name = "_jou_assert_fail", + nargs = 3, + argtypes = argtypes, + argnames = argnames, + takes_varargs = False, + is_noreturn = True, + returntype_location = assert_location, + }, + } + ins.set_operands(args) + add_instruction(st, ins) + + st->current_block = trueblock + + +# for init; cond; incr: +# ...body... +# +# While loop is basically a special case of for loop, so it uses this too. +def build_loop( + st: State*, + init: AstStatement*, + cond: AstExpression*, + incr: AstStatement*, + body: AstBody*, +) -> None: + condblock = add_block(st) # evaluate condition and go to bodyblock or doneblock + bodyblock = add_block(st) # run loop body and go to incrblock + incrblock = add_block(st) # run incr and go to condblock + doneblock = add_block(st) # rest of the code goes here + + # TODO: can init be NULL? + if init != NULL: + build_statement(st, init) + + # Evaluate condition. Jump to loop body or skip to after loop. + add_jump(st, NULL, condblock, condblock, condblock) + condvar = build_expression(st, cond) + add_jump(st, condvar, bodyblock, doneblock, bodyblock) + + # 'break' skips to after loop, 'continue' goes to incr. + st->breakstack = realloc(st->breakstack, sizeof(st->breakstack[0]) * (st->nloops + 1)) + st->continuestack = realloc(st->continuestack, sizeof(st->continuestack[0]) * (st->nloops + 1)) + assert st->breakstack != NULL + assert st->continuestack != NULL + st->breakstack[st->nloops] = doneblock + st->continuestack[st->nloops] = incrblock + st->nloops++ + + # Run loop body + build_body(st, body) + + st->nloops-- + assert st->breakstack[st->nloops] == doneblock + assert st->continuestack[st->nloops] == incrblock + + # Run incr and jump back to condition. + add_jump(st, NULL, incrblock, incrblock, incrblock) + if incr != NULL: # TODO: can it ever be NULL? + build_statement(st, incr) + add_jump(st, NULL, condblock, condblock, doneblock) + + +def build_statement(st: State*, stmt: AstStatement*) -> None: + if stmt->kind == AstStatementKind::If: + build_if_statement(st, &stmt->if_statement) + elif stmt->kind == AstStatementKind::Assert: + build_assert(st, stmt->location, &stmt->assertion) + elif stmt->kind == AstStatementKind::Pass: + pass + elif stmt->kind == AstStatementKind::WhileLoop: + build_loop( + st, NULL, &stmt->while_loop.condition, NULL, + &stmt->while_loop.body) + elif stmt->kind == AstStatementKind::ForLoop: + build_loop( + st, stmt->for_loop.init, &stmt->for_loop.cond, stmt->for_loop.incr, + &stmt->for_loop.body) + elif stmt->kind == AstStatementKind::Break: + if st->nloops == 0: + fail(stmt->location, "'break' can only be used inside a loop") + add_jump(st, NULL, st->breakstack[st->nloops - 1], st->breakstack[st->nloops - 1], NULL) + elif stmt->kind == AstStatementKind::Continue: + if st->nloops == 0: + fail(stmt->location, "'continue' can only be used inside a loop") + add_jump(st, NULL, st->continuestack[st->nloops - 1], st->continuestack[st->nloops - 1], NULL) + elif stmt->kind == AstStatementKind::Assign: + targetexpr = &stmt->assignment.target + valueexpr = &stmt->assignment.value + + targetvar: LocalVariable* = NULL + if targetexpr->kind == AstExpressionKind::GetVariable: + targetvar = find_local_var_cf(st, targetexpr->varname) + + if targetvar != NULL: + # avoid pointers to help simplify_cfg + value = build_expression(st, valueexpr) + add_unary_op(st, stmt->location, CfInstructionKind::VarCpy, value, targetvar) + else: + # TODO: is this evaluation order good? + target = build_address_of_expression(st, targetexpr) + value = build_expression(st, valueexpr) + assert target->type->kind == TypeKind::Pointer + add_binary_op(st, stmt->location, CfInstructionKind::PtrStore, target, value, NULL) + + elif ( + stmt->kind == AstStatementKind::InPlaceAdd + or stmt->kind == AstStatementKind::InPlaceSubtract + or stmt->kind == AstStatementKind::InPlaceMultiply + or stmt->kind == AstStatementKind::InPlaceDivide + or stmt->kind == AstStatementKind::InPlaceModulo + ): + targetexpr = &stmt->assignment.target + rhsexpr = &stmt->assignment.value + + targetptr = build_address_of_expression(st, targetexpr) + rhs = build_expression(st, rhsexpr) + assert targetptr->type->kind == TypeKind::Pointer + oldvalue = add_local_var(st, targetptr->type->value_type) + add_unary_op(st, stmt->location, CfInstructionKind::PtrLoad, targetptr, oldvalue) + + if stmt->kind == AstStatementKind::InPlaceAdd: + op = AstExpressionKind::Add + elif stmt->kind == AstStatementKind::InPlaceSubtract: + op = AstExpressionKind::Subtract + elif stmt->kind == AstStatementKind::InPlaceMultiply: + op = AstExpressionKind::Multiply + elif stmt->kind == AstStatementKind::InPlaceDivide: + op = AstExpressionKind::Divide + elif stmt->kind == AstStatementKind::InPlaceModulo: + op = AstExpressionKind::Modulo + else: + assert False + + newvalue = build_binop(st, op, stmt->location, oldvalue, rhs, targetptr->type->value_type) + add_binary_op(st, stmt->location, CfInstructionKind::PtrStore, targetptr, newvalue, NULL) + + elif stmt->kind == AstStatementKind::Return: + if stmt->return_value != NULL: + retvalue = build_expression(st, stmt->return_value) + retvariable = find_local_var_cf(st, "return") + assert retvariable != NULL + add_unary_op(st, stmt->location, CfInstructionKind::VarCpy, retvalue, retvariable) + + st->current_block->iftrue = &st->cfg->end_block + st->current_block->iffalse = &st->cfg->end_block + st->current_block = add_block(st) # an unreachable block + + elif stmt->kind == AstStatementKind::DeclareLocalVar: + if stmt->var_declaration.value != NULL: + v = find_local_var_cf(st, stmt->var_declaration.name) + assert v != NULL + cfvar = build_expression(st, stmt->var_declaration.value) + add_unary_op(st, stmt->location, CfInstructionKind::VarCpy, cfvar, v) + + elif stmt->kind == AstStatementKind::ExpressionStatement: + build_expression(st, &stmt->expression) + + else: + # other statements shouldn't occur inside functions/methods + assert False + + +def build_body(st: State*, body: AstBody*) -> None: + for i = 0; i < body->nstatements; i++: + build_statement(st, &body->statements[i]) + + +def build_function_or_method( + st: State*, + selfclass: Type*, + name: byte*, + body: AstBody*, +) -> CfGraph*: + assert st->cfg == NULL + + assert st->fomtypes == NULL + for f = st->filetypes->fomtypes; f < &st->filetypes->fomtypes[st->filetypes->nfomtypes]; f++: + if strcmp(f->signature.name, name) == 0 and get_self_class(&f->signature) == selfclass: + st->fomtypes = f + break + assert st->fomtypes != NULL + + st->cfg = calloc(1, sizeof *st->cfg) + st->cfg->signature = copy_signature(&st->fomtypes->signature) + + # Copy local variables over from type checking. + # Ownership of the variables changes, they will be freed when graphs are freed. + st->cfg->nlocals = st->fomtypes->nlocals + st->cfg->locals = malloc(st->cfg->nlocals * sizeof(st->cfg->locals[0])) + assert st->cfg->locals != NULL + memcpy(st->cfg->locals, st->fomtypes->locals, st->cfg->nlocals * sizeof(st->cfg->locals[0])) + + st->cfg->all_blocks = malloc(2 * sizeof(st->cfg->all_blocks[0])) + assert st->cfg->all_blocks != NULL + st->cfg->all_blocks[0] = &st->cfg->start_block + st->cfg->all_blocks[1] = &st->cfg->end_block + st->cfg->n_all_blocks = 2 + + st->current_block = &st->cfg->start_block + + assert st->nloops == 0 + build_body(st, body) + assert st->nloops == 0 + + # Implicit return at the end of the function + st->current_block->iftrue = &st->cfg->end_block + st->current_block->iffalse = &st->cfg->end_block + + cfg = st->cfg + st->fomtypes = NULL + st->cfg = NULL + return cfg + + +# TODO: passing a type context here doesn't really make sense. +# It would be better to pass only the public symbols that have been imported. +def build_control_flow_graphs(ast: AstFile*, filetypes: FileTypes*) -> CfGraphFile: + result = CfGraphFile{filename = ast->path} + st = State{filetypes = filetypes} + + for i = 0; i < ast->body.nstatements; i++: + stmt = &ast->body.statements[i] + + if stmt->kind == AstStatementKind::Function and stmt->function.body.nstatements > 0: + g = build_function_or_method(&st, NULL, stmt->function.signature.name, &stmt->function.body) + + result.graphs = realloc(result.graphs, sizeof(result.graphs[0]) * (result.ngraphs + 1)) + assert result.graphs != NULL + result.graphs[result.ngraphs++] = g + + if stmt->kind == AstStatementKind::Class: + class_type: Type* = NULL + for t = filetypes->owned_types; t < &filetypes->owned_types[filetypes->n_owned_types]; t++: + if strcmp((*t)->name, stmt->classdef.name) == 0: + class_type = *t + break + assert class_type != NULL + + for m = stmt->classdef.members; m < &stmt->classdef.members[stmt->classdef.nmembers]; m++: + if m->kind == AstClassMemberKind::Method: + g = build_function_or_method(&st, class_type, m->method.signature.name, &m->method.body) + + result.graphs = realloc(result.graphs, sizeof(result.graphs[0]) * (result.ngraphs + 1)) + assert result.graphs != NULL + result.graphs[result.ngraphs++] = g + + free(st.breakstack) + free(st.continuestack) + return result diff --git a/self_hosted/codegen.jou b/self_hosted/codegen.jou new file mode 100644 index 00000000..bec9f6e7 --- /dev/null +++ b/self_hosted/codegen.jou @@ -0,0 +1,500 @@ +import "stdlib/math.jou" +import "stdlib/mem.jou" +import "stdlib/str.jou" + +import "./evaluate.jou" +import "./llvm.jou" +import "./target.jou" +import "./types.jou" +import "./structs.jou" + +# LLVM doesn't have a built-in union type, and you're supposed to abuse other types for that: +# https://mapping-high-level-constructs-to-llvm-ir.readthedocs.io/en/latest/basic-constructs/unions.html +# +# My first idea was to use an array of bytes that is big enough to fit anything. +# However, that might not be aligned properly. +# +# Then I tried choosing the member type that has the biggest align, and making a large enough array of it. +# Because the align is always a power of two, the memory will be suitably aligned for all member types. +# But it didn't work for some reason I still don't understand. +# +# Then I figured out how clang does it and did it the same way. +# We make a struct that contains: +# - the most aligned type as chosen before +# - array of i8 as padding to make it the right size. +# But for some reason that didn't work either. +# +# As a "last resort" I just use an array of i64 large enough and hope it's aligned as needed. +def codegen_union_type(types: LLVMType**, ntypes: int) -> LLVMType*: + # For some reason uncommenting this makes stuff compile almost 2x slower... + #if ntypes == 1: + # return types[0] + + sizeneeded = 0L + for i = 0; i < ntypes; i++: + size1 = LLVMABISizeOfType(target.target_data, types[i]) + size2 = LLVMStoreSizeOfType(target.target_data, types[i]) + + # If this assert fails, you need to figure out which of the size functions should be used. + # I don't know what their difference is. + # And if you need the alignment, there's 3 different functions for that... + assert size1 == size2 + sizeneeded = llmax(sizeneeded, size1) + + return LLVMArrayType(LLVMInt64Type(), ((sizeneeded+7)/8) as int) + + +def codegen_class_type(type: Type*) -> LLVMType*: + assert type->kind == TypeKind::Class + + n = type->classdata.nfields + + flat_elems: LLVMType** = malloc(sizeof(flat_elems[0]) * n) + for i = 0; i < n; i++: + # Treat all pointers inside structs as if they were void*. + # This allows structs to contain pointers to themselves. + if type->classdata.fields[i].type->kind == TypeKind::Pointer: + flat_elems[i] = codegen_type(voidPtrType) + else: + flat_elems[i] = codegen_type(type->classdata.fields[i].type) + + # Combine together fields of the same union. + combined: LLVMType** = malloc(sizeof(combined[0]) * n) + combinedlen = 0 + for start = 0; start < n; start = end: + end = start + 1 + while end < n and type->classdata.fields[start].union_id == type->classdata.fields[end].union_id: + end++ + combined[combinedlen++] = codegen_union_type(&flat_elems[start], end-start) + + result = LLVMStructType(combined, combinedlen, False as int) + free(flat_elems) + free(combined) + return result + + +def codegen_type(type: Type*) -> LLVMType*: + if type->kind == TypeKind::Array: + return LLVMArrayType(codegen_type(type->array.item_type), type->array.len) + if type->kind == TypeKind::Pointer: + return LLVMPointerType(codegen_type(type->value_type), 0) + if type->kind == TypeKind::FloatingPoint: + if type->size_in_bits == 32: + return LLVMFloatType() + if type->size_in_bits == 64: + return LLVMDoubleType() + assert False + if type->kind == TypeKind::VoidPointer: + # just use i8* as here https://stackoverflow.com/q/36724399 + return LLVMPointerType(LLVMInt8Type(), 0) + if ( + type->kind == TypeKind::SignedInteger + or type->kind == TypeKind::UnsignedInteger + ): + return LLVMIntType(type->size_in_bits) + if type->kind == TypeKind::Bool: + return LLVMInt1Type() + if type->kind == TypeKind::OpaqueClass: + # this is compiler internal/temporary thing and should never end up here + assert False + if type->kind == TypeKind::Class: + return codegen_class_type(type) + if type->kind == TypeKind::Enum: + return LLVMInt32Type() + assert False + + +class State: + module: LLVMModule* + builder: LLVMBuilder* + cfvars: LocalVariable** + cfvars_end: LocalVariable** + # All local variables are represented as pointers to stack space, even + # if they are never reassigned. LLVM will optimize the mess. + llvm_locals: LLVMValue** + + +def get_pointer_to_local_var(st: State*, cfvar: LocalVariable*) -> LLVMValue*: + assert cfvar != NULL + + # The loop below looks stupid, but I don't see a better alternative. + # + # I want CFG variables to be used as pointers, so that it's easy to refer to a + # variable's name and type, check if you have the same variable, etc. But I + # can't make a List of variables when building CFG, because existing variable + # pointers would become invalid as the list grows. The solution is to allocate + # each variable separately when building the CFG. + # + # Another idea I had was to count the number of variables needed beforehand, + # so I wouldn't need to ever resize the list of variables, but the CFG building + # is already complicated enough as is. + for i = 0; &st->cfvars[i] < st->cfvars_end; i++: + if st->cfvars[i] == cfvar: + return st->llvm_locals[i] + assert False + + +def get_local_var(st: State*, cfvar: LocalVariable*) -> LLVMValue*: + varptr = get_pointer_to_local_var(st, cfvar) + return LLVMBuildLoad(st->builder, varptr, cfvar->name) + + +def set_local_var(st: State*, cfvar: LocalVariable*, value: LLVMValue*) -> None: + assert cfvar != NULL + for i = 0; &st->cfvars[i] < st->cfvars_end; i++: + if st->cfvars[i] == cfvar: + LLVMBuildStore(st->builder, value, st->llvm_locals[i]) + return + assert False + + +def codegen_function_or_method_decl(st: State*, sig: Signature*) -> LLVMValue*: + fullname: byte[200] + if get_self_class(sig) != NULL: + snprintf(fullname, sizeof fullname, "%s.%s", get_self_class(sig)->name, sig->name) + else: + assert sizeof(fullname) >= sizeof(sig->name) + assert sizeof(sig->name) > 50 # this is an array, not a pointer to dynamic length string + strcpy(fullname, sig->name) + + # Make it so that this can be called many times without issue + func = LLVMGetNamedFunction(st->module, fullname) + if func != NULL: + return func + + argtypes: LLVMType** = malloc(sig->nargs * sizeof(argtypes[0])) + for i = 0; i < sig->nargs; i++: + argtypes[i] = codegen_type(sig->argtypes[i]) + + returntype: LLVMType* + # TODO: tell llvm, if we know a function is noreturn ? + if sig->returntype == NULL: # "-> noreturn" or "-> None" + returntype = LLVMVoidType() + else: + returntype = codegen_type(sig->returntype) + + functype = LLVMFunctionType(returntype, argtypes, sig->nargs, sig->takes_varargs as int) + free(argtypes) + + return LLVMAddFunction(st->module, fullname, functype) + + +def codegen_call(st: State*, sig: Signature*, args: LLVMValue**, nargs: int) -> LLVMValue*: + function: LLVMValue* = codegen_function_or_method_decl(st, sig) + assert function != NULL + assert LLVMGetTypeKind(LLVMTypeOf(function)) == LLVMTypeKind::Pointer + function_type = LLVMGetElementType(LLVMTypeOf(function)) + assert LLVMGetTypeKind(function_type) == LLVMTypeKind::Function + + debug_name: byte[100] = "" + if LLVMGetTypeKind(LLVMGetReturnType(function_type)) != LLVMTypeKind::Void: + snprintf(debug_name, sizeof debug_name, "%s_return_value", sig->name) + + return LLVMBuildCall2(st->builder, function_type, function, args, nargs, debug_name) + + +def make_a_string_constant(st: State*, s: byte*) -> LLVMValue*: + array = LLVMConstString(s, strlen(s) as int, False as int) + global_var = LLVMAddGlobal(st->module, LLVMTypeOf(array), "string_literal") + LLVMSetLinkage(global_var, LLVMLinkage::Private) # This makes it a static global variable + LLVMSetInitializer(global_var, array) + + string_type = LLVMPointerType(LLVMInt8Type(), 0) + return LLVMBuildBitCast(st->builder, global_var, string_type, "string_ptr") + + +def codegen_constant(st: State*, c: Constant*) -> LLVMValue*: + if c->kind == ConstantKind::Bool: + return LLVMConstInt(LLVMInt1Type(), c->boolean as long, False as int) + if c->kind == ConstantKind::Integer: + return LLVMConstInt(codegen_type(type_of_constant(c)), c->integer.value, c->integer.is_signed as int) + if c->kind == ConstantKind::Float or c->kind == ConstantKind::Double: + return LLVMConstRealOfString(codegen_type(type_of_constant(c)), c->double_or_float_text) + if c->kind == ConstantKind::Null: + return LLVMConstNull(codegen_type(voidPtrType)) + if c->kind == ConstantKind::String: + return make_a_string_constant(st, c->str) + if c->kind == ConstantKind::EnumMember: + return LLVMConstInt(LLVMInt32Type(), c->enum_member.memberidx, False as int) + assert False + + +def codegen_special_constant(name: byte*) -> LLVMValue*: + v = get_special_constant(name) + assert v != -1 + return LLVMConstInt(LLVMInt1Type(), v, False as int) + + +def build_signed_mod(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*: + # Jou's % operator ensures that a%b has same sign as b: + # jou_mod(a, b) = llvm_mod(llvm_mod(a, b) + b, b) + llmod = LLVMBuildSRem(builder, lhs, rhs, "smod_tmp") + sum = LLVMBuildAdd(builder, llmod, rhs, "smod_tmp") + return LLVMBuildSRem(builder, sum, rhs, "smod") + + +def build_signed_div(builder: LLVMBuilder*, lhs: LLVMValue*, rhs: LLVMValue*) -> LLVMValue*: + # LLVM's provides two divisions. One truncates, the other is an "exact div" + # that requires there is no remainder. Jou uses floor division which is + # neither of the two, but is quite easy to implement: + # + # floordiv(a, b) = exact_div(a - jou_mod(a, b), b) + # + top = LLVMBuildSub(builder, lhs, build_signed_mod(builder, lhs, rhs), "sdiv_tmp") + return LLVMBuildExactSDiv(builder, top, rhs, "sdiv") + + +def codegen_arithmetic_instruction(st: State*, ins: CfInstruction*) -> None: + lhs = get_local_var(st, ins->operands[0]) + rhs = get_local_var(st, ins->operands[1]) + + assert ins->operands[0]->type == ins->operands[1]->type + type = ins->operands[0]->type + + if type->kind == TypeKind::FloatingPoint: + if ins->kind == CfInstructionKind::NumAdd: + set_local_var(st, ins->destvar, LLVMBuildFAdd(st->builder, lhs, rhs, "float_sum")) + elif ins->kind == CfInstructionKind::NumSub: + set_local_var(st, ins->destvar, LLVMBuildFSub(st->builder, lhs, rhs, "float_diff")) + elif ins->kind == CfInstructionKind::NumMul: + set_local_var(st, ins->destvar, LLVMBuildFMul(st->builder, lhs, rhs, "float_prod")) + elif ins->kind == CfInstructionKind::NumDiv: + set_local_var(st, ins->destvar, LLVMBuildFDiv(st->builder, lhs, rhs, "float_quot")) + elif ins->kind == CfInstructionKind::NumMod: + set_local_var(st, ins->destvar, LLVMBuildFRem(st->builder, lhs, rhs, "float_mod")) + else: + assert False + + elif type->kind == TypeKind::SignedInteger: + if ins->kind == CfInstructionKind::NumAdd: + set_local_var(st, ins->destvar, LLVMBuildAdd(st->builder, lhs, rhs, "int_sum")) + elif ins->kind == CfInstructionKind::NumSub: + set_local_var(st, ins->destvar, LLVMBuildSub(st->builder, lhs, rhs, "int_diff")) + elif ins->kind == CfInstructionKind::NumMul: + set_local_var(st, ins->destvar, LLVMBuildMul(st->builder, lhs, rhs, "int_prod")) + elif ins->kind == CfInstructionKind::NumDiv: + set_local_var(st, ins->destvar, build_signed_div(st->builder, lhs, rhs)) + elif ins->kind == CfInstructionKind::NumMod: + set_local_var(st, ins->destvar, build_signed_mod(st->builder, lhs, rhs)) + else: + assert False + + elif type->kind == TypeKind::UnsignedInteger: + if ins->kind == CfInstructionKind::NumAdd: + set_local_var(st, ins->destvar, LLVMBuildAdd(st->builder, lhs, rhs, "uint_sum")) + elif ins->kind == CfInstructionKind::NumSub: + set_local_var(st, ins->destvar, LLVMBuildSub(st->builder, lhs, rhs, "uint_diff")) + elif ins->kind == CfInstructionKind::NumMul: + set_local_var(st, ins->destvar, LLVMBuildMul(st->builder, lhs, rhs, "uint_prod")) + elif ins->kind == CfInstructionKind::NumDiv: + set_local_var(st, ins->destvar, LLVMBuildUDiv(st->builder, lhs, rhs, "uint_quot")) + elif ins->kind == CfInstructionKind::NumMod: + set_local_var(st, ins->destvar, LLVMBuildURem(st->builder, lhs, rhs, "uint_mod")) + else: + assert False + + else: + assert False + + +def codegen_instruction(st: State*, ins: CfInstruction*) -> None: + if ins->kind == CfInstructionKind::Call: + args: LLVMValue** = malloc(ins->noperands * sizeof(args[0])) + for i = 0; i < ins->noperands; i++: + args[i] = get_local_var(st, ins->operands[i]) + return_value = codegen_call(st, &ins->signature, args, ins->noperands) + if ins->destvar != NULL: + set_local_var(st, ins->destvar, return_value) + free(args) + elif ins->kind == CfInstructionKind::Constant: + set_local_var(st, ins->destvar, codegen_constant(st, &ins->constant)) + elif ins->kind == CfInstructionKind::SpecialConstant: + set_local_var(st, ins->destvar, codegen_special_constant(ins->scname)) + elif ins->kind == CfInstructionKind::StringArray: + set_local_var(st, ins->destvar, LLVMConstString(ins->strarray.str, ins->strarray.len, True as int)) + elif ins->kind == CfInstructionKind::SizeOf: + set_local_var(st, ins->destvar, LLVMSizeOf(codegen_type(ins->type))) + elif ins->kind == CfInstructionKind::AddressOfLocalVar: + set_local_var(st, ins->destvar, get_pointer_to_local_var(st, ins->operands[0])) + elif ins->kind == CfInstructionKind::AddressOfGlobalVar: + set_local_var(st, ins->destvar, LLVMGetNamedGlobal(st->module, ins->globalname)) + elif ins->kind == CfInstructionKind::PtrLoad: + set_local_var(st, ins->destvar, LLVMBuildLoad(st->builder, get_local_var(st, ins->operands[0]), "ptr_load")) + elif ins->kind == CfInstructionKind::PtrStore: + LLVMBuildStore(st->builder, get_local_var(st, ins->operands[1]), get_local_var(st, ins->operands[0])) + elif ins->kind == CfInstructionKind::PtrToInt64: + set_local_var(st, ins->destvar, LLVMBuildPtrToInt(st->builder, get_local_var(st, ins->operands[0]), LLVMInt64Type(), "ptr_as_long")) + elif ins->kind == CfInstructionKind::Int64ToPtr: + set_local_var(st, ins->destvar, LLVMBuildIntToPtr(st->builder, get_local_var(st, ins->operands[0]), codegen_type(ins->destvar->type), "long_as_ptr")) + elif ins->kind == CfInstructionKind::PtrClassField: + classtype = ins->operands[0]->type->value_type + assert classtype->kind == TypeKind::Class + f = classtype->classdata.fields + while strcmp(f->name, ins->fieldname) != 0: + f++ + assert f < &classtype->classdata.fields[classtype->classdata.nfields] + + val = LLVMBuildStructGEP2(st->builder, codegen_type(classtype), get_local_var(st, ins->operands[0]), f->union_id, ins->fieldname) + # This cast is needed in two cases: + # * All pointers are i8* in structs so we can do self-referencing classes. + # * This is how unions work. + val = LLVMBuildBitCast(st->builder, val, LLVMPointerType(codegen_type(f->type), 0), "struct_member_cast") + set_local_var(st, ins->destvar, val) + elif ins->kind == CfInstructionKind::PtrMemsetToZero: + size = LLVMSizeOf(codegen_type(ins->operands[0]->type->value_type)) + LLVMBuildMemSet(st->builder, get_local_var(st, ins->operands[0]), LLVMConstInt(LLVMInt8Type(), 0, False as int), size, 0) + elif ins->kind == CfInstructionKind::PtrAddInt: + ptr_var = get_local_var(st, ins->operands[0]) + int_var = get_local_var(st, ins->operands[1]) + set_local_var(st, ins->destvar, LLVMBuildGEP(st->builder, ptr_var, &int_var, 1, "ptr_add_int")) + elif ins->kind == CfInstructionKind::NumCast: + from = ins->operands[0]->type + to = ins->destvar->type + assert is_number_type(from) and is_number_type(to) + + if is_integer_type(from) and is_integer_type(to): + # Examples: + # signed 8-bit 0xFF (-1) --> 16-bit 0xFFFF (-1 or max value) + # unsigned 8-bit 0xFF (255) --> 16-bit 0x00FF (255) + set_local_var(st, ins->destvar, LLVMBuildIntCast2(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), (from->kind == TypeKind::SignedInteger) as int, "int_cast")) + elif is_integer_type(from) and to->kind == TypeKind::FloatingPoint: + # integer --> double / float + if from->kind == TypeKind::SignedInteger: + set_local_var(st, ins->destvar, LLVMBuildSIToFP(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), "cast")) + else: + set_local_var(st, ins->destvar, LLVMBuildUIToFP(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), "cast")) + elif from->kind == TypeKind::FloatingPoint and is_integer_type(to): + if to->kind == TypeKind::SignedInteger: + set_local_var(st, ins->destvar, LLVMBuildFPToSI(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), "cast")) + else: + set_local_var(st, ins->destvar, LLVMBuildFPToUI(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), "cast")) + elif from->kind == TypeKind::FloatingPoint and to->kind == TypeKind::FloatingPoint: + set_local_var(st, ins->destvar, LLVMBuildFPCast(st->builder, get_local_var(st, ins->operands[0]), codegen_type(to), "cast")) + else: + assert False + elif ins->kind == CfInstructionKind::BoolNegate: + set_local_var(st, ins->destvar, LLVMBuildXor(st->builder, get_local_var(st, ins->operands[0]), LLVMConstInt(LLVMInt1Type(), 1, False as int), "bool_negate")) + elif ins->kind == CfInstructionKind::PtrCast: + set_local_var(st, ins->destvar, LLVMBuildBitCast(st->builder, get_local_var(st, ins->operands[0]), codegen_type(ins->destvar->type), "ptr_cast")) + elif ins->kind == CfInstructionKind::VarCpy or ins->kind == CfInstructionKind::Int32ToEnum or ins->kind == CfInstructionKind::EnumToInt32: + set_local_var(st, ins->destvar, get_local_var(st, ins->operands[0])) + elif ins->kind == CfInstructionKind::NumAdd or ins->kind == CfInstructionKind::NumSub or ins->kind == CfInstructionKind::NumMul or ins->kind == CfInstructionKind::NumDiv or ins->kind == CfInstructionKind::NumMod: + codegen_arithmetic_instruction(st, ins) + elif ins->kind == CfInstructionKind::NumEq: + if is_integer_type(ins->operands[0]->type): + set_local_var(st, ins->destvar, LLVMBuildICmp(st->builder, LLVMIntPredicate::EQ, get_local_var(st, ins->operands[0]), get_local_var(st, ins->operands[1]), "num_eq")) + else: + set_local_var(st, ins->destvar, LLVMBuildFCmp(st->builder, LLVMRealPredicate::OEQ, get_local_var(st, ins->operands[0]), get_local_var(st, ins->operands[1]), "num_eq")) + elif ins->kind == CfInstructionKind::NumLt: + if ins->operands[0]->type->kind == TypeKind::UnsignedInteger and ins->operands[1]->type->kind == TypeKind::UnsignedInteger: + set_local_var(st, ins->destvar, LLVMBuildICmp(st->builder, LLVMIntPredicate::ULT, get_local_var(st, ins->operands[0]), get_local_var(st, ins->operands[1]), "num_lt")) + elif is_integer_type(ins->operands[0]->type) and is_integer_type(ins->operands[1]->type): + set_local_var(st, ins->destvar, LLVMBuildICmp(st->builder, LLVMIntPredicate::SLT, get_local_var(st, ins->operands[0]), get_local_var(st, ins->operands[1]), "num_lt")) + else: + set_local_var(st, ins->destvar, LLVMBuildFCmp(st->builder, LLVMRealPredicate::OLT, get_local_var(st, ins->operands[0]), get_local_var(st, ins->operands[1]), "num_lt")) + else: + assert False + + +def find_block(cfg: CfGraph*, b: CfBlock*) -> int: + for i = 0; i < cfg->n_all_blocks; i++: + if cfg->all_blocks[i] == b: + return i + assert False + + +def codegen_call_to_the_special_startup_function(st: State*) -> None: + if WINDOWS or MACOS or NETBSD: + functype = LLVMFunctionType(LLVMVoidType(), NULL, 0, False as int) + func = LLVMAddFunction(st->module, "_jou_startup", functype) + LLVMBuildCall2(st->builder, functype, func, NULL, 0, "") + + +def codegen_function_or_method_def(st: State*, cfg: CfGraph*) -> None: + st->cfvars = cfg->locals + st->cfvars_end = &cfg->locals[cfg->nlocals] + st->llvm_locals = malloc(sizeof(st->llvm_locals[0]) * cfg->nlocals) + + llvm_func = codegen_function_or_method_decl(st, &cfg->signature) + + blocks: LLVMBasicBlock** = malloc(sizeof(blocks[0]) * cfg->n_all_blocks) + for i = 0; i < cfg->n_all_blocks; i++: + name: byte[50] + sprintf(name, "block%d", i) + blocks[i] = LLVMAppendBasicBlock(llvm_func, name) + + assert cfg->all_blocks[0] == &cfg->start_block + LLVMPositionBuilderAtEnd(st->builder, blocks[0]) + + if get_self_class(&cfg->signature) == NULL and strcmp(cfg->signature.name, "main") == 0: + # Insert special code at start of main() + codegen_call_to_the_special_startup_function(st) + + # Allocate stack space for local variables at start of function. + return_value: LLVMValue* = NULL + for i = 0; i < cfg->nlocals; i++: + v = cfg->locals[i] + st->llvm_locals[i] = LLVMBuildAlloca(st->builder, codegen_type(v->type), v->name) + if strcmp(v->name, "return") == 0: + return_value = st->llvm_locals[i] + + # Place arguments into the first n local variables. + for i = 0; i < cfg->signature.nargs; i++: + set_local_var(st, cfg->locals[i], LLVMGetParam(llvm_func, i)) + + for i = 0; i < cfg->n_all_blocks; i++: + b = &cfg->all_blocks[i] + LLVMPositionBuilderAtEnd(st->builder, blocks[i]) + + for ins = (*b)->instructions; ins < &(*b)->instructions[(*b)->ninstructions]; ins++: + codegen_instruction(st, ins) + + if *b == &cfg->end_block: + assert (*b)->ninstructions == 0 + # The "return" variable may have been deleted as unused. + # In that case return_value is NULL but signature.returntype isn't. + if return_value != NULL: + LLVMBuildRet(st->builder, LLVMBuildLoad(st->builder, return_value, "return_value")) + elif cfg->signature.returntype != NULL or cfg->signature.is_noreturn: + LLVMBuildUnreachable(st->builder) + else: + LLVMBuildRetVoid(st->builder) + elif (*b)->iftrue != NULL and (*b)->iffalse != NULL: + if (*b)->iftrue == (*b)->iffalse: + LLVMBuildBr(st->builder, blocks[find_block(cfg, (*b)->iftrue)]) + else: + assert (*b)->branchvar != NULL + LLVMBuildCondBr( + st->builder, + get_local_var(st, (*b)->branchvar), + blocks[find_block(cfg, (*b)->iftrue)], + blocks[find_block(cfg, (*b)->iffalse)]) + elif (*b)->iftrue == NULL and (*b)->iffalse == NULL: + LLVMBuildUnreachable(st->builder) + else: + assert False + + free(blocks) + free(st->llvm_locals) + + +def codegen(cfgfile: CfGraphFile*, ft: FileTypes*) -> LLVMModule*: + st = State{ + module = LLVMModuleCreateWithName(cfgfile->filename), + builder = LLVMCreateBuilder(), + } + + LLVMSetTarget(st.module, target.triple) + LLVMSetDataLayout(st.module, target.data_layout) + + for v = ft->globals; v < &ft->globals[ft->nglobals]; v++: + t = codegen_type(v->type) + globalptr = LLVMAddGlobal(st.module, t, v->name) + if v->defined_in_current_file: + LLVMSetInitializer(globalptr, LLVMConstNull(t)) + + for g = cfgfile->graphs; g < &cfgfile->graphs[cfgfile->ngraphs]; g++: + codegen_function_or_method_def(&st, *g) + + LLVMDisposeBuilder(st.builder) + return st.module diff --git a/self_hosted/errors_and_warnings.jou b/self_hosted/errors_and_warnings.jou index 0e520cd2..637b0f8d 100644 --- a/self_hosted/errors_and_warnings.jou +++ b/self_hosted/errors_and_warnings.jou @@ -5,12 +5,26 @@ class Location: path: byte* # Not owned. Points to a string that is held elsewhere. lineno: int -def fail(location: Location, message: byte*) -> noreturn: - # When stdout is redirected to same place as stderr, - # make sure that normal printf()s show up before our error. + +# When stdout is redirected to same place as stderr, +# make sure that normal printf()s show up before our warning. +def flush_streams() -> None: fflush(stdout) fflush(stderr) + +def show_warning(location: Location, message: byte*) -> None: + flush_streams() + + fprintf(stderr, "compiler warning for file \"%s\"", location.path) + if location.lineno != 0: + fprintf(stderr, ", line %d", location.lineno) + fprintf(stderr, ": %s\n", message) + + +def fail(location: Location, message: byte*) -> noreturn: + flush_streams() + fprintf(stderr, "compiler error in file \"%s\"", location.path) if location.lineno != 0: fprintf(stderr, ", line %d", location.lineno) diff --git a/self_hosted/evaluate.jou b/self_hosted/evaluate.jou index 355885f1..336596f9 100644 --- a/self_hosted/evaluate.jou +++ b/self_hosted/evaluate.jou @@ -1,12 +1,10 @@ -# Compile-time evaluating if statements. +import "stdlib/str.jou" +import "stdlib/mem.jou" import "./ast.jou" import "./errors_and_warnings.jou" -import "stdlib/str.jou" -import "stdlib/mem.jou" -# Return values: 1=true, 0=false, -1=unknown def get_special_constant(name: byte*) -> int: if strcmp(name, "WINDOWS") == 0: return WINDOWS as int @@ -20,10 +18,10 @@ def get_special_constant(name: byte*) -> int: def evaluate_condition(expr: AstExpression*) -> bool: if expr->kind == AstExpressionKind::GetVariable: v = get_special_constant(expr->varname) - if v == 1: - return True if v == 0: return False + if v == 1: + return True if expr->kind == AstExpressionKind::And: return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1]) @@ -38,9 +36,9 @@ def evaluate_condition(expr: AstExpression*) -> bool: # returns the statements to replace if statement with def evaluate_compile_time_if_statement(if_stmt: AstIfStatement*) -> AstBody: result = &if_stmt->else_body - for p = if_stmt->if_and_elifs; p < &if_stmt->if_and_elifs[if_stmt->n_if_and_elifs]; p++: - if evaluate_condition(&p->condition): - result = &p->body + for i = 0; i < if_stmt->n_if_and_elifs; i++: + if evaluate_condition(&if_stmt->if_and_elifs[i].condition): + result = &if_stmt->if_and_elifs[i].body break ret = *result @@ -63,7 +61,7 @@ def replace(body: AstBody*, i: int, new: AstBody) -> None: # This handles nested if statements. -def evaluate_compile_time_if_statements_in_body(body: AstBody*) -> None: +def evaluate_compile_time_if_statements(body: AstBody*) -> None: i = 0 while i < body->nstatements: if body->statements[i].kind == AstStatementKind::If: diff --git a/self_hosted/free.jou b/self_hosted/free.jou new file mode 100644 index 00000000..61635d27 --- /dev/null +++ b/self_hosted/free.jou @@ -0,0 +1,71 @@ +# Boring boilerplate code to free up data structures used in compilation. + +import "./structs.jou" +import "./types.jou" +import "./token.jou" +import "stdlib/mem.jou" + +def free_tokens(tokenlist: Token*) -> None: + for t = tokenlist; t->kind != TokenKind::EndOfFile; t++: + if t->kind == TokenKind::String: + free(t->long_string) + free(tokenlist) + +def free_constant(c: Constant*) -> None: + if c->kind == ConstantKind::String: + free(c->str) + +def free_signature(sig: Signature*) -> None: + free(sig->argnames) + free(sig->argtypes) + +def free_export_symbol(es: ExportSymbol*) -> None: + if es->kind == ExportSymbolKind::Function: + free_signature(&es->funcsignature) + +def free_file_types(ft: FileTypes*) -> None: + for t = ft->owned_types; t < &ft->owned_types[ft->n_owned_types]; t++: + free_type(*t) + for func = ft->functions; func < &ft->functions[ft->nfunctions]; func++: + free_signature(&func->signature) + for fom = ft->fomtypes; fom < &ft->fomtypes[ft->nfomtypes]; fom++: + for et = fom->expr_types; et < &fom->expr_types[fom->n_expr_types]; et++: + free(*et) + free(fom->expr_types) + free(fom->locals) # Don't free individual locals because they're owned by CFG now + free_signature(&fom->signature) + free(ft->globals) + free(ft->types) + free(ft->owned_types) + free(ft->functions) + free(ft->fomtypes) + +def free_control_flow_graph_block(cfg: CfGraph*, b: CfBlock*) -> None: + for ins = b->instructions; ins < &b->instructions[b->ninstructions]; ins++: + if ins->kind == CfInstructionKind::Constant: + free_constant(&ins->constant) + if ins->kind == CfInstructionKind::StringArray: + free(ins->strarray.str) + if ins->kind == CfInstructionKind::Call: + free_signature(&ins->signature) + free(ins->operands) + free(b->instructions) + if b != &cfg->start_block and b != &cfg->end_block: + free(b) + +def free_cfg(cfg: CfGraph*) -> None: + free_signature(&cfg->signature) + + for b = cfg->all_blocks; b < &cfg->all_blocks[cfg->n_all_blocks]; b++: + free_control_flow_graph_block(cfg, *b) + for v = cfg->locals; v < &cfg->locals[cfg->nlocals]; v++: + free(*v) + + free(cfg->all_blocks) + free(cfg->locals) + free(cfg) + +def free_control_flow_graphs(cfgfile: CfGraphFile*) -> None: + for cfg = cfgfile->graphs; cfg < &cfgfile->graphs[cfgfile->ngraphs]; cfg++: + free_cfg(*cfg) + free(cfgfile->graphs) diff --git a/self_hosted/llvm.jou b/self_hosted/llvm.jou index 56817086..ed5e4342 100644 --- a/self_hosted/llvm.jou +++ b/self_hosted/llvm.jou @@ -18,6 +18,9 @@ class LLVMTargetData: class LLVMTargetMachine: _dummy: int +class LLVMPassManagerBuilder: + _dummy: int + # =========== Target.h =========== declare LLVMInitializeX86TargetInfo() -> None declare LLVMInitializeX86Target() -> None @@ -276,3 +279,17 @@ declare LLVMBuildInsertValue(Builder: LLVMBuilder*, AggVal: LLVMValue*, EltVal: declare LLVMCreatePassManager() -> LLVMPassManager* declare LLVMRunPassManager(PM: LLVMPassManager*, M: LLVMModule*) -> int declare LLVMDisposePassManager(PM: LLVMPassManager*) -> None + + +# =========== Transforms/PassManagerBuilder.h =========== +declare LLVMPassManagerBuilderCreate() -> LLVMPassManagerBuilder* +declare LLVMPassManagerBuilderDispose(PMB: LLVMPassManagerBuilder*) -> None +declare LLVMPassManagerBuilderSetOptLevel(PMB: LLVMPassManagerBuilder*, OptLevel: int) -> None +declare LLVMPassManagerBuilderSetSizeLevel(PMB: LLVMPassManagerBuilder*, SizeLevel: int) -> None +declare LLVMPassManagerBuilderSetDisableUnitAtATime(PMB: LLVMPassManagerBuilder*, Value: int) -> None +declare LLVMPassManagerBuilderSetDisableUnrollLoops(PMB: LLVMPassManagerBuilder*, Value: int) -> None +declare LLVMPassManagerBuilderSetDisableSimplifyLibCalls(PMB: LLVMPassManagerBuilder*, Value: int) -> None +declare LLVMPassManagerBuilderUseInlinerWithThreshold(PMB: LLVMPassManagerBuilder*, Threshold: int) -> None +declare LLVMPassManagerBuilderPopulateFunctionPassManager(PMB: LLVMPassManagerBuilder*, PM: LLVMPassManager*) -> None +declare LLVMPassManagerBuilderPopulateModulePassManager(PMB: LLVMPassManagerBuilder*, PM: LLVMPassManager*) -> None +declare LLVMPassManagerBuilderPopulateLTOPassManager(PMB: LLVMPassManagerBuilder*, PM: LLVMPassManager*, Internalize: int, RunInliner: int) -> None diff --git a/self_hosted/main.jou b/self_hosted/main.jou index 75932da6..47f7ba79 100644 --- a/self_hosted/main.jou +++ b/self_hosted/main.jou @@ -1,510 +1,503 @@ -import "../config.jou" -import "./ast.jou" -import "./errors_and_warnings.jou" -import "./tokenizer.jou" -import "./parser.jou" -import "./types.jou" -import "./typecheck.jou" -import "./paths.jou" -import "./target.jou" -import "./create_llvm_ir.jou" -import "./llvm.jou" -import "./evaluate.jou" -import "stdlib/mem.jou" +import "stdlib/io.jou" import "stdlib/process.jou" +import "stdlib/mem.jou" +import "stdlib/errno.jou" import "stdlib/str.jou" -import "stdlib/io.jou" -enum CompilerMode: - TokenizeOnly # Tokenize one file, don't recurse to imports - ParseOnly # Tokenize and parse one file, don't recurse to imports - Compile # Compile and link - CompileAndRun # Compile, link and run a program (default) +import "./build_cfg.jou" +import "./evaluate.jou" +import "./run.jou" +import "./codegen.jou" +import "./print.jou" +import "./llvm.jou" +import "./output.jou" +import "./typecheck.jou" +import "./target.jou" +import "./types.jou" +import "./free.jou" +import "./parser.jou" +import "./paths.jou" +import "./errors_and_warnings.jou" +import "./structs.jou" +import "./update.jou" +import "./tokenizer.jou" +import "./ast.jou" + -class CommandLineArgs: - mode: CompilerMode - output_file: byte* # The argument after -o, possibly with .exe appended to it - verbosity: int # Number of -v/--verbose flags given - main_path: byte* # Jou file path passed on command line +def optimize(module: LLVMModule*, level: int) -> None: + assert 1 <= level and level <= 3 + pm = LLVMCreatePassManager() + + # The default settings should be fine for Jou because they work well for + # C and C++, and Jou is quite similar to C. + pmbuilder = LLVMPassManagerBuilderCreate() + LLVMPassManagerBuilderSetOptLevel(pmbuilder, level) + LLVMPassManagerBuilderPopulateModulePassManager(pmbuilder, pm) + LLVMPassManagerBuilderDispose(pmbuilder) + + LLVMRunPassManager(pm, module) + LLVMDisposePassManager(pm) -# An error message should have already been printed to stderr, without a trailing \n -def fail_parsing_args(argv0: byte*, message: byte*) -> None: - fprintf(stderr, "%s: %s (try \"%s --help\")\n", argv0, message, argv0) - exit(2) def print_help(argv0: byte*) -> None: printf("Usage:\n") - printf(" %s [options] FILENAME.jou\n", argv0) - printf(" %s --help # This message\n", argv0) + printf(" %s [-o OUTFILE] [-O0|-O1|-O2|-O3] [--verbose] [--linker-flags \"...\"] FILENAME\n", argv0) + printf(" %s --help # This message\n", argv0) + printf(" %s --update # Download and install the latest Jou\n", argv0) printf("\n") printf("Options:\n") - printf(" -o OUTFILE output an executable file, don't run the code\n") - printf(" -v/--verbose print what compilation steps are done\n") - printf(" -vv / --verbose --verbose show what each compilation step produces\n") - printf(" --tokenize-only tokenize one file and display the resulting tokens\n") - printf(" --parse-only tokenize and parse one file and display the AST\n") - -def parse_args(argc: int, argv: byte**) -> CommandLineArgs: - result = CommandLineArgs{mode = CompilerMode::CompileAndRun} + printf(" -o OUTFILE output an executable file, don't run the code\n") + printf(" -O0/-O1/-O2/-O3 set optimization level (0 = no optimization, 1 = default, 3 = runs fastest)\n") + printf(" -v / --verbose display some progress information\n") + printf(" -vv display a lot of information about all compilation steps\n") + printf(" --valgrind use valgrind when running the code\n") + printf(" --tokenize-only display only the output of the tokenizer, don't do anything else\n") + printf(" --parse-only display only the AST (parse tree), don't do anything else\n") + printf(" --linker-flags appended to the linker command, so you can use external libraries\n") + +def parse_arguments(argc: int, argv: byte**) -> None: + memset(&command_line_args, 0, sizeof command_line_args) + command_line_args.argv0 = argv[0] + # Set default optimize to O1, user sets optimize will overwrite the default flag + command_line_args.optlevel = 1 + + if argc == 2 and strcmp(argv[1], "--help") == 0: + print_help(argv[0]) + exit(0) + + if argc == 2 and strcmp(argv[1], "--update") == 0: + update_jou_compiler() + exit(0) i = 1 while i < argc: - arg = argv[i++] - - if strcmp(arg, "--help") == 0: - print_help(argv[0]) - exit(0) - - if result.mode != CompilerMode::CompileAndRun and ( - strcmp(arg, "--tokenize-only") == 0 - or strcmp(arg, "--parse-only") == 0 - or strcmp(arg, "-o") == 0 - ): - fail_parsing_args(argv[0], "only one of --tokenize-only, --parse-only or -o can be given") - - if strcmp(arg, "--tokenize-only") == 0: - result.mode = CompilerMode::TokenizeOnly - elif strcmp(arg, "--parse-only") == 0: - result.mode = CompilerMode::ParseOnly - elif strcmp(arg, "-o") == 0: - result.mode = CompilerMode::Compile - result.output_file = argv[i++] - if result.output_file == NULL: - fail_parsing_args(argv[0], "-o must be followed by the name of an output file") - elif strcmp(arg, "--verbose") == 0: - result.verbosity++ - elif starts_with(arg, "-v") and strspn(&arg[1], "v") == strlen(arg) - 1: - result.verbosity += (strlen(arg) as int) - 1 - elif arg[0] == '-': - message = malloc(strlen(arg) + 100) - sprintf(message, "unknown option '%s'", arg) - fail_parsing_args(argv[0], message) - elif result.main_path == NULL: - result.main_path = arg - else: - fail_parsing_args(argv[0], "you can pass only one Jou file") - - if result.main_path == NULL: - fail_parsing_args(argv[0], "missing Jou file name") - - return result - -def find_file(files: FileState*, nfiles: int, path: byte*) -> FileState*: - for i = 0; i < nfiles; i++: - if strcmp(files[i].ast.path, path) == 0: - return &files[i] - return NULL - -# C:\Users\myname\.foo-bar.jou --> "_foo_bar" -# Result never contains "-", so you can add "-" separated suffixes without conflicts. -def get_sane_filename(path: byte*) -> byte[50]: - while True: - if strstr(path, "/") != NULL: - path = strstr(path, "/") - elif strstr(path, "\\") != NULL: - path = strstr(path, "\\") - else: - break - path++ # skip slash - - name: byte[50] - snprintf(name, sizeof name, "%s", path) - assert name[0] != '\0' - - if name[0] == '.': - name[0] = '_' - for i = 0; name[i] != '\0'; i++: - if name[i] == '.': - name[i] = '\0' - break - if name[i] == '-': - name[i] = '_' - return name - - -def check_main_function(ast: AstFile*) -> bool: - for i = 0; i < ast->body.nstatements; i++: - s = &ast->body.statements[i] - if s->kind == AstStatementKind::Function and strcmp(s->function.signature.name, "main") == 0: - return True - return False - -def check_ast_and_import_conflicts(ast: AstFile*, symbol: ExportSymbol*) -> None: - for i = 0; i < ast->body.nstatements; i++: - ts = &ast->body.statements[i] - if ts->kind == AstStatementKind::Function: - conflict = symbol->kind == ExportSymbolKind::Function and strcmp(ts->function.signature.name, symbol->name) == 0 + if strcmp(argv[i], "--help") == 0 or strcmp(argv[i], "--update") == 0: + fprintf(stderr, "%s: \"%s\" cannot be used with other arguments", argv[0], argv[i]) + fprintf(stderr, " (try \"%s --help\")\n", argv[0]) + exit(2) + elif strcmp(argv[i], "--verbose") == 0: + command_line_args.verbosity++ + i++ + elif starts_with(argv[i], "-v") and strspn(&argv[i][1], "v") == strlen(argv[i])-1: + command_line_args.verbosity += (strlen(argv[i]) as int) - 1 + i++ + elif strcmp(argv[i], "--valgrind") == 0: + command_line_args.valgrind = True + i++ + elif strcmp(argv[i], "--tokenize-only") == 0: + if argc > 3: + fprintf(stderr, "%s: --tokenize-only cannot be used together with other flags (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) + command_line_args.tokenize_only = True + i++ + elif strcmp(argv[i], "--parse-only") == 0: + if argc > 3: + fprintf(stderr, "%s: --parse-only cannot be used together with other flags (try \"%s --help\")", argv[0], argv[0]) + exit(2) + command_line_args.parse_only = True + i++ + elif strcmp(argv[i], "--linker-flags") == 0: + if command_line_args.linker_flags != NULL: + fprintf(stderr, "%s: --linker-flags cannot be given multiple times (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) + + if argc-i < 2: + fprintf(stderr, "%s: there must be a string of flags after --linker-flags (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) + + command_line_args.linker_flags = argv[i+1] + i += 2 elif ( - ts->kind == AstStatementKind::GlobalVariableDeclaration - or ts->kind == AstStatementKind::GlobalVariableDefinition + strlen(argv[i]) == 3 + and starts_with(argv[i], "-O") + and argv[i][2] >= '0' + and argv[i][2] <= '3' ): - conflict = symbol->kind == ExportSymbolKind::GlobalVariable and strcmp(ts->var_declaration.name, symbol->name) == 0 - elif ts->kind == AstStatementKind::Class: - conflict = symbol->kind == ExportSymbolKind::Type and strcmp(ts->classdef.name, symbol->name) == 0 - elif ts->kind == AstStatementKind::Enum: - conflict = symbol->kind == ExportSymbolKind::Type and strcmp(ts->enumdef.name, symbol->name) == 0 + command_line_args.optlevel = argv[i][2] - '0' + i++ + elif strcmp(argv[i], "-o") == 0: + if argc-i < 2: + fprintf(stderr, "%s: there must be a file name after -o", argv[0]) + fprintf(stderr, " (try \"%s --help\")\n", argv[0]) + exit(2) + + command_line_args.outfile = argv[i+1] + if strlen(command_line_args.outfile) > 4 and ends_with(command_line_args.outfile, ".jou"): + fprintf(stderr, "%s: the filename after -o should be an executable, not a Jou file (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) + i += 2 + elif argv[i][0] == '-': + fprintf(stderr, "%s: unknown argument \"%s\"", argv[0], argv[i]) + fprintf(stderr, " (try \"%s --help\")\n", argv[0]) + exit(2) + elif command_line_args.infile != NULL: + fprintf(stderr, "%s: you can only pass one Jou file (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) else: - assert False - - if conflict: - if symbol->kind == ExportSymbolKind::Function: - kind_name = "function" - elif symbol->kind == ExportSymbolKind::GlobalVariable: - kind_name = "global variable" - elif symbol->kind == ExportSymbolKind::Type: - kind_name = "type" - else: - assert False + command_line_args.infile = argv[i++] - message: byte[500] - # TODO: might be nice to show where it was imported from - snprintf(message, sizeof message, "a %s named '%s' already exists", kind_name, symbol->name) - fail(ts->location, message) + if command_line_args.infile == NULL: + fprintf(stderr, "%s: missing Jou file name (try \"%s --help\")\n", argv[0], argv[0]) + exit(2) class FileState: + path: byte* ast: AstFile - typectx: FileTypes + types: FileTypes + module: LLVMModule* pending_exports: ExportSymbol* class ParseQueueItem: - path: byte* - is_imported: bool + filename: byte* import_location: Location -class Compiler: - argv0: byte* - verbosity: int +class CompileState: stdlib_path: byte* - args: CommandLineArgs* files: FileState* nfiles: int - automagic_files: byte*[10] - - def determine_automagic_files(self) -> None: - self->automagic_files[0] = malloc(strlen(self->stdlib_path) + 40) - sprintf(self->automagic_files[0], "%s/_assert_fail.jou", self->stdlib_path) - if WINDOWS or MACOS or NETBSD: - self->automagic_files[1] = malloc(strlen(self->stdlib_path) + 40) - sprintf(self->automagic_files[1], "%s/_jou_startup.jou", self->stdlib_path) - - def parse_all_files(self) -> None: - queue: ParseQueueItem* = malloc(50 * sizeof queue[0]) - queue_len = 0 - queue[queue_len++] = ParseQueueItem{path = self->args->main_path} - for i = 0; self->automagic_files[i] != NULL; i++: - queue[queue_len++] = ParseQueueItem{path = self->automagic_files[i]} - - while queue_len > 0: - item = queue[--queue_len] - - found = False - for i = 0; i < self->nfiles; i++: - if strcmp(self->files[i].ast.path, item.path) == 0: - found = True - break - if found: - continue - - if self->verbosity >= 1: - printf("Parsing %s\n", item.path) - - if item.is_imported: - tokens = tokenize(item.path, &item.import_location) - else: - tokens = tokenize(item.path, NULL) - if self->verbosity >= 2: - print_tokens(tokens) - ast = parse(tokens, self->stdlib_path) - if self->verbosity >= 2: - ast.print() - free(tokens) # TODO: do this properly - - evaluate_compile_time_if_statements_in_body(&ast.body) - - if item.is_imported and check_main_function(&ast): - assert item.import_location.path != NULL - fail(item.import_location, "imported file should not have `main` function") - - self->files = realloc(self->files, sizeof self->files[0] * (self->nfiles + 1)) - self->files[self->nfiles++] = FileState{ast = ast} - - for i = 0; i < ast.nimports; i++: - # TODO: offsetof() - queue = realloc(queue, sizeof queue[0] * (queue_len + 1)) - queue[queue_len++] = ParseQueueItem{ - path = ast.imports[i].resolved_path, - is_imported = True, - import_location = ast.imports[i].location, - } - - free(queue) - - def process_imports_and_exports(self) -> None: - if self->verbosity >= 1: - printf("Processing imports/exports\n") - - for idest = 0; idest < self->nfiles; idest++: - dest = &self->files[idest] - seen_before: FileState** = malloc(sizeof(seen_before[0]) * dest->ast.nimports) - - for i = 0; i < dest->ast.nimports; i++: - imp = &dest->ast.imports[i] - - src: FileState* = NULL - for isrc = 0; isrc < self->nfiles; isrc++: - if strcmp(self->files[isrc].ast.path, imp->resolved_path) == 0: - src = &self->files[isrc] - break - assert src != NULL - - if src == dest: - fail(imp->location, "the file itself cannot be imported") - - for k = 0; k < i; k++: - if seen_before[k] == src: - message: byte[500] - snprintf(message, sizeof(message), "file \"%s\" is imported twice", imp->specified_path) - fail(imp->location, message) - seen_before[i] = src - - for exp = src->pending_exports; exp->name[0] != '\0'; exp++: - if self->verbosity >= 1: - printf( - " %s: imported in %s, exported in %s\n", - exp->name, src->ast.path, dest->ast.path, - ) - check_ast_and_import_conflicts(&dest->ast, exp) - dest->typectx.add_imported_symbol(exp) - - free(seen_before) - - for i = 0; i < self->nfiles; i++: - free(self->files[i].pending_exports) - self->files[i].pending_exports = NULL - - def typecheck_stage1_all_files(self) -> None: - for i = 0; i < self->nfiles; i++: - if self->verbosity >= 1: - printf("Type-check stage 1: %s\n", self->files[i].ast.path) - - assert self->files[i].pending_exports == NULL - self->files[i].pending_exports = typecheck_stage1_create_types( - &self->files[i].typectx, - &self->files[i].ast, - ) - - def typecheck_stage2_all_files(self) -> None: - for i = 0; i < self->nfiles; i++: - if self->verbosity >= 1: - printf("Type-check stage 2: %s\n", self->files[i].ast.path) - - assert self->files[i].pending_exports == NULL - self->files[i].pending_exports = typecheck_stage2_populate_types( - &self->files[i].typectx, - &self->files[i].ast, - ) - - def typecheck_stage3_all_files(self) -> None: - for i = 0; i < self->nfiles; i++: - if self->verbosity >= 1: - printf("Type-check stage 3: %s\n", self->files[i].ast.path) - - typecheck_stage3_function_and_method_bodies( - &self->files[i].typectx, - &self->files[i].ast, - ) - - def get_object_file_paths(self) -> byte**: - sane_names: byte[50]* = calloc(sizeof sane_names[0], self->nfiles) - result: byte** = calloc(sizeof result[0], self->nfiles + 1) # NULL terminated - - # First, extract just the names from file paths. - # "blah/blah/foo.jou" --> "foo" - for i = 0; i < self->nfiles; i++: - sane_names[i] = get_sane_filename(self->files[i].ast.path) - - for i = 0; i < self->nfiles; i++: - # If there are 3 files named foo.jou in different directories, their object files will be foo.o, foo-1.o, foo-2.o - counter = 0 - for k = 0; k < i; k++: - if strcmp(sane_names[k], sane_names[i]) == 0: - counter++ - - name: byte[100] - if counter == 0: - sprintf(name, "%s.o", sane_names[i]) - else: - sprintf(name, "%s-%d.o", sane_names[i], counter) - result[i] = get_path_to_file_in_jou_compiled(name) + parse_queue: ParseQueueItem* + parse_queue_len: int - free(sane_names) - return result +def find_file(compst: CompileState*, path: byte*) -> FileState*: + for fs = compst->files; fs < &compst->files[compst->nfiles]; fs++: + if strcmp(fs->path, path) == 0: + return fs + return NULL - def get_exe_file_path(self) -> byte*: - if self->args->output_file == NULL: - tmp = get_sane_filename(self->args->main_path) - exe = get_path_to_file_in_jou_compiled(tmp) - else: - exe = strdup(self->args->output_file) - - if WINDOWS and not ends_with(exe, ".exe") and not ends_with(exe, ".EXE"): - exe = realloc(exe, strlen(exe) + 10) - strcat(exe, ".exe") - - if WINDOWS: - for i = 0; exe[i] != '\0'; i++: - if exe[i] == '/': - exe[i] = '\\' - - return exe - - def create_object_files(self) -> byte**: - paths = self->get_object_file_paths() - - for i = 0; i < self->nfiles; i++: - if self->verbosity >= 1: - printf("Build LLVM IR: %s\n", self->files[i].ast.path) - - module = create_llvm_ir(&self->files[i].ast, &self->files[i].typectx) - if self->verbosity >= 2: - # Don't want to use LLVMDumpModule() because it dumps to stdout. - # When redirected, stdout and stderr tend to get mixed up into a weird order. - s = LLVMPrintModuleToString(module) - puts(s) - LLVMDisposeMessage(s) - - if self->verbosity >= 1: - printf("Verify LLVM IR: %s\n", self->files[i].ast.path) - LLVMVerifyModule(module, LLVMVerifierFailureAction::AbortProcess, NULL) - - path = paths[i] - if self->verbosity >= 1: - printf("Emit LLVM IR: %s --> %s\n", self->files[i].ast.path, path) - - error: byte* = NULL - if LLVMTargetMachineEmitToFile(target.target_machine, module, path, LLVMCodeGenFileType::ObjectFile, &error) != 0: - assert error != NULL - fprintf(stderr, "error in LLVMTargetMachineEmitToFile(): %s\n", error) - exit(1) - assert error == NULL - - return paths - - def link(self, object_files: byte**) -> byte*: - exe = self->get_exe_file_path() - if WINDOWS: - c_compiler = find_installation_directory() - c_compiler = realloc(c_compiler, strlen(c_compiler) + 100) - strcat(c_compiler, "\\mingw64\\bin\\gcc.exe") +def open_the_file(path: byte*, import_location: Location*) -> FILE*: + f = fopen(path, "rb") + if f == NULL: + msg: byte[500] + if import_location != NULL: + snprintf(msg, sizeof(msg), "cannot import from \"%s\": %s", path, strerror(get_errno())) + fail(*import_location, msg) else: - c_compiler = get_jou_clang_path() + snprintf(msg, sizeof(msg), "cannot open file: %s", strerror(get_errno())) + fail(Location{path=path}, msg) + return f - command_size = strlen(c_compiler) + strlen(exe) + 100 - for i = 0; object_files[i] != NULL; i++: - command_size += 5 - command_size += strlen(object_files[i]) - command: byte* = malloc(command_size) +def defines_main(ast: AstFile*) -> bool: + for i = 0; i < ast->body.nstatements; i++: + s = &ast->body.statements[i] + if s->kind == AstStatementKind::Function and strcmp(s->function.signature.name, "main") == 0: + return True + return False - sprintf(command, "\"%s\" -o \"%s\"", c_compiler, exe) - for i = 0; object_files[i] != NULL; i++: - sprintf(&command[strlen(command)], " \"%s\"", object_files[i]) - strcat(command, " -lm") +def parse_file(compst: CompileState*, filename: byte*, import_location: Location*) -> None: + if find_file(compst, filename) != NULL: + return - if WINDOWS: - # windows strips outermost quotes for some reason, so let's quote it all one more time... - memmove(&command[1], &command[0], strlen(command) + 1) - command[0] = '"' - strcat(command, "\"") + fs = FileState{path = strdup(filename)} - if self->verbosity >= 1: - printf("Run linker command: %s\n", command) + if command_line_args.verbosity >= 1: + printf("Tokenizing %s\n", filename) + tokens = tokenize(fs.path, import_location) - # make sure that compiler output shows up before command output, even if redirected - fflush(stdout) - fflush(stderr) + if command_line_args.verbosity >= 2: + print_tokens(tokens) - ret = system(command) - free(command) - if ret != 0: - fprintf(stderr, "%s: linking failed\n", self->argv0) - exit(1) + if command_line_args.verbosity >= 1: + printf("Parsing %s\n", filename) + fs.ast = parse(tokens, compst->stdlib_path) + free_tokens(tokens) + + # TODO: enable this + if command_line_args.verbosity >= 1: + printf("Evaluating compile-time if statements in %s\n", filename) + evaluate_compile_time_if_statements(&fs.ast.body) + + if command_line_args.verbosity >= 2: + fs.ast.print() + + # If it's not the file passed on command line, it shouldn't define main() + if strcmp(filename, command_line_args.infile) != 0 and defines_main(&fs.ast): + # Set error location to import, so user immediately knows which file + # imports something that defines main(). + assert import_location != NULL + fail(*import_location, "imported file should not have `main` function") + + for imp = fs.ast.imports; imp < &fs.ast.imports[fs.ast.nimports]; imp++: + compst->parse_queue = realloc(compst->parse_queue, sizeof(compst->parse_queue[0]) * (compst->parse_queue_len + 1)) + assert compst->parse_queue != NULL + compst->parse_queue[compst->parse_queue_len++] = ParseQueueItem{ + filename = imp->resolved_path, + import_location = imp->location, + } - return exe + compst->files = realloc(compst->files, sizeof(compst->files[0]) * (compst->nfiles + 1)) + assert compst->files != NULL + compst->files[compst->nfiles++] = fs + +def parse_all_pending_files(compst: CompileState*) -> None: + while compst->parse_queue_len > 0: + it = compst->parse_queue[--compst->parse_queue_len] + parse_file(compst, it.filename, &it.import_location) + free(compst->parse_queue) + +def compile_ast_to_object_file(fs: FileState*) -> byte*: + if command_line_args.verbosity >= 1: + printf("Building Control Flow Graphs: %s\n", fs->path) + + cfgfile = build_control_flow_graphs(&fs->ast, &fs->types) + for imp = fs->ast.imports; imp < &fs->ast.imports[fs->ast.nimports]; imp++: + if not imp->used: + msg: byte[500] + snprintf(msg, sizeof msg, "\"%s\" imported but not used", imp->specified_path) + show_warning(imp->location, msg) + + if command_line_args.verbosity >= 2: + print_control_flow_graphs(&cfgfile) + + # TODO: implement this +# if command_line_args.verbosity >= 1: +# printf("Analyzing CFGs: %s\n", fs->path) +# simplify_control_flow_graphs(&cfgfile) +# if command_line_args.verbosity >= 2: +# print_control_flow_graphs(&cfgfile) + + if command_line_args.verbosity >= 1: + printf("Building LLVM IR: %s\n", fs->path) + + mod = codegen(&cfgfile, &fs->types) + # TODO: free the control flow graphs, this crashes for some reason + #free_control_flow_graphs(&cfgfil + + if command_line_args.verbosity >= 2: + print_llvm_ir(mod, False) + + # If this fails, it is not just users writing dumb code, it is a bug in this compiler. + # This compiler should always fail with an error elsewhere, or generate valid LLVM IR. + LLVMVerifyModule(mod, LLVMVerifierFailureAction::AbortProcess, NULL) + + if command_line_args.optlevel != 0: + if command_line_args.verbosity >= 1: + printf("Optimizing %s (level %d)\n", fs->path, command_line_args.optlevel) + optimize(mod, command_line_args.optlevel) + if command_line_args.verbosity >= 2: + print_llvm_ir(mod, True) + + objpath = compile_to_object_file(mod) + LLVMDisposeModule(mod) + return objpath + + +def statement_conflicts_with_an_import(stmt: AstStatement*, importsym: ExportSymbol*) -> bool: + if stmt->kind == AstStatementKind::Function: + return ( + importsym->kind == ExportSymbolKind::Function + and strcmp(importsym->name, stmt->function.signature.name) == 0 + ) + + if ( + stmt->kind == AstStatementKind::GlobalVariableDeclaration + or stmt->kind == AstStatementKind::GlobalVariableDefinition + ): + return ( + importsym->kind == ExportSymbolKind::GlobalVar + and strcmp(importsym->name, stmt->var_declaration.name) == 0 + ) + + if stmt->kind == AstStatementKind::Class: + return ( + importsym->kind == ExportSymbolKind::Type + and strcmp(importsym->name, stmt->classdef.name) == 0 + ) + + if stmt->kind == AstStatementKind::Enum: + return ( + importsym->kind == ExportSymbolKind::Type + and strcmp(importsym->name, stmt->enumdef.name) == 0 + ) + + assert False # TODO + +def add_imported_symbol(fs: FileState*, es: ExportSymbol*, imp: AstImport*) -> None: + for i = 0; i < fs->ast.body.nstatements; i++: + if statement_conflicts_with_an_import(&fs->ast.body.statements[i], es): + if es->kind == ExportSymbolKind::Function: + wat = "function" + elif es->kind == ExportSymbolKind::GlobalVar: + wat = "global variable" + elif es->kind == ExportSymbolKind::Type: + wat = "type" + else: + assert False - def run(self, exe: byte*) -> None: - command = malloc(strlen(exe) + 10) - sprintf(command, "\"%s\"", exe) - if self->verbosity >= 1: - printf("Run the compiled program command: %s\n", command) + msg: byte[500] + snprintf(msg, sizeof msg, "a %s named '%s' already exists", wat, es->name) + fail(fs->ast.body.statements[i].location, msg) - # make sure that compiler output shows up before command output, even if redirected - fflush(stdout) - fflush(stderr) + if es->kind == ExportSymbolKind::Function: + fs->types.functions = realloc(fs->types.functions, sizeof(fs->types.functions[0]) * (fs->types.nfunctions + 1)) + assert fs->types.functions != NULL + fs->types.functions[fs->types.nfunctions++] = SignatureAndUsedPtr{ + signature = copy_signature(&es->funcsignature), + usedptr = &imp->used, + } + elif es->kind == ExportSymbolKind::Type: + fs->types.types = realloc(fs->types.types, sizeof(fs->types.types[0]) * (fs->types.ntypes + 1)) + assert fs->types.types != NULL + fs->types.types[fs->types.ntypes++] = TypeAndUsedPtr{ + type = es->type, + usedptr = &imp->used, + } + elif es->kind == ExportSymbolKind::GlobalVar: + g = GlobalVariable{ + type = es->type, + usedptr = &imp->used, + } - ret = system(command) - if ret != 0: - # TODO: print something? The shell doesn't print stuff - # like "Segmentation fault" on Windows afaik - exit(1) + assert strlen(es->name) < sizeof g.name + strcpy(g.name, es->name) + fs->types.globals = realloc(fs->types.globals, sizeof(fs->types.globals[0]) * (fs->types.nglobals + 1)) + assert fs->types.globals != NULL + fs->types.globals[fs->types.nglobals++] = g + else: + assert False + +def add_imported_symbols(compst: CompileState*) -> None: + for to = compst->files; to < &compst->files[compst->nfiles]; to++: + seen_before: FileState** = NULL + seen_before_len = 0 + + for imp = to->ast.imports; imp < &to->ast.imports[to->ast.nimports]; imp++: + from = find_file(compst, imp->resolved_path) + assert from != NULL + + if from == to: + fail(imp->location, "the file itself cannot be imported") + + for i = 0; i < seen_before_len; i++: + if seen_before[i] == from: + msg: byte[500] + snprintf(msg, sizeof msg, "file \"%s\" is imported twice", imp->specified_path) + fail(imp->location, msg) + + seen_before = realloc(seen_before, sizeof(seen_before[0]) * (seen_before_len + 1)) + seen_before[seen_before_len++] = from + + for es = from->pending_exports; es->name[0] != '\0'; es++: + if command_line_args.verbosity >= 2: + if es->kind == ExportSymbolKind::Function: + kindstr = "function" + elif es->kind == ExportSymbolKind::GlobalVar: + kindstr = "global var" + elif es->kind == ExportSymbolKind::Type: + kindstr = "type" + else: + assert False + + printf("Adding imported %s %s: %s --> %s\n", + kindstr, es->name, from->path, to->path) + + add_imported_symbol(to, es, imp) + + free(seen_before) + + # Mark all exports as no longer pending. + for fs = compst->files; fs < &compst->files[compst->nfiles]; fs++: + for es = fs->pending_exports; es->name[0] != '\0'; es++: + free_export_symbol(es) + free(fs->pending_exports) + fs->pending_exports = NULL + +def include_special_stdlib_file(compst: CompileState*, filename: byte*) -> None: + path = malloc(strlen(compst->stdlib_path) + strlen(filename) + 123) + sprintf(path, "%s/%s", compst->stdlib_path, filename) + parse_file(compst, path, NULL) + free(path) def main(argc: int, argv: byte**) -> int: init_target() init_types() + stdlib = find_stdlib() + parse_arguments(argc, argv) + + compst = CompileState{ stdlib_path = stdlib } + if command_line_args.verbosity >= 2: + printf("Target triple: %s\n", target.triple) + printf("Data layout: %s\n", target.data_layout) + + if command_line_args.tokenize_only or command_line_args.parse_only: + tokens = tokenize(command_line_args.infile, NULL) + if command_line_args.tokenize_only: + print_tokens(tokens) + else: + ast = parse(tokens, compst.stdlib_path) + ast.print() + ast.free() + free_tokens(tokens) + return 0 + + include_special_stdlib_file(&compst, "_assert_fail.jou") + + if WINDOWS or MACOS or NETBSD: + include_special_stdlib_file(&compst, "_jou_startup.jou") + + parse_file(&compst, command_line_args.infile, NULL) + parse_all_pending_files(&compst) + + if command_line_args.verbosity >= 1: + printf("Type-checking...\n") + + for fs = compst.files; fs < &compst.files[compst.nfiles]; fs++: + if command_line_args.verbosity >= 1: + printf(" stage 1: %s\n", fs->path) + fs->pending_exports = typecheck_stage1_create_types(&fs->types, &fs->ast) + + add_imported_symbols(&compst) + for fs = compst.files; fs < &compst.files[compst.nfiles]; fs++: + if command_line_args.verbosity >= 1: + printf(" stage 2: %s\n", fs->path) + fs->pending_exports = typecheck_stage2_populate_types(&fs->types, &fs->ast) + + add_imported_symbols(&compst) + for fs = compst.files; fs < &compst.files[compst.nfiles]; fs++: + if command_line_args.verbosity >= 1: + printf(" stage 3: %s\n", fs->path) + typecheck_stage3_function_and_method_bodies(&fs->types, &fs->ast) + + objpaths: byte** = calloc(sizeof objpaths[0], compst.nfiles + 1) + for i = 0; i < compst.nfiles; i++: + objpaths[i] = compile_ast_to_object_file(&compst.files[i]) + + # Check for missing main() as late as possible, so that other errors come first. + # This way Jou users can work on other functions before main() function is written. + mainfile = find_file(&compst, command_line_args.infile) + assert mainfile != NULL + if not defines_main(&mainfile->ast): + fail(Location{path=mainfile->path, lineno=0}, "missing `main` function to execute the program") + + for fs = compst.files; fs < &compst.files[compst.nfiles]; fs++: + fs->ast.free() + free(fs->path) + free_file_types(&fs->types) + free(compst.files) + free(stdlib) + + if command_line_args.outfile != NULL: + exepath = strdup(command_line_args.outfile) + else: + exepath = get_default_exe_path() - args = parse_args(argc, argv) + run_linker(objpaths, exepath) + for i = 0; objpaths[i] != NULL; i++: + free(objpaths[i]) + free(objpaths) - if args.mode == CompilerMode::TokenizeOnly: - tokens = tokenize(args.main_path, NULL) - print_tokens(tokens) - free(tokens) - elif args.mode == CompilerMode::ParseOnly: - tokens = tokenize(args.main_path, NULL) - stdlib_path = find_stdlib() - ast = parse(tokens, stdlib_path) - ast.print() - ast.free() - free(tokens) - free(stdlib_path) - elif args.mode == CompilerMode::Compile or args.mode == CompilerMode::CompileAndRun: - compiler = Compiler{ - argv0 = argv[0], - verbosity = args.verbosity, - stdlib_path = find_stdlib(), - args = &args, - } - compiler.determine_automagic_files() - compiler.parse_all_files() - - compiler.typecheck_stage1_all_files() - compiler.process_imports_and_exports() - compiler.typecheck_stage2_all_files() - compiler.process_imports_and_exports() - compiler.typecheck_stage3_all_files() - - mainfile = find_file(compiler.files, compiler.nfiles, args.main_path) - assert mainfile != NULL - - if not check_main_function(&mainfile->ast): - l = Location{path=mainfile->ast.path, lineno=0} - fail(l, "missing `main` function to execute the program") - - object_files = compiler.create_object_files() - executable = compiler.link(object_files) - for i = 0; object_files[i] != NULL; i++: - free(object_files[i]) - free(object_files) - - # TODO: make the -o flag work - if args.mode == CompilerMode::CompileAndRun: - compiler.run(executable) - free(executable) - for i = 0; compiler.automagic_files[i] != NULL; i++: - free(compiler.automagic_files[i]) + ret = 0 + if command_line_args.outfile == NULL: + if command_line_args.verbosity >= 1: + printf("Run: %s\n", exepath) + ret = run_exe(exepath, command_line_args.valgrind) - else: - assert False + free(exepath) + + # not really necessary, but makes valgrind much happier + free_global_type_state() + cleanup_target() - return 0 + return ret diff --git a/self_hosted/output.jou b/self_hosted/output.jou new file mode 100644 index 00000000..53045365 --- /dev/null +++ b/self_hosted/output.jou @@ -0,0 +1,95 @@ +import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/process.jou" +import "stdlib/mem.jou" + +import "../config.jou" +import "./llvm.jou" +import "./target.jou" +import "./structs.jou" +import "./paths.jou" + + +def run_linker(objpaths: byte**, exepath: byte*) -> None: + jou_exe = find_current_executable() + instdir = dirname(jou_exe) + + if command_line_args.linker_flags != NULL: + linker_flags = malloc(strlen(command_line_args.linker_flags) + 50) + assert linker_flags != NULL + strcpy(linker_flags, "-lm ") + strcat(linker_flags, command_line_args.linker_flags) + else: + linker_flags = strdup("-lm") + assert linker_flags != NULL + + size = 10L + for i = 0; objpaths[i] != NULL; i++: + size += strlen(objpaths[i]) + 10 + + quoted_object_files: byte* = malloc(size) + assert quoted_object_files != NULL + quoted_object_files[0] = '\0' + + for i = 0; objpaths[i] != NULL; i++: + if i != 0: + strcat(quoted_object_files, " ") + strcat(quoted_object_files, "\"") + strcat(quoted_object_files, objpaths[i]) # TODO: escape properly? + strcat(quoted_object_files, "\"") + + size = strlen(instdir) + strlen(quoted_object_files) + strlen(exepath) + strlen(linker_flags) + 100 + if get_jou_clang_path() != NULL: + size += strlen(get_jou_clang_path()) + command: byte* = malloc(size) + + if WINDOWS: + # Assume mingw with clang has been downloaded with windows_setup.sh. + # Could also use clang, but gcc has less dependencies so we can make the Windows zips smaller. + # Windows quoting is weird. The outermost quotes get stripped here. + snprintf(command, size, "\"\"%s\\mingw64\\bin\\gcc.exe\" %s -o \"%s\" %s\"", instdir, quoted_object_files, exepath, linker_flags) + else: + # Assume clang is installed and use it to link. Could use lld, but clang is needed anyway. + # instdir is not used in this case. + snprintf(command, size, "'%s' %s -o '%s' %s", get_jou_clang_path(), quoted_object_files, exepath, linker_flags) + + free(quoted_object_files) + free(jou_exe) + free(linker_flags) + + if command_line_args.verbosity >= 2: + printf("Running linker: %s\n", command) + elif command_line_args.verbosity >= 1: + printf("Running linker\n") + + if system(command) != 0: + exit(1) + free(command) + + +def compile_to_object_file(module: LLVMModule*) -> byte*: + len = 0L + objname = get_filename_without_jou_suffix(LLVMGetSourceFileName(module, &len)) + + objname = realloc(objname, strlen(objname) + 10) + if WINDOWS: + strcat(objname, ".obj") + else: + strcat(objname, ".o") + + path = get_path_to_file_in_jou_compiled(objname) + free(objname) + + if command_line_args.verbosity >= 1: + printf("Emitting object file: %s\n", path) + + tmppath = strdup(path) + error: byte* = NULL + if LLVMTargetMachineEmitToFile(target.target_machine, module, tmppath, LLVMCodeGenFileType::ObjectFile, &error) != 0: + assert error != NULL + fprintf(stderr, "failed to emit object file \"%s\": %s\n", path, error) + exit(1) + free(tmppath) + + assert error == NULL + return path diff --git a/self_hosted/parser.jou b/self_hosted/parser.jou index bd3840bd..84e1cf36 100644 --- a/self_hosted/parser.jou +++ b/self_hosted/parser.jou @@ -695,7 +695,7 @@ class Parser: result = AstExpression{ location = as_location, kind = AstExpressionKind::As, - as_expression = p, + as_ = p, } return result diff --git a/self_hosted/paths.jou b/self_hosted/paths.jou index 8f52c706..93cfd9a7 100644 --- a/self_hosted/paths.jou +++ b/self_hosted/paths.jou @@ -3,6 +3,8 @@ import "stdlib/str.jou" import "stdlib/io.jou" import "stdlib/process.jou" +import "./structs.jou" + if WINDOWS: declare GetModuleFileNameA(hModule: void*, lpFilename: byte*, nSize: int) -> int elif MACOS: @@ -119,15 +121,91 @@ else: def my_mkdir(path: byte*) -> None: mkdir(path, 0o777) # this is what mkdir in bash does according to strace + +def write_gitignore(p: byte*) -> None: + filename: byte* = malloc(strlen(p) + 100) + sprintf(filename, "%s/.gitignore", p) + + f = fopen(filename, "r") + if f != NULL: + # already exists + fclose(f) + else: + # write '*' into gitignore, so that git won't track any compiled files + f = fopen(filename, "w") + if f != NULL: + fprintf(f, "*") + fclose(f) + + free(filename) + + +def mkdir_exist_ok(p: byte*) -> None: + # TODO: check if errno == EEXIST + # Currently no good way to access EEXIST constant + my_mkdir(p) + + def get_path_to_file_in_jou_compiled(filename: byte*) -> byte*: - # TODO: is placing jou_compiled to current working directory a good idea? - my_mkdir("jou_compiled") - my_mkdir("jou_compiled/self_hosted") + # Place compiled files so that it's difficult to get race conditions when + # compiling multiple Jou files simultaneously (tests do that) + tmp = strdup(command_line_args.infile) + infile_folder = strdup(dirname(tmp)) + free(tmp) + + subfolder = get_filename_without_jou_suffix(command_line_args.infile) + + result: byte* = malloc(strlen(infile_folder) + strlen(subfolder) + strlen(filename) + 100) + assert result != NULL + + sprintf(result, "%s/jou_compiled", infile_folder) + mkdir_exist_ok(result) + write_gitignore(result) - result: byte* = malloc(strlen(filename) + 100) - sprintf(result, "jou_compiled/self_hosted/%s", filename) + sprintf(result, "%s/jou_compiled/%s", infile_folder, subfolder) + mkdir_exist_ok(result) + + sprintf(result, "%s/jou_compiled/%s/%s", infile_folder, subfolder, filename) + + free(infile_folder) + free(subfolder) return result + + +def get_filename_without_jou_suffix(path: byte*) -> byte*: + last_slash = strrchr(path, '/') + if last_slash != NULL: + path = &last_slash[1] + + if WINDOWS: + last_slash = strrchr(path, '\\') + if last_slash != NULL: + path = &last_slash[1] + + len = strlen(path) + if len > 4 and ends_with(path, ".jou"): + len -= 4 + + result: byte* = malloc(len+1) + assert result != NULL + memcpy(result, path, len) + result[len] = '\0' + + return result + + +def get_default_exe_path() -> byte*: + name = get_filename_without_jou_suffix(command_line_args.infile) + if WINDOWS: + name = realloc(name, strlen(name) + 10) + strcat(name, ".exe") + + path = get_path_to_file_in_jou_compiled(name) + free(name) + return path + + # TODO: put this to stdlib? or does it do too much for a stdlib function? def delete_slice(start: byte*, end: byte*) -> None: memmove(start, end, strlen(end) + 1) diff --git a/self_hosted/print.jou b/self_hosted/print.jou new file mode 100644 index 00000000..6b9d44e4 --- /dev/null +++ b/self_hosted/print.jou @@ -0,0 +1,248 @@ +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" + +import "./llvm.jou" +import "./structs.jou" +import "./types.jou" + + +def print_string(s: byte*, len: int) -> None: + putchar('"') + for i = 0; i < len or (len == -1 and s[i] != '\0'); i++: + if 32 <= s[i] and s[i] <= 126: + # printable ascii character + putchar(s[i]) + elif s[i] == '\n': + printf("\\n") + else: + printf("\\x%02x", s[i]) # TODO: \x is not yet recognized by the tokenizer + putchar('"') + + +def print_constant(c: Constant*) -> None: + if c->kind == ConstantKind::EnumMember: + printf("enum member %d of %s", c->enum_member.memberidx, c->enum_member.enumtype->name) + elif c->kind == ConstantKind::Bool: + if c->boolean: + printf("True") + else: + printf("False") + elif c->kind == ConstantKind::Float: + printf("float %s", c->double_or_float_text) + elif c->kind == ConstantKind::Double: + printf("double %s", c->double_or_float_text) + elif c->kind == ConstantKind::Integer: + if c->integer.is_signed: + signed_or_unsigned = "signed" + else: + signed_or_unsigned = "unsigned" + printf( + "%lld (%d-bit %s)", + c->integer.value, + c->integer.size_in_bits, + signed_or_unsigned, + ) + elif c->kind == ConstantKind::Null: + printf("NULL") + elif c->kind == ConstantKind::String: + print_string(c->str, -1) + else: + assert False + + +global printed_varnames: byte[10][10] +global printed_varnames_idx: int + + + +def varname_for_printing(var: LocalVariable*) -> byte*: + if var->name[0] != '\0': + # it is named, not a dummy + return var->name + + # Cycle through enough space for a few variables, so that you + # can call this several times inside the same printf(). + s: byte* = printed_varnames[printed_varnames_idx++] + printed_varnames_idx %= (sizeof(printed_varnames) / sizeof(printed_varnames[0])) as int + + sprintf(s, "$%d", var->id) + return s + + +def very_short_number_type_description(t: Type*) -> byte*: + if t->kind == TypeKind::FloatingPoint: + return "floating" + if t->kind == TypeKind::SignedInteger: + return "signed" + if t->kind == TypeKind::UnsignedInteger: + return "unsigned" + assert False + + +def print_cf_instruction(ins: CfInstruction*) -> None: + printf(" line %-4d ", ins->location.lineno) + + if ins->destvar != NULL: + printf("%s = ", varname_for_printing(ins->destvar)) + + if ins->kind == CfInstructionKind::AddressOfLocalVar: + printf("address of %s (local variable)", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::AddressOfGlobalVar: + printf("address of %s (global variable)", ins->globalname) + elif ins->kind == CfInstructionKind::SizeOf: + printf("sizeof %s", ins->type->name) + elif ins->kind == CfInstructionKind::BoolNegate: + printf("boolean negation of %s", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::Call: + if get_self_class(&ins->signature) != NULL: + printf("call method %s.", get_self_class(&ins->signature)->name) + else: + printf("call function ") + printf("%s(", ins->signature.name) + for i = 0; i < ins->noperands; i++: + if i != 0: + printf(", ") + printf("%s", varname_for_printing(ins->operands[i])) + printf(")") + elif ins->kind == CfInstructionKind::NumCast: + printf( + "number cast %s (%d-bit %s --> %d-bit %s)", + varname_for_printing(ins->operands[0]), + ins->operands[0]->type->size_in_bits, + very_short_number_type_description(ins->operands[0]->type), + ins->destvar->type->size_in_bits, + very_short_number_type_description(ins->destvar->type)) + elif ins->kind == CfInstructionKind::EnumToInt32: + printf("cast %s from enum to 32-bit signed int", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::Int32ToEnum: + printf("cast %s from 32-bit signed int to enum", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::PtrToInt64: + printf("cast %s to 64-bit integer", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::Int64ToPtr: + printf("cast %s from 64-bit integer to pointer", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::Constant: + print_constant(&ins->constant) + elif ins->kind == CfInstructionKind::SpecialConstant: + printf("special constant \"%s\"", ins->scname) + elif ins->kind == CfInstructionKind::StringArray: + printf("string array ") + print_string(ins->strarray.str, ins->strarray.len) + elif ( + ins->kind == CfInstructionKind::NumAdd + or ins->kind == CfInstructionKind::NumSub + or ins->kind == CfInstructionKind::NumMul + or ins->kind == CfInstructionKind::NumDiv + or ins->kind == CfInstructionKind::NumMod + or ins->kind == CfInstructionKind::NumEq + or ins->kind == CfInstructionKind::NumLt + ): + if ins->kind == CfInstructionKind::NumAdd: + printf("num add ") + elif ins->kind == CfInstructionKind::NumSub: + printf("num sub ") + elif ins->kind == CfInstructionKind::NumMul: + printf("num mul ") + elif ins->kind == CfInstructionKind::NumDiv: + printf("num div ") + elif ins->kind == CfInstructionKind::NumMod: + printf("num mod ") + elif ins->kind == CfInstructionKind::NumEq: + printf("num eq ") + elif ins->kind == CfInstructionKind::NumLt: + printf("num lt ") + else: + assert False + printf("%s, %s", varname_for_printing(ins->operands[0]), varname_for_printing(ins->operands[1])) + elif ins->kind == CfInstructionKind::PtrLoad: + # Extra parentheses to make these stand out a bit. + printf("*(%s)", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::PtrStore: + printf("*(%s) = %s", varname_for_printing(ins->operands[0]), varname_for_printing(ins->operands[1])) + elif ins->kind == CfInstructionKind::PtrAddInt: + printf("ptr %s + integer %s", varname_for_printing(ins->operands[0]), varname_for_printing(ins->operands[1])) + elif ins->kind == CfInstructionKind::PtrClassField: + printf("%s + offset of field \"%s\"", varname_for_printing(ins->operands[0]), ins->fieldname) + elif ins->kind == CfInstructionKind::PtrCast: + printf("pointer cast %s", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::PtrMemsetToZero: + printf("set value of pointer %s to zero bytes", varname_for_printing(ins->operands[0])) + elif ins->kind == CfInstructionKind::VarCpy: + printf("%s", varname_for_printing(ins->operands[0])) + else: + assert False + printf("\n") + + +def print_control_flow_graph(cfg: CfGraph*) -> None: + printed_varnames_idx = 0 + + sigstr: byte* = signature_to_string(&cfg->signature, True, True) + printf("Function %s\n", sigstr) + free(sigstr) + + printf(" Variables:\n") + for var = cfg->locals; var < &cfg->locals[cfg->nlocals]; var++: + printf(" %-20s %s\n", varname_for_printing(*var), (*var)->type->name) + + for blockidx = 0; blockidx < cfg->n_all_blocks; blockidx++: + b = cfg->all_blocks[blockidx] + + printf(" Block %d", blockidx) + #printf(" at %p", b) + + if b == &cfg->start_block: + printf(" (start block)") + if b == &cfg->end_block: + assert b->ninstructions == 0 + printf(" is the end block.\n") + continue + + printf(":\n") + + for ins = b->instructions; ins < &b->instructions[b->ninstructions]; ins++: + print_cf_instruction(ins) + + if b == &cfg->end_block: + assert b->iftrue == NULL + assert b->iffalse == NULL + elif b->iftrue == NULL and b->iffalse == NULL: + printf(" Execution stops here. We have called a noreturn function.\n") + else: + trueidx = -1 + falseidx = -1 + for i = 0; i < cfg->n_all_blocks; i++: + if cfg->all_blocks[i] == b->iftrue: + trueidx = i + if cfg->all_blocks[i]==b->iffalse: + falseidx = i + assert trueidx != -1 + assert falseidx != -1 + if trueidx == falseidx: + printf(" Jump to block %d.\n", trueidx) + else: + assert b->branchvar != NULL + printf(" If %s is True jump to block %d, otherwise block %d.\n", + varname_for_printing(b->branchvar), trueidx, falseidx) + + printf("\n") + + +def print_control_flow_graphs(cfgfile: CfGraphFile*) -> None: + printf("===== Control Flow Graphs for file \"%s\" =====\n", cfgfile->filename) + for cfg = cfgfile->graphs; cfg < &cfgfile->graphs[cfgfile->ngraphs]; cfg++: + print_control_flow_graph(*cfg) + +def print_llvm_ir(module: LLVMModule*, is_optimized: bool) -> None: + if is_optimized: + opt_or_unopt = "Optimized" + else: + opt_or_unopt = "Unoptimized" + + len = 0L + filename = LLVMGetSourceFileName(module, &len) + printf("===== %s LLVM IR for file \"%.*s\" =====\n", opt_or_unopt, len as int, filename) + + s = LLVMPrintModuleToString(module) + puts(s) + LLVMDisposeMessage(s) diff --git a/self_hosted/run.jou b/self_hosted/run.jou new file mode 100644 index 00000000..4646992b --- /dev/null +++ b/self_hosted/run.jou @@ -0,0 +1,29 @@ +import "stdlib/str.jou" +import "stdlib/mem.jou" +import "stdlib/io.jou" +import "stdlib/process.jou" + + +def run_exe(exepath: byte*, valgrind: bool) -> int: + command = malloc(strlen(exepath) + 1000) + if WINDOWS: + sprintf(command, "\"%s\"", exepath) + while strstr(command, "/") != NULL: + *strstr(command, "/") = '\\' + else: + if valgrind: + sprintf(command, "valgrind -q --leak-check=full --show-leak-kinds=all --error-exitcode=1 '%s'", exepath) + else: + sprintf(command, "'%s'", exepath) + + # Make sure that everything else shows up before the user's prints. + fflush(stdout) + fflush(stderr) + + ret = system(command) + free(command) + + if ret == 0: + return 0 # success + else: + return 1 # TODO: extract actual error code / return value diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index d0b06d88..1eb1415b 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -1,16 +1 @@ -# This is a list of files that don't behave correctly when ran with the self-hosted compiler. tests/other_errors/missing_return.jou -tests/other_errors/missing_value_in_return.jou -tests/other_errors/noreturn_but_return_with_value.jou -tests/other_errors/noreturn_but_return_without_value.jou -tests/should_succeed/compiler_cli.jou -tests/should_succeed/linked_list.jou -tests/should_succeed/pointer.jou -tests/should_succeed/printf.jou -tests/other_errors/return_void.jou -tests/should_succeed/stderr.jou -tests/should_succeed/unused_import.jou -tests/wrong_type/cannot_be_indexed.jou -tests/wrong_type/index.jou -tests/should_succeed/method_by_value.jou -tests/wrong_type/self_annotation.jou diff --git a/self_hosted/structs.jou b/self_hosted/structs.jou new file mode 100644 index 00000000..e11ad2ef --- /dev/null +++ b/self_hosted/structs.jou @@ -0,0 +1,252 @@ +# TODO: delete this file, merge into others + +import "stdlib/str.jou" +import "stdlib/mem.jou" + +import "./llvm.jou" +import "./ast.jou" +import "./types.jou" +import "./errors_and_warnings.jou" + +class CommandLineArgs: + argv0: byte* # Program name + verbosity: int # How much debug/progress info to print, how many times -v/--verbose passed + valgrind: bool # true --> Use valgrind when runnning user's jou program + tokenize_only: bool # If true, tokenize the file passed on command line and don't actually compile anything + parse_only: bool # If true, parse the file passed on command line and don't actually compile anything + optlevel: int # Optimization level (0 don't optimize, 3 optimize a lot) + infile: byte* # The "main" Jou file (can import other files) + outfile: byte* # If not NULL, where to output executable + linker_flags: byte* # String that is appended to linking command + +# Command-line arguments are a global variable because I like it. +global command_line_args: CommandLineArgs + +# Constants can appear in AST and also compilation steps after AST. +enum ConstantKind: + EnumMember + Integer + Float + Double + String + Null + Bool + +class IntegerConstant: + size_in_bits: int + is_signed: bool + value: long + +class EnumMemberConstant: + enumtype: Type* + memberidx: int + +class Constant: + kind: ConstantKind + union: + integer: IntegerConstant + str: byte* + # TODO: rename double_or_float_text --> float_or_double_text to be consistent with AST + double_or_float_text: byte[100] # convenient because LLVM wants a string anyway + boolean: bool + enum_member: EnumMemberConstant + +def copy_constant(c: Constant) -> Constant: + if c.kind == ConstantKind::String: + c.str = strdup(c.str) + assert c.str != NULL + return c + +def int_constant(type: Type*, value: long) -> Constant: + assert is_integer_type(type) + return Constant{ + kind = ConstantKind::Integer, + integer = IntegerConstant{ + size_in_bits = type->size_in_bits, + is_signed = type->kind == TypeKind::SignedInteger, + value = value + } + } + +class Signature: + name: byte[100] # Function or method name. For methods it does not include the name of the class. + nargs: int + argtypes: Type** + argnames: byte[100]* + takes_varargs: bool # true for functions like printf() + # TODO: rename to return_type + returntype: Type* # NULL, if does not return a value + is_noreturn: bool + # TODO: rename to return_type_location + returntype_location: Location # meaningful even if returntype is NULL + + +class GlobalVariable: + name: byte[100] # Same as in user's code, never empty + type: Type* + defined_in_current_file: bool # not declare-only (e.g. stdout) or imported + usedptr: bool* # If non-NULL, set to true when the variable is used. This is how we detect unused imports. + +class LocalVariable: + id: int # Unique, but you can also compare pointers to Variable. + name: byte[100] # Same name as in user's code, empty for temporary variables created by compiler + type: Type* + is_argument: bool # First n variables are always the arguments + +class ExpressionTypes: + expr: AstExpression* # not owned + type: Type* + implicit_cast_type: Type* # NULL for no implicit cast + + # Flags to indicate whether special kinds of implicit casts happened + implicit_array_to_pointer_cast: bool # Foo[N] to Foo* + implicit_string_to_array_cast: bool # "..." to byte[N] + +enum ExportSymbolKind: + Function + Type + GlobalVar + +class ExportSymbol: + kind: ExportSymbolKind + name: byte[100] # TODO: maybe this should be 200 because it can be ClassName.method_name? or something else? + union: + funcsignature: Signature + type: Type* # ExportSymbolKind::Type and ExportSymbolKind::GlobalVar + +# Type information about a function or method defined in the current file. +class FunctionOrMethodTypes: + signature: Signature + expr_types: ExpressionTypes** + n_expr_types: int + locals: LocalVariable** + nlocals: int + +class TypeAndUsedPtr: + type: Type* + usedptr: bool* + +class SignatureAndUsedPtr: + signature: Signature + usedptr: bool* + +# Type information about a file. +class FileTypes: + current_fom_types: FunctionOrMethodTypes* # conceptually this is internal to typecheck.c + fomtypes: FunctionOrMethodTypes* + nfomtypes: int + globals: GlobalVariable* + nglobals: int + owned_types: Type** # These will be freed later + n_owned_types: int + types: TypeAndUsedPtr* + ntypes: int + functions: SignatureAndUsedPtr* + nfunctions: int + + +class CfStringArray: + str: byte* + len: int + +enum CfInstructionKind: + Constant + SpecialConstant # e.g. "WINDOWS", unlike CF_Constant this doesn't trigger "this code will never run" warnings + StringArray + Call # function or method call, depending on whether self_type is NULL (see below) + AddressOfLocalVar + AddressOfGlobalVar + SizeOf + PtrMemsetToZero # takes one operand, a pointer: memset(ptr, 0, sizeof(*ptr)) + PtrStore # *op1 = op2 (does not use destvar, takes 2 operands) + PtrLoad # aka dereference + PtrToInt64 + Int64ToPtr + PtrClassField # takes 1 operand (pointer), sets destvar to &op->fieldname + PtrCast + PtrAddInt + # Left and right side of number operations must be of the same type (except CfInstructionKind::NumCast). + NumAdd + NumSub + NumMul + NumDiv + NumMod + NumEq + NumLt + NumCast + EnumToInt32 + Int32ToEnum + BoolNegate # TODO: get rid of this? + VarCpy # similar to assignment statements: var1 = var2 + +# Control Flow Graph. +# Struct names not prefixed with Cfg because it looks too much like "config" to me +class CfInstruction: + location: Location + kind: CfInstructionKind + union: + constant: Constant # CfInstructionKind::Constant + strarray: CfStringArray # CfInstructionKind::StringArray + signature: Signature # CfInstructionKind::Call + fieldname: byte[100] # CfInstructionKind::PtrClassField + globalname: byte[100] # CfInstructionKind::AddressOfGlobalVar + scname: byte[100] # CfInstructionKind::SpecialConstant + type: Type* # CfInstructionKind::SizeOf + + operands: LocalVariable** # e.g. numbers to add, function arguments + noperands: int + destvar: LocalVariable* # NULL when it doesn't make sense, e.g. functions that return void + hide_unreachable_warning: bool # usually false, can be set to true to avoid unreachable warning false positives + + # operands should be NULL-terminated array, or NULL for empty + # TODO: does it ever need to be NULL? + # TODO: do we need this method at all? + def set_operands(self, operands: LocalVariable**) -> None: + self->noperands = 0 + while operands != NULL and operands[self->noperands] != NULL: + self->noperands++ + + nbytes = sizeof(self->operands[0]) * self->noperands + self->operands = malloc(nbytes) + assert self->operands != NULL + memcpy(self->operands, operands, nbytes) + + +class CfBlock: + instructions: CfInstruction* + ninstructions: int + branchvar: LocalVariable* # boolean value used to decide where to jump next + + # iftrue and iffalse are NULL for special end block and after calling a noreturn function. + # When iftrue and iffalse are the same, the branchvar is not used and may be NULL. + iftrue: CfBlock* + iffalse: CfBlock* + +class CfGraph: + signature: Signature + start_block: CfBlock # First block + end_block: CfBlock # Always empty. Return statement jumps here. + all_blocks: CfBlock** + n_all_blocks: int + locals: LocalVariable** # First n variables are the function arguments + nlocals: int + +class CfGraphFile: + filename: byte* + graphs: CfGraph** # only for defined functions + ngraphs: int + + +# LLVM makes a mess of how to define what kind of computer will run the +# compiled programs. Sometimes it wants a target triple, sometimes a +# data layout. Sometimes it wants a string, sometimes an object +# representing the thing. +# +# This struct aims to provide everything you may ever need. Hopefully it +# will make the mess slightly less miserable to you. +class Target: + triple: byte[100] + data_layout: byte[500] + target_ref: LLVMTarget* + target_machine_ref: LLVMTargetMachine* + target_data_ref: LLVMTargetData* diff --git a/self_hosted/target.jou b/self_hosted/target.jou index a06d91f2..3495d676 100644 --- a/self_hosted/target.jou +++ b/self_hosted/target.jou @@ -21,11 +21,6 @@ class Target: global target: Target -# TODO: run this with atexit() once we have function pointers -#def cleanup() -> None: -# LLVMDisposeTargetMachine(target.target_machine) -# LLVMDisposeTargetData(target.target_data) - def init_target() -> None: LLVMInitializeX86TargetInfo() LLVMInitializeX86Target() @@ -68,3 +63,7 @@ def init_target() -> None: assert strlen(tmp) < sizeof target.data_layout strcpy(target.data_layout, tmp) LLVMDisposeMessage(tmp) + +def cleanup_target() -> None: + LLVMDisposeTargetMachine(target.target_machine) + LLVMDisposeTargetData(target.target_data) diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index 3eddb036..15bff2ef 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -1,571 +1,454 @@ -# Type checking is split into several stages: -# 1. Create types. After this, classes defined in Jou exist, but -# they are opaque and contain no members. Enums exist and contain -# their members (although it doesn't really matter whether enum -# members are handled in stage 1 or 2). -# 2. Check signatures, global variables and class bodies, but ignore -# bodies of functions and methods. This stage assumes that all -# types exist, but doesn't need to know what fields each class has. -# 3. Check function and method bodies. -# -# The goal of this design is to make cyclic imports possible. At each -# stage, we don't need the results from the same stage, only from -# previous stages. This means that cyclic imports "just work" if we do -# each stage on all files before moving on to the next stage. - -import "stdlib/io.jou" import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/math.jou" import "stdlib/mem.jou" -import "./ast.jou" + +import "./structs.jou" +import "./evaluate.jou" import "./types.jou" +import "./ast.jou" import "./errors_and_warnings.jou" -import "./evaluate.jou" - - -def can_cast_implicitly(from: Type*, to: Type*) -> bool: - # TODO: document these properly. But they are: - # array to pointer, e.g. int[3] --> int* (needs special-casing elsewhere) - # from one integer type to another bigger integer type, unless it is signed-->unsigned - # between two pointer types when one of the two is void* - # from float to double (TODO) - return ( - from == to - or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) - or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) - or ( - from->is_integer_type() - and to->is_integer_type() - and from->size_in_bits < to->size_in_bits - and not (from->kind == TypeKind::SignedInteger and to->kind == TypeKind::UnsignedInteger) - ) - or (from == &float_type and to == &double_type) - or (from->is_integer_type() and to->kind == TypeKind::FloatingPoint) - or (from->is_pointer_type() and to->is_pointer_type() and (from == &void_ptr_type or to == &void_ptr_type)) - ) -def can_cast_explicitly(from: Type*, to: Type*) -> bool: - return ( - from == to - or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) - or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) - or (from->is_pointer_type() and to->is_pointer_type()) - or (from->is_number_type() and to->is_number_type()) - or (from->is_integer_type() and to->kind == TypeKind::Enum) - or (from->kind == TypeKind::Enum and to->is_integer_type()) - or (from == &bool_type and to->is_integer_type()) - or (from->is_pointer_type() and to == long_type) - or (from == long_type and to->is_pointer_type()) - ) +def find_type(ft: FileTypes*, name: byte*) -> Type*: + for t = ft->types; t < &ft->types[ft->ntypes]; t++: + if strcmp(t->type->name, name) == 0: + if t->usedptr != NULL: + *t->usedptr = True + return t->type + return NULL + +def find_function(ft: FileTypes*, name: byte*) -> Signature*: + for f = ft->functions; f < &ft->functions[ft->nfunctions]; f++: + if strcmp(f->signature.name, name) == 0: + if f->usedptr != NULL: + *f->usedptr = True + return &f->signature + return NULL + +def find_method(selfclass: Type*, name: byte*) -> Signature*: + if selfclass->kind != TypeKind::Class: + return NULL + for m = selfclass->classdata.methods; m < &selfclass->classdata.methods[selfclass->classdata.nmethods]; m++: + if strcmp(m->name, name) == 0: + return m + return NULL + +def find_function_or_method(ft: FileTypes*, selfclass: Type*, name: byte*) -> Signature*: + if selfclass != NULL: + return find_method(selfclass, name) + else: + return find_function(ft, name) + +def find_local_var(ft: FileTypes*, name: byte*) -> LocalVariable*: + if ft->current_fom_types != NULL: + for var = ft->current_fom_types->locals; var < &ft->current_fom_types->locals[ft->current_fom_types->nlocals]; var++: + if strcmp((*var)->name, name) == 0: + return *var + return NULL + +def find_any_var(ft: FileTypes*, name: byte*) -> Type*: + if get_special_constant(name) != -1: + return boolType + if ft->current_fom_types != NULL: + for lvar = ft->current_fom_types->locals; lvar < &ft->current_fom_types->locals[ft->current_fom_types->nlocals]; lvar++: + if strcmp((*lvar)->name, name) == 0: + return (*lvar)->type + for gvar = ft->globals; gvar < &ft->globals[ft->nglobals]; gvar++: + if strcmp(gvar->name, name) == 0: + if gvar->usedptr != NULL: + *gvar->usedptr = True + return gvar->type + return NULL -# Implicit casts are used in many places, e.g. function arguments. -# -# When you pass an argument of the wrong type, it's best to give an error message -# that says so, instead of some generic "expected type foo, got object of type bar" -# kind of message. -# -# The template can contain "" and "". They will be substituted with names -# of types. We cannot use printf() style functions because the arguments can be in -# any order. -def fail_with_implicit_cast_error(location: Location, template: byte*, from: Type*, to: Type*) -> None: - assert template != NULL +def short_type_description(t: Type*) -> byte*: + if t->kind == TypeKind::OpaqueClass or t->kind == TypeKind::Class: + return "a class" + if t->kind == TypeKind::Enum: + return "an enum" + if t->kind == TypeKind::VoidPointer or t->kind == TypeKind::Pointer: + return "a pointer type" + if ( + t->kind == TypeKind::SignedInteger + or t->kind == TypeKind::UnsignedInteger + or t->kind == TypeKind::FloatingPoint + ): + return "a number type" + if t->kind == TypeKind::Array: + return "an array type" + if t->kind == TypeKind::Bool: + return "the built-in bool type" + assert False - n = 0 - for i = 0; template[i] != '\0'; i++: - if template[i] == '<': - n++ +def typecheck_stage1_create_types(ft: FileTypes*, ast: AstFile*) -> ExportSymbol*: + exports: ExportSymbol* = NULL + nexports = 0 - message: byte* = malloc(sizeof(from->name)*n + strlen(template) + 1) - message[0] = '\0' - while *template != '\0': - if starts_with(template, ""): - template = &template[6] - strcat(message, from->name) - elif starts_with(template, ""): - template = &template[4] - strcat(message, to->name) + for i = 0; i < ast->body.nstatements; i++: + stmt = &ast->body.statements[i] + + name: byte[100] + if stmt->kind == AstStatementKind::Class: + assert sizeof(name) == sizeof(stmt->classdef.name) + strcpy(name, stmt->classdef.name) + t = create_opaque_class(name) + elif stmt->kind == AstStatementKind::Enum: + assert sizeof(name) == sizeof(stmt->enumdef.name) + strcpy(name, stmt->enumdef.name) + t = create_enum(name, stmt->enumdef.member_count, stmt->enumdef.member_names) else: - s = [*template++, '\0'] - strcat(message, s) - - fail(location, message) + continue + existing = find_type(ft, name) + if existing != NULL: + msg: byte[500] + snprintf(msg, sizeof(msg), "%s named '%s' already exists", short_type_description(existing), name) + fail(stmt->location, msg) -# To understand the purpose of ExportSymbol, suppose file A imports file B. -# - Type checking file B produces an ExportSymbol that matches the import in file A. -# - Before the next type checking stage, the ExportSymbol is added to file A's types. -# - During the next stage, file A can use the imported symbol. -enum ExportSymbolKind: - Function - Type - GlobalVariable - -class ExportSymbol: - kind: ExportSymbolKind - name: byte[100] - - union: - signature: Signature # ExportSymbolKind::Function - type: Type* # ExportSymbolKind::Type, ExportSymbolKind::GlobalVariable - - def print(self) -> None: - if self->kind == ExportSymbolKind::Function: - s = self->signature.to_string(True, True) - printf("ExportSymbol: function %s\n", s) - free(s) - elif self->kind == ExportSymbolKind::Type: - printf("ExportSymbol: type %s as \"%s\"\n", self->type->name, self->name) - elif self->kind == ExportSymbolKind::GlobalVariable: - printf("ExportSymbol: variable %s: %s\n", self->name, self->type->name) - else: - assert False + ft->types = realloc(ft->types, sizeof(ft->types[0]) * (ft->ntypes + 1)) + assert ft->types != NULL + ft->types[ft->ntypes++] = TypeAndUsedPtr{type=t, usedptr=NULL} -class ExpressionTypes: - expression: AstExpression* - original_type: Type* - implicit_cast_type: Type* # NULL if no implicit casting is needed - next: ExpressionTypes* # TODO: switch to more efficient structure than linked list? - - # Flags to indicate whether special kinds of implicit casts happened - implicit_array_to_pointer_cast: bool # Foo[N] to Foo* - implicit_string_to_array_cast: bool # "..." to byte[N] - - def get_type_after_implicit_cast(self) -> Type*: - assert self->original_type != NULL - if self->implicit_cast_type == NULL: - return self->original_type - return self->implicit_cast_type - - # TODO: error_location is probably unnecessary, can get location from self->expression - def do_implicit_cast(self, to: Type*, error_location: Location, error_template: byte*) -> None: - # This cannot be called multiple times - assert self->implicit_cast_type == NULL - assert not self->implicit_array_to_pointer_cast - assert not self->implicit_string_to_array_cast - - from = self->original_type - if from == to: - return + ft->owned_types = realloc(ft->owned_types, sizeof(ft->owned_types[0]) * (ft->n_owned_types + 1)) + assert ft->owned_types != NULL + ft->owned_types[ft->n_owned_types++] = t - if ( - self->expression->kind == AstExpressionKind::String - and from == byte_type->get_pointer_type() - and to->kind == TypeKind::Array - and to->array.item_type == byte_type - ): - string_size = strlen(self->expression->string) + 1 - if to->array.length < string_size: - message: byte[100] - snprintf( - message, sizeof message, - "a string of %d bytes (including '\\0') does not fit into %s", - string_size, to->name, - ) - fail(error_location, message) - self->implicit_string_to_array_cast = True - # Passing in NULL for error_template can be used to force a cast to happen. - elif error_template != NULL and not can_cast_implicitly(from, to): - fail_with_implicit_cast_error(error_location, error_template, from, to) - - self->implicit_cast_type = to - if from->kind == TypeKind::Array and to->is_pointer_type(): - self->implicit_array_to_pointer_cast = True - ensure_can_take_address( - self->expression, - "cannot create a pointer into an array that comes from %s (try storing it to a local variable first)", - ) + es = ExportSymbol{kind = ExportSymbolKind::Type, type = t} + assert sizeof(es.name) == sizeof(name) + strcpy(es.name, name) - # Does not store the new type to self, because explicit casts have their own AstExpression which has its own expression types. - def do_explicit_cast(self, to: Type*, error_location: Location) -> None: - assert self->implicit_cast_type == NULL - assert not self->implicit_array_to_pointer_cast - - from = self->original_type - if not can_cast_explicitly(from, to): - message: byte[500] - snprintf(&message[0], sizeof message, "cannot cast from type %s to %s", from->name, to->name) - fail(error_location, message) - - if from->kind == TypeKind::Array and to->is_pointer_type(): - self->cast_array_to_pointer() - - def cast_array_to_pointer(self) -> None: - assert self->original_type->kind == TypeKind::Array - self->do_implicit_cast(self->original_type->array.item_type->get_pointer_type(), Location{}, NULL) - -class LocalVariable: - name: byte[100] - type: Type* - next: LocalVariable* # TODO: switch to more efficient structure than linked list? - -class GlobalVariable: - name: byte[100] - type: Type* - -class FunctionOrMethodTypes: - signature: Signature - expression_types: ExpressionTypes* - local_vars: LocalVariable* - - def get_expression_types(self, expr: AstExpression*) -> ExpressionTypes*: - for et = self->expression_types; et != NULL; et = et->next: - if et->expression == expr: - return et - return NULL + exports = realloc(exports, sizeof(exports[0]) * (nexports + 1)) + assert exports != NULL + exports[nexports++] = es - def find_local_var(self, name: byte*) -> LocalVariable*: - for v = self->local_vars; v != NULL; v = v->next: - if strcmp(v->name, name) == 0: - return v - return NULL + exports = realloc(exports, sizeof(exports[0]) * (nexports + 1)) + assert exports != NULL + exports[nexports] = ExportSymbol{} # list terminator + return exports -# All type information for a Jou file. This is initially empty, and is filled during each stage of type checking. -class FileTypes: - # Includes imported and defined functions. - all_functions: Signature* - n_all_functions: int - - defined_functions: FunctionOrMethodTypes* - n_defined_functions: int - - types: Type** - ntypes: int - - globals: GlobalVariable* - nglobals: int - - def add_imported_symbol(self, symbol: ExportSymbol*) -> None: - if symbol->kind == ExportSymbolKind::Type: - self->types = realloc(self->types, (self->ntypes + 1) * sizeof(self->types[0])) - self->types[self->ntypes++] = symbol->type - elif symbol->kind == ExportSymbolKind::Function: - self->all_functions = realloc(self->all_functions, sizeof self->all_functions[0] * (self->n_all_functions + 1)) - self->all_functions[self->n_all_functions++] = symbol->signature.copy() - elif symbol->kind == ExportSymbolKind::GlobalVariable: - pass # TODO - else: - symbol->print() - assert False +def evaluate_array_length(expr: AstExpression*) -> int: + if expr->kind == AstExpressionKind::Int: + return expr->int_value + fail(expr->location, "cannot evaluate array length at compile time") + +def is_void(t: AstType*) -> bool: + return t->kind == AstTypeKind::Named and strcmp(t->name, "void") == 0 + +def is_none(t: AstType*) -> bool: + return t->kind == AstTypeKind::Named and strcmp(t->name, "None") == 0 + +def is_noreturn(t: AstType*) -> bool: + return t->kind == AstTypeKind::Named and strcmp(t->name, "noreturn") == 0 + +def type_from_ast(ft: FileTypes*, asttype: AstType*) -> Type*: + msg: byte[500] + + if is_void(asttype) or is_none(asttype) or is_noreturn(asttype): + snprintf(msg, sizeof(msg), "'%s' cannot be used here because it is not a type", asttype->name) + fail(asttype->location, msg) + + if asttype->kind == AstTypeKind::Named: + if strcmp(asttype->name, "short") == 0: + return shortType + if strcmp(asttype->name, "int") == 0: + return intType + if strcmp(asttype->name, "long") == 0: + return longType + if strcmp(asttype->name, "byte") == 0: + return byteType + if strcmp(asttype->name, "bool") == 0: + return boolType + if strcmp(asttype->name, "float") == 0: + return floatType + if strcmp(asttype->name, "double") == 0: + return doubleType + + found = find_type(ft, asttype->name) + if found != NULL: + return found + + snprintf(msg, sizeof(msg), "there is no type named '%s'", asttype->name) + fail(asttype->location, msg) + + if asttype->kind == AstTypeKind::Pointer: + if is_void(asttype->value_type): + return voidPtrType + return get_pointer_type(type_from_ast(ft, asttype->value_type)) + + if asttype->kind == AstTypeKind::Array: + tmp = type_from_ast(ft, asttype->value_type) + len = evaluate_array_length(asttype->array.length) + if len <= 0: + fail(asttype->array.length->location, "array length must be positive") + return get_array_type(tmp, len) - def find_function(self, name: byte*) -> Signature*: - for i = 0; i < self->n_all_functions; i++: - if strcmp(self->all_functions[i].name, name) == 0: - return &self->all_functions[i] - return NULL + assert False - # If class_type is NULL, this finds a function - def find_defined_function_or_method(self, name: byte*, class_type: Type*) -> FunctionOrMethodTypes*: - assert class_type == NULL or class_type->kind == TypeKind::Class - for i = 0; i < self->n_defined_functions; i++: - if ( - strcmp(self->defined_functions[i].signature.name, name) == 0 - and self->defined_functions[i].signature.get_containing_class() == class_type - ): - return &self->defined_functions[i] - return NULL +def handle_global_var(ft: FileTypes*, vardecl: AstNameTypeValue*, defined_here: bool) -> ExportSymbol: + assert ft->current_fom_types == NULL # find_any_var() only finds global vars + if find_any_var(ft, vardecl->name) != NULL: + msg: byte[500] + snprintf(msg, sizeof(msg), "a global variable named '%s' already exists", vardecl->name) + fail(vardecl->name_location, msg) + + assert vardecl->value == NULL + g = GlobalVariable{ + type = type_from_ast(ft, &vardecl->type), + defined_in_current_file = defined_here, + } - def find_type(self, name: byte*) -> Type*: - for i = 0; i < self->ntypes; i++: - if strcmp(self->types[i]->name, name) == 0: - return self->types[i] - return NULL + assert sizeof(g.name) == sizeof(vardecl->name) + strcpy(g.name, vardecl->name) - def find_global_var(self, name: byte*) -> Type*: - for i = 0; i < self->nglobals; i++: - if strcmp(self->globals[i].name, name) == 0: - return self->globals[i].type - return NULL + ft->globals = realloc(ft->globals, sizeof(ft->globals[0]) * (ft->nglobals + 1)) + assert ft->globals != NULL + ft->globals[ft->nglobals++] = g -def check_type_doesnt_exist(ft: FileTypes*, name: byte*, location: Location) -> None: - existing = ft->find_type(name) - if existing != NULL: - description = short_type_description(existing) - message: byte[500] - snprintf(message, sizeof message, "%s named '%s' already exists", description, name) - fail(location, message) + es = ExportSymbol{kind = ExportSymbolKind::GlobalVar, type = g.type} + assert sizeof(es.name) == sizeof(g.name) + strcpy(es.name, g.name) + return es -def typecheck_stage1_create_types(ft: FileTypes*, file: AstFile*) -> ExportSymbol*: - exports: ExportSymbol* = NULL - nexports = 0 +def handle_signature(ft: FileTypes*, astsig: AstSignature*, self_class: Type*) -> Signature: + msg: byte[500] - for i = 0; i < file->body.nstatements; i++: - if file->body.statements[i].kind == AstStatementKind::Class: - classdef = &file->body.statements[i].classdef - check_type_doesnt_exist(ft, classdef->name, classdef->name_location) - t = create_opaque_class(classdef->name) - elif file->body.statements[i].kind == AstStatementKind::Enum: - enumdef = &file->body.statements[i].enumdef - check_type_doesnt_exist(ft, enumdef->name, enumdef->name_location) - t = create_enum(enumdef->name, enumdef->member_count, enumdef->member_names) + if find_function_or_method(ft, self_class, astsig->name) != NULL: + if self_class != NULL: + snprintf(msg, sizeof(msg), "a method named '%s' already exists", astsig->name) else: - continue - - ft->types = realloc(ft->types, (ft->ntypes + 1) * sizeof ft->types[0]) - ft->types[ft->ntypes++] = t - exports = realloc(exports, (nexports + 1) * sizeof exports[0]) - exports[nexports++] = ExportSymbol{ - kind = ExportSymbolKind::Type, - name = t->name, - type = t, - } - - exports = realloc(exports, sizeof exports[0] * (nexports + 1)) - exports[nexports] = ExportSymbol{} - return exports + snprintf(msg, sizeof(msg), "a function named '%s' already exists", astsig->name) + fail(astsig->name_location, msg) + sig = Signature{nargs = astsig->nargs, takes_varargs = astsig->takes_varargs} + assert sizeof(sig.name) == sizeof(astsig->name) + strcpy(sig.name, astsig->name) -def evaluate_array_length(expression: AstExpression*) -> int: - # TODO: support something more fancy? - if expression->kind == AstExpressionKind::Int: - return expression->int_value - fail(expression->location, "cannot evaluate array length at compile time") - -def type_from_ast(ft: FileTypes*, ast_type: AstType*) -> Type*: - if ast_type->is_void(): - fail(ast_type->location, "'void' cannot be used here because it is not a type") - if ast_type->is_none(): - fail(ast_type->location, "'None' cannot be used here because it is not a type") - if ast_type->is_noreturn(): - fail(ast_type->location, "'noreturn' cannot be used here because it is not a type") - - if ast_type->kind == AstTypeKind::Named: - if strcmp(ast_type->name, "short") == 0: - return short_type - if strcmp(ast_type->name, "int") == 0: - return int_type - if strcmp(ast_type->name, "long") == 0: - return long_type - if strcmp(ast_type->name, "byte") == 0: - return byte_type - if strcmp(ast_type->name, "bool") == 0: - return &bool_type - if strcmp(ast_type->name, "float") == 0: - return &float_type - if strcmp(ast_type->name, "double") == 0: - return &double_type - - result = ft->find_type(ast_type->name) - if result != NULL: - return result - - message: byte* = malloc(strlen(ast_type->name) + 100) - sprintf(message, "there is no type named '%s'", ast_type->name) - fail(ast_type->location, message) - - if ast_type->kind == AstTypeKind::Pointer: - if ast_type->value_type->is_void(): - return &void_ptr_type - return type_from_ast(ft, ast_type->value_type)->get_pointer_type() - - if ast_type->kind == AstTypeKind::Array: - member_type = type_from_ast(ft, ast_type->array.member_type) - length = evaluate_array_length(ast_type->array.length) - if length <= 0: - fail(ast_type->array.length->location, "array length must be positive") - return member_type->get_array_type(length) - - ast_type->print(True) - printf("\n") - assert False # TODO - -def handle_signature(ft: FileTypes*, astsig: AstSignature*, self_type: Type*) -> Signature: - assert self_type == NULL or self_type->kind == TypeKind::Class - - sig = Signature{ - name = astsig->name, - nargs = astsig->nargs, - takes_varargs = astsig->takes_varargs, - } - - sig.argnames = malloc(sizeof sig.argnames[0] * sig.nargs) + size = sizeof(sig.argnames[0]) * sig.nargs + sig.argnames = malloc(size) for i = 0; i < sig.nargs; i++: - sig.argnames[i] = astsig->args[i].name + assert sizeof(sig.argnames[i]) == sizeof(astsig->args[i].name) + strcpy(sig.argnames[i], astsig->args[i].name) - sig.argtypes = malloc(sizeof sig.argtypes[0] * sig.nargs) + sig.argtypes = malloc(sizeof(sig.argtypes[0]) * sig.nargs) for i = 0; i < sig.nargs; i++: - if strcmp(astsig->args[i].name, "self") == 0: - assert self_type != NULL - sig.argtypes[i] = self_type->get_pointer_type() + if ( + strcmp(sig.argnames[i], "self") == 0 + and astsig->args[i].type.kind == AstTypeKind::Named + and astsig->args[i].type.name[0] == '\0' + ): + # just "self" without a type after it --> default to "self: Foo*" in class Foo + argtype = get_pointer_type(self_class) else: - sig.argtypes[i] = type_from_ast(ft, &astsig->args[i].type) + argtype = type_from_ast(ft, &astsig->args[i].type) + + if strcmp(sig.argnames[i], "self") == 0 and argtype != self_class and argtype != get_pointer_type(self_class): + snprintf(msg, sizeof(msg), "type of self must be %s* (default) or %s", self_class->name, self_class->name) + fail(astsig->args[i].type.location, msg) + + sig.argtypes[i] = argtype - if astsig->return_type.is_none() or astsig->return_type.is_noreturn(): - sig.return_type = NULL + sig.is_noreturn = is_noreturn(&astsig->return_type) + if is_none(&astsig->return_type) or is_noreturn(&astsig->return_type): + sig.returntype = NULL + elif is_void(&astsig->return_type): + fail(astsig->return_type.location, "void is not a valid return type, use '-> None' if the function does not return a value") else: - sig.return_type = type_from_ast(ft, &astsig->return_type) + sig.returntype = type_from_ast(ft, &astsig->return_type) - if self_type == NULL and strcmp(sig.name, "main") == 0: + if self_class == NULL and strcmp(sig.name, "main") == 0: # special main() function checks - if sig.return_type != int_type: + if sig.returntype != intType: fail(astsig->return_type.location, "the main() function must return int") - if sig.nargs != 0 and not ( - sig.nargs == 2 - and sig.argtypes[0] == int_type - and sig.argtypes[1] == byte_type->get_pointer_type()->get_pointer_type() + if ( + sig.nargs != 0 + and not ( + sig.nargs == 2 + and sig.argtypes[0] == intType + and sig.argtypes[1] == get_pointer_type(get_pointer_type(byteType)) + ) ): fail( astsig->args[0].type.location, "if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int" ) + sig.returntype_location = astsig->return_type.location + + if self_class == NULL: + ft->functions = realloc(ft->functions, sizeof(ft->functions[0]) * (ft->nfunctions + 1)) + assert ft->functions != NULL + ft->functions[ft->nfunctions++] = SignatureAndUsedPtr{ + signature = copy_signature(&sig), + usedptr = NULL, + } + return sig def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None: - # Previous type-checking stage created an opaque type. + # Previous type-checking stage created an opaque struct. type: Type* = NULL - for i = 0; i < ft->ntypes; i++: - if strcmp(ft->types[i]->name, classdef->name) == 0: - type = ft->types[i] + for s = ft->owned_types; s < &ft->owned_types[ft->n_owned_types]; s++: + if strcmp((*s)->name, classdef->name) == 0: + type = *s break - assert type != NULL + assert type != NULL assert type->kind == TypeKind::OpaqueClass type->kind = TypeKind::Class - memset(&type->class_members, 0, sizeof type->class_members) + memset(&type->classdata, 0, sizeof type->classdata) union_id = 0 - for i = 0; i < classdef->nmembers; i++: - member = &classdef->members[i] - if member->kind == AstClassMemberKind::Field: - type->class_members.fields = realloc(type->class_members.fields, (type->class_members.nfields + 1) * sizeof type->class_members.fields[0]) - type->class_members.fields[type->class_members.nfields++] = ClassField{ - name = member->field.name, - type = type_from_ast(ft, &member->field.type), + for m = classdef->members; m < &classdef->members[classdef->nmembers]; m++: + if m->kind == AstClassMemberKind::Field: + f = ClassField{ + type = type_from_ast(ft, &m->field.type), union_id = union_id++, } - elif member->kind == AstClassMemberKind::Union: + assert sizeof(f.name) == sizeof(m->field.name) + strcpy(f.name, m->field.name) + + type->classdata.fields = realloc(type->classdata.fields, sizeof(type->classdata.fields[0]) * (type->classdata.nfields + 1)) + assert type->classdata.fields != NULL + type->classdata.fields[type->classdata.nfields++] = f + + elif m->kind == AstClassMemberKind::Union: uid = union_id++ - for k = 0; k < member->union_fields.nfields; k++: - type->class_members.fields = realloc(type->class_members.fields, (type->class_members.nfields + 1) * sizeof type->class_members.fields[0]) - type->class_members.fields[type->class_members.nfields++] = ClassField{ - name = member->union_fields.fields[k].name, - type = type_from_ast(ft, &member->union_fields.fields[k].type), + for ntv = m->union_fields.fields; ntv < &m->union_fields.fields[m->union_fields.nfields]; ntv++: + f = ClassField{ + type = type_from_ast(ft, &ntv->type), union_id = uid, } - elif member->kind == AstClassMemberKind::Method: + assert sizeof(f.name) == sizeof(ntv->name) + strcpy(f.name, ntv->name) + + type->classdata.fields = realloc(type->classdata.fields, sizeof(type->classdata.fields[0]) * (type->classdata.nfields + 1)) + assert type->classdata.fields != NULL + type->classdata.fields[type->classdata.nfields++] = f + + elif m->kind == AstClassMemberKind::Method: # Don't handle the method body yet: that is a part of stage 3, not stage 2 - sig = handle_signature(ft, &member->method.signature, type) - type->class_members.methods = realloc(type->class_members.methods, sizeof type->class_members.methods[0] * (type->class_members.nmethods + 1)) - type->class_members.methods[type->class_members.nmethods++] = sig + sig = handle_signature(ft, &m->method.signature, type) + + type->classdata.methods = realloc(type->classdata.methods, sizeof(type->classdata.methods[0]) * (type->classdata.nmethods + 1)) + assert type->classdata.methods != NULL + type->classdata.methods[type->classdata.nmethods++] = sig + else: assert False -# Returned array is terminated by ExportSymbol with empty name. -def typecheck_stage2_populate_types(ft: FileTypes*, ast_file: AstFile*) -> ExportSymbol*: - message: byte[200] +def typecheck_stage2_populate_types(ft: FileTypes*, ast: AstFile*) -> ExportSymbol*: exports: ExportSymbol* = NULL nexports = 0 - for i = 0; i < ast_file->body.nstatements; i++: - ts = &ast_file->body.statements[i] - - if ts->kind == AstStatementKind::Function: - if ft->find_function(ts->function.signature.name) != NULL: - snprintf( - message, sizeof message, - "a function named '%s' already exists", - ts->function.signature.name, - ) - fail(ts->location, message) - - sig = handle_signature(ft, &ts->function.signature, NULL) - ft->all_functions = realloc(ft->all_functions, sizeof ft->all_functions[0] * (ft->n_all_functions + 1)) - ft->all_functions[ft->n_all_functions++] = sig.copy() - exports = realloc(exports, sizeof exports[0] * (nexports + 1)) - exports[nexports++] = ExportSymbol{ - kind = ExportSymbolKind::Function, - name = sig.name, - signature = sig, - } - - if ts->kind == AstStatementKind::Class: - handle_class_members_stage2(ft, &ts->classdef) - - if ( - ts->kind == AstStatementKind::GlobalVariableDeclaration - or ts->kind == AstStatementKind::GlobalVariableDefinition - ): - if ft->find_global_var(ts->var_declaration.name) != NULL: - snprintf( - message, sizeof message, - "a global variable named '%s' already exists", - ts->var_declaration.name, - ) - fail(ts->location, message) - - assert ts->var_declaration.value == NULL - type = type_from_ast(ft, &ts->var_declaration.type) - ft->globals = realloc(ft->globals, (ft->nglobals + 1) * sizeof ft->globals[0]) - ft->globals[ft->nglobals++] = GlobalVariable{name = ts->var_declaration.name, type = type} - - exports = realloc(exports, sizeof exports[0] * (nexports + 1)) - exports[nexports++] = ExportSymbol{ - kind = ExportSymbolKind::GlobalVariable, - name = ts->var_declaration.name, - type = type, - } + for i = 0; i < ast->body.nstatements; i++: + stmt = &ast->body.statements[i] + + exports = realloc(exports, sizeof(exports[0]) * (nexports + 1)) + assert exports != NULL + + if stmt->kind == AstStatementKind::GlobalVariableDeclaration: + exports[nexports++] = handle_global_var(ft, &stmt->var_declaration, False) + elif stmt->kind == AstStatementKind::GlobalVariableDefinition: + exports[nexports++] = handle_global_var(ft, &stmt->var_declaration, True) + elif stmt->kind == AstStatementKind::Function: + sig = handle_signature(ft, &stmt->function.signature, NULL) + es = ExportSymbol{kind = ExportSymbolKind::Function, funcsignature = sig} + assert sizeof(es.name) == sizeof(sig.name) + strcpy(es.name, sig.name) + exports[nexports++] = es + elif stmt->kind == AstStatementKind::Class: + handle_class_members_stage2(ft, &stmt->classdef) + elif stmt->kind == AstStatementKind::Enum: + pass # Everything done in previous type-checking steps. + else: + assert False - exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports = realloc(exports, sizeof(exports[0]) * (nexports + 1)) + assert exports != NULL exports[nexports] = ExportSymbol{} return exports +def add_variable(ft: FileTypes*, t: Type*, name: byte*) -> LocalVariable*: + var: LocalVariable* = calloc(1, sizeof *var) + var->id = ft->current_fom_types->nlocals + var->type = t -def plural_s(n: int) -> byte*: - if n == 1: - return "" - return "s" - -def nth(n: int) -> byte[100]: - first_few = [NULL as byte*, "first", "second", "third", "fourth", "fifth", "sixth"] - result: byte[100] + assert name != NULL + assert find_local_var(ft, name) == NULL + assert strlen(name) < sizeof(var->name) + strcpy(var->name, name) - if n < sizeof first_few / sizeof first_few[0]: - strcpy(result, first_few[n]) - else: - sprintf(result, "%dth", n) - return result + ft->current_fom_types->locals = realloc(ft->current_fom_types->locals, sizeof(ft->current_fom_types->locals[0]) * (ft->current_fom_types->nlocals + 1)) + assert ft->current_fom_types->locals != NULL + ft->current_fom_types->locals[ft->current_fom_types->nlocals++] = var -def short_type_description(t: Type*) -> byte*: - if t->kind == TypeKind::Class or t->kind == TypeKind::OpaqueClass: - return "a class" - if t->kind == TypeKind::Enum: - return "an enum" - if t->is_pointer_type(): - return "a pointer type" - if t->is_number_type(): - return "a number type" - if t->kind == TypeKind::Array: - return "an array type" - if t == &bool_type: - return "the built-in bool type" - assert False + return var -# TODO: make this a method in class AstExpression? -def short_expression_description(expr: AstExpression*) -> byte[200]: - result: byte[200] +global short_expr_desc_result: byte[200] - # Imagine "cannot assign to" in front of these, e.g. "cannot assign to a constant" +# Intended for errors. Returned string can be overwritten in next call. +# Imagine "cannot assign to" in front of these, e.g. "cannot assign to a constant" +def short_expression_description(expr: AstExpression*) -> byte*: if ( expr->kind == AstExpressionKind::String - or expr->kind == AstExpressionKind::Short or expr->kind == AstExpressionKind::Int + or expr->kind == AstExpressionKind::Short or expr->kind == AstExpressionKind::Long or expr->kind == AstExpressionKind::Byte + or expr->kind == AstExpressionKind::Float + or expr->kind == AstExpressionKind::Double or expr->kind == AstExpressionKind::Bool - or expr->kind == AstExpressionKind::Null ): return "a constant" - elif ( - expr->kind == AstExpressionKind::Negate - or expr->kind == AstExpressionKind::Add + + if expr->kind == AstExpressionKind::Null: + return "NULL" + if expr->kind == AstExpressionKind::GetEnumMember: + return "an enum member" + if expr->kind == AstExpressionKind::SizeOf: + return "a sizeof expression" + if expr->kind == AstExpressionKind::Instantiate: + return "a newly created instance" + if expr->kind == AstExpressionKind::Array: + return "an array literal" + if expr->kind == AstExpressionKind::Indexing: + return "an indexed value" + if expr->kind == AstExpressionKind::As: + return "the result of a cast" + if expr->kind == AstExpressionKind::Dereference: + return "the value of a pointer" + if expr->kind == AstExpressionKind::And: + return "the result of 'and'" + if expr->kind == AstExpressionKind::Or: + return "the result of 'or'" + if expr->kind == AstExpressionKind::Not: + return "the result of 'not'" + if expr->kind == AstExpressionKind::Self: + return "self" + + if expr->kind == AstExpressionKind::Call: + if expr->call.method_call_self == NULL: + return "a function call" + else: + return "a method call" + + if expr->kind == AstExpressionKind::GetVariable: + if get_special_constant(expr->varname) != -1: + return "a special constant" + return "a variable" + + if ( + expr->kind == AstExpressionKind::Add or expr->kind == AstExpressionKind::Subtract or expr->kind == AstExpressionKind::Multiply or expr->kind == AstExpressionKind::Divide or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Negate ): return "the result of a calculation" - elif ( + + if ( expr->kind == AstExpressionKind::Eq or expr->kind == AstExpressionKind::Ne or expr->kind == AstExpressionKind::Gt @@ -574,50 +457,29 @@ def short_expression_description(expr: AstExpression*) -> byte[200]: or expr->kind == AstExpressionKind::Le ): return "the result of a comparison" - elif expr->kind == AstExpressionKind::Call: - sprintf(result, "a %s call", expr->call.function_or_method()) - return result - elif expr->kind == AstExpressionKind::Instantiate: - return "a newly created instance" - elif expr->kind == AstExpressionKind::GetVariable: - if get_special_constant(expr->varname) == -1: - return "a variable" - return "a special constant" - elif expr->kind == AstExpressionKind::GetEnumMember: - return "an enum member" - elif expr->kind == AstExpressionKind::GetClassField: - snprintf(result, sizeof result, "field '%s'", expr->class_field.field_name) - return result - elif expr->kind == AstExpressionKind::As: - return "the result of a cast" - elif expr->kind == AstExpressionKind::SizeOf: - return "a sizeof expression" - elif expr->kind == AstExpressionKind::AddressOf: - subresult = short_expression_description(expr->operands) - snprintf(result, sizeof result, "address of %s", subresult) - return result - elif expr->kind == AstExpressionKind::Dereference: - return "the value of a pointer" - elif expr->kind == AstExpressionKind::And: - return "the result of 'and'" - elif expr->kind == AstExpressionKind::Or: - return "the result of 'or'" - elif expr->kind == AstExpressionKind::Not: - return "the result of 'not'" - elif expr->kind == AstExpressionKind::PreIncr or expr->kind == AstExpressionKind::PostIncr: + + if ( + expr->kind == AstExpressionKind::PreIncr + or expr->kind == AstExpressionKind::PostIncr + ): return "the result of incrementing a value" - elif expr->kind == AstExpressionKind::PreDecr or expr->kind == AstExpressionKind::PostDecr: + + if ( + expr->kind == AstExpressionKind::PreDecr + or expr->kind == AstExpressionKind::PostDecr + ): return "the result of decrementing a value" - elif expr->kind == AstExpressionKind::Indexing: - return "an indexed value" - elif expr->kind == AstExpressionKind::Self: - return "self" - elif expr->kind == AstExpressionKind::Array: - return "an array literal" - else: - expr->print() - printf("*** %d\n", expr->kind) - assert False + + if expr->kind == AstExpressionKind::AddressOf: + snprintf(short_expr_desc_result, sizeof short_expr_desc_result, "address of %s", short_expression_description(&expr->operands[0])) + return short_expr_desc_result + + if expr->kind == AstExpressionKind::GetClassField: + snprintf(short_expr_desc_result, sizeof short_expr_desc_result, "field '%s'", expr->class_field.field_name) + return short_expr_desc_result + + assert False + # The & operator can't go in front of most expressions. # You can't do &(1 + 2), for example. @@ -625,809 +487,1045 @@ def short_expression_description(expr: AstExpression*) -> byte[200]: # The same rules apply to assignments: "foo = bar" is treated as setting the # value of the pointer &foo to bar. # -# error_template can be e.g. "cannot take address of %s" or "cannot assign to %s" -def ensure_can_take_address(expression: AstExpression*, error_template: byte*) -> None: - if expression->kind == AstExpressionKind::GetClassField: - # &foo.bar --> must ensure we can take address of foo. - # Doesn't apply to &foo->bar because that's foo + some offset, so foo is already a pointer. - if not expression->class_field.uses_arrow_operator: - # Turn "cannot assign to %s" into "cannot assign to a field of %s". - # This assumes that error_template is relatively simple, i.e. it only contains one %s somewhere. - new_template = malloc(strlen(error_template) + 50) - sprintf(new_template, error_template, "a field of %s") - ensure_can_take_address(&expression->operands[0], new_template) - free(new_template) - return - - if expression->kind == AstExpressionKind::GetVariable: - # &foo is usually fine, but &WINDOWS is not - if get_special_constant(expression->varname) == -1: - return +# errmsg_template can be e.g. "cannot take address of %s" or "cannot assign to %s" +def ensure_can_take_address(fom: FunctionOrMethodTypes*, expr: AstExpression*, errmsg_template: byte*) -> None: + assert fom != NULL if ( - expression->kind == AstExpressionKind::Dereference # &*foo - or expression->kind == AstExpressionKind::Indexing # &foo[bar] = foo + some offset (foo is a pointer) + expr->kind == AstExpressionKind::Dereference + or expr->kind == AstExpressionKind::Indexing # &foo[bar] + or ( + # &foo->bar = foo + offset (it doesn't use &foo) + expr->kind == AstExpressionKind::GetClassField + and expr->class_field.uses_arrow_operator + ) ): return - # Anything else is an error. - desc: byte[200] = short_expression_description(expression) - error = malloc(strlen(error_template) + 300) - sprintf(error, error_template, desc) - fail(expression->location, error) + if expr->kind == AstExpressionKind::GetClassField: + # &foo.bar = &foo + offset + assert not expr->class_field.uses_arrow_operator -def max(a: int, b: int) -> int: - if a > b: - return a - return b + # Turn "cannot assign to %s" into "cannot assign to a field of %s". + # This assumes that errmsg_template is relatively simple, i.e. it only contains one %s somewhere. + newtemplate: byte* = malloc(strlen(errmsg_template) + 100) + sprintf(newtemplate, errmsg_template, "a field of %s") -def check_binop( - op: AstExpressionKind, - location: Location, - lhs_types: ExpressionTypes*, - rhs_types: ExpressionTypes*, -) -> Type*: - result_is_bool = False - if op == AstExpressionKind::Add: - do_what = "add" - elif op == AstExpressionKind::Subtract: - do_what = "subtract" - elif op == AstExpressionKind::Multiply: - do_what = "multiply" - elif op == AstExpressionKind::Divide: - do_what = "divide" - elif op == AstExpressionKind::Modulo: - do_what = "take remainder with" - else: - assert ( - op == AstExpressionKind::Eq - or op == AstExpressionKind::Ne - or op == AstExpressionKind::Gt - or op == AstExpressionKind::Ge - or op == AstExpressionKind::Lt - or op == AstExpressionKind::Le - ) - do_what = "compare" - result_is_bool = True + ensure_can_take_address(fom, &expr->operands[0], newtemplate) + free(newtemplate) + return - got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type - got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type() - got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type() - got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum - got_pointers = ( - lhs_types->original_type->is_pointer_type() - and rhs_types->original_type->is_pointer_type() - and ( - # Ban comparisons like int* == byte*, unless one of the two types is void* - lhs_types->original_type == rhs_types->original_type - or lhs_types->original_type == &void_ptr_type - or rhs_types->original_type == &void_ptr_type - ) - ) + # You can usually take address of variable, but you can't take address of special + # constant (e.g. &WINDOWS) + if ( + expr->kind == AstExpressionKind::GetVariable + and get_special_constant(expr->varname) == -1 + ): + return + # You can take address of self if it's not passed as a pointer: + # + # def method(self: MyClass) -> None: + # do_something(&self) + # + # This lets you e.g. write methods that return a modified instance. if ( - (not got_bools and not got_numbers and not got_enums and not got_pointers) - or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers) + expr->kind == AstExpressionKind::Self + and fom->signature.argtypes[0]->kind == TypeKind::Class ): - message: byte[500] - snprintf( - message, sizeof message, - "wrong types: cannot %s %s and %s", - do_what, lhs_types->original_type->name, rhs_types->original_type->name, - ) - fail(location, message) + return - if got_bools: - cast_type = &bool_type - elif got_integers: - size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits) - if ( - lhs_types->original_type->kind == TypeKind::SignedInteger - or rhs_types->original_type->kind == TypeKind::SignedInteger - ): - cast_type = &signed_integers[size] - else: - cast_type = &unsigned_integers[size] - elif got_numbers: - if lhs_types->original_type == &double_type or rhs_types->original_type == &double_type: - cast_type = &double_type + msg: byte[500] + snprintf(msg, sizeof(msg), errmsg_template, short_expression_description(expr)) + fail(expr->location, msg) + +# Implicit casts are used in many places, e.g. function arguments. +# +# When you pass an argument of the wrong type, it's best to give an error message +# that says so, instead of some generic "expected type foo, got object of type bar" +# kind of message. +# +# The template can contain "" and "". They will be substituted with names +# of types. We cannot use printf() style functions because the arguments can be in +# any order. +def fail_with_implicit_cast_error(location: Location, template: byte*, from: Type*, to: Type*) -> None: + assert template != NULL + + n = 0 + for i = 0; template[i] != '\0'; i++: + if template[i] == '<': + n++ + + message: byte* = malloc(sizeof(from->name)*n + strlen(template) + 1) + message[0] = '\0' + while *template != '\0': + if starts_with(template, ""): + template = &template[6] + strcat(message, from->name) + elif starts_with(template, ""): + template = &template[4] + strcat(message, to->name) else: - cast_type = &float_type - elif got_pointers: - cast_type = &void_ptr_type - elif got_enums: - cast_type = int_type + s = [*template++, '\0'] + strcat(message, s) + + fail(location, message) + + +def can_cast_implicitly(from: Type*, to: Type*) -> bool: + # TODO: document these properly. But they are: + # array to pointer, e.g. int[3] --> int* (needs special-casing elsewhere) + # from one integer type to another bigger integer type, unless it is signed-->unsigned + # between two pointer types when one of the two is void* + # from float to double (TODO) + return ( + from == to + or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) + or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) + or ( + from->is_integer_type() + and to->is_integer_type() + and from->size_in_bits < to->size_in_bits + and not (from->kind == TypeKind::SignedInteger and to->kind == TypeKind::UnsignedInteger) + ) + or (from == floatType and to == doubleType) + or (from->is_integer_type() and to->kind == TypeKind::FloatingPoint) + or (from->is_pointer_type() and to->is_pointer_type() and (from == voidPtrType or to == voidPtrType)) + ) + +def can_cast_explicitly(from: Type*, to: Type*) -> bool: + return ( + from == to + or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) + or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) + or (from->is_pointer_type() and to->is_pointer_type()) + or (from->is_number_type() and to->is_number_type()) + or (from->is_integer_type() and to->kind == TypeKind::Enum) + or (from->kind == TypeKind::Enum and to->is_integer_type()) + or (from == boolType and to->is_integer_type()) + or (from->is_pointer_type() and to == longType) + or (from == longType and to->is_pointer_type()) + ) + +def do_implicit_cast( + fom: FunctionOrMethodTypes*, + types: ExpressionTypes*, + to: Type*, + location: Location, + errormsg_template: byte*, +) -> None: + assert types->implicit_cast_type == NULL + assert not types->implicit_array_to_pointer_cast + from = types->type + if from == to: + return + + if ( + types->expr->kind == AstExpressionKind::String + and from == get_pointer_type(byteType) + and to->kind == TypeKind::Array + and to->array.item_type == byteType + ): + string_size = strlen(types->expr->string) + 1 + if to->array.len < string_size: + msg: byte[500] + snprintf(msg, sizeof(msg), "a string of %d bytes (including '\\0') does not fit into %s", string_size, to->name) + fail(location, msg) + types->implicit_string_to_array_cast = True + # Passing in NULL for errormsg_template can be used to "force" a cast to happen. + elif errormsg_template != NULL and not can_cast_implicitly(from, to): + fail_with_implicit_cast_error(location, errormsg_template, from, to) + + types->implicit_cast_type = to + types->implicit_array_to_pointer_cast = (from->kind == TypeKind::Array and is_pointer_type(to)) + + if types->implicit_array_to_pointer_cast: + ensure_can_take_address( + fom, + types->expr, + "cannot create a pointer into an array that comes from %s (try storing it to a local variable first)" + ) + +def cast_array_to_pointer(fom: FunctionOrMethodTypes*, types: ExpressionTypes*) -> None: + assert types->type->kind == TypeKind::Array + do_implicit_cast(fom, types, get_pointer_type(types->type->array.item_type), Location{}, NULL) + +def do_explicit_cast(fom: FunctionOrMethodTypes*, types: ExpressionTypes*, to: Type*, location: Location) -> None: + assert types->implicit_cast_type == NULL + from = types->type + + msg: byte[500] + + if from == to: + snprintf(msg, sizeof(msg), "unnecessary cast from %s to %s", from->name, to->name) + show_warning(location, msg) + + if not can_cast_explicitly(from, to): + snprintf(msg, sizeof(msg), "cannot cast from type %s to %s", from->name, to->name) + fail(location, msg) + + if from->kind == TypeKind::Array and is_pointer_type(to): + cast_array_to_pointer(fom, types) + +def typecheck_expression_not_void(ft: FileTypes*, expr: AstExpression*) -> ExpressionTypes*: + types: ExpressionTypes* = typecheck_expression(ft, expr) + if types != NULL: + return types + + # Should be function/method call that returns void + assert expr->kind == AstExpressionKind::Call + + msg: byte[500] + if expr->call.method_call_self == NULL: + snprintf(msg, sizeof(msg), "function '%s' does not return a value", expr->call.name) + else: + snprintf(msg, sizeof(msg), "method '%s' does not return a value", expr->call.name) + fail(expr->location, msg) + +def typecheck_expression_with_implicit_cast( + ft: FileTypes*, + expr: AstExpression*, + casttype: Type*, + errormsg_template: byte*, +) -> None: + types = typecheck_expression_not_void(ft, expr) + do_implicit_cast(ft->current_fom_types, types, casttype, expr->location, errormsg_template) + +def check_binop( + fom: FunctionOrMethodTypes*, + op: AstExpressionKind, + location: Location, + lhstypes: ExpressionTypes*, + rhstypes: ExpressionTypes*, +) -> Type*: + do_what: byte* + if op == AstExpressionKind::Add: + do_what = "add" + elif op == AstExpressionKind::Subtract: + do_what = "subtract" + elif op == AstExpressionKind::Multiply: + do_what = "multiply" + elif op == AstExpressionKind::Divide: + do_what = "divide" + elif op == AstExpressionKind::Modulo: + do_what = "take remainder with" + elif ( + op == AstExpressionKind::Eq + or op == AstExpressionKind::Ne + or op == AstExpressionKind::Gt + or op == AstExpressionKind::Ge + or op == AstExpressionKind::Lt + or op == AstExpressionKind::Le + ): + do_what = "compare" else: assert False - lhs_types->do_implicit_cast(cast_type, Location{}, NULL) - rhs_types->do_implicit_cast(cast_type, Location{}, NULL) + got_bools = lhstypes->type == boolType and rhstypes->type == boolType + got_integers = is_integer_type(lhstypes->type) and is_integer_type(rhstypes->type) + got_numbers = is_number_type(lhstypes->type) and is_number_type(rhstypes->type) + got_enums = lhstypes->type->kind == TypeKind::Enum and rhstypes->type->kind == TypeKind::Enum + got_pointers = ( + is_pointer_type(lhstypes->type) + and is_pointer_type(rhstypes->type) + and ( + # Ban comparisons like int* == byte*, unless one of the two types is void* + lhstypes->type == rhstypes->type + or lhstypes->type == voidPtrType + or rhstypes->type == voidPtrType + ) + ) - if result_is_bool: - return &bool_type - else: + if ( + ( + (not got_bools) + and (not got_numbers) + and (not got_enums) + and (not got_pointers) + ) or ( + (got_bools or got_enums) + and op != AstExpressionKind::Eq + and op != AstExpressionKind::Ne + ) or ( + got_pointers + and op != AstExpressionKind::Eq + and op != AstExpressionKind::Ne + and op != AstExpressionKind::Gt + and op != AstExpressionKind::Ge + and op != AstExpressionKind::Lt + and op != AstExpressionKind::Le + ) + ): + msg: byte[500] + snprintf(msg, sizeof(msg), "wrong types: cannot %s %s and %s", do_what, lhstypes->type->name, rhstypes->type->name) + fail(location, msg) + + cast_type: Type* = NULL + if got_bools: + cast_type = boolType + if got_integers: + cast_type = get_integer_type( + max(lhstypes->type->size_in_bits, rhstypes->type->size_in_bits), + lhstypes->type->kind == TypeKind::SignedInteger or rhstypes->type->kind == TypeKind::SignedInteger + ) + if got_numbers and not got_integers: + if lhstypes->type == doubleType or rhstypes->type == doubleType: + cast_type = doubleType + else: + cast_type = floatType + if got_pointers: + cast_type = get_integer_type(64, False) + if got_enums: + cast_type = intType + assert cast_type != NULL + + do_implicit_cast(fom, lhstypes, cast_type, Location{}, NULL) + do_implicit_cast(fom, rhstypes, cast_type, Location{}, NULL) + + if ( + op == AstExpressionKind::Add + or op == AstExpressionKind::Subtract + or op == AstExpressionKind::Multiply + or op == AstExpressionKind::Divide + or op == AstExpressionKind::Modulo + ): return cast_type -def check_class_field(location: Location, class_type: Type*, field_name: byte*) -> ClassField*: - assert class_type->kind == TypeKind::Class + if ( + op == AstExpressionKind::Eq + or op == AstExpressionKind::Ne + or op == AstExpressionKind::Gt + or op == AstExpressionKind::Ge + or op == AstExpressionKind::Lt + or op == AstExpressionKind::Le + ): + return boolType + + assert False + + +def check_increment_or_decrement(ft: FileTypes*, expr: AstExpression*) -> Type*: + bad_type_fmt, bad_expr_fmt: byte* + + if expr->kind == AstExpressionKind::PreIncr or expr->kind == AstExpressionKind::PostIncr: + bad_type_fmt = "cannot increment a value of type %s" + bad_expr_fmt = "cannot increment %s" + elif expr->kind == AstExpressionKind::PreDecr or expr->kind == AstExpressionKind::PostDecr: + bad_type_fmt = "cannot decrement a value of type %s" + bad_expr_fmt = "cannot decrement %s" + else: + assert False + + ensure_can_take_address(ft->current_fom_types, &expr->operands[0], bad_expr_fmt) + + t = typecheck_expression_not_void(ft, &expr->operands[0])->type + if not is_integer_type(t) and not is_pointer_type(t): + msg: byte[500] + snprintf(msg, sizeof(msg), bad_type_fmt, t->name) + fail(expr->location, msg) + return t + +def typecheck_dereferenced_pointer(location: Location, t: Type*) -> None: + # TODO: improved error message for dereferencing void* + if t->kind != TypeKind::Pointer: + msg: byte[500] + snprintf(msg, sizeof(msg), "the dereference operator '*' is only for pointers, not for %s", t->name) + fail(location, msg) + +# ptr[index] +def typecheck_indexing( + ft: FileTypes*, + ptrexpr: AstExpression*, + indexexpr: AstExpression*, +) -> Type*: + msg: byte[500] + + types = typecheck_expression_not_void(ft, ptrexpr) + + if types->type->kind == TypeKind::Array: + cast_array_to_pointer(ft->current_fom_types, types) + ptrtype = types->implicit_cast_type + else: + if types->type->kind != TypeKind::Pointer: + snprintf(msg, sizeof(msg), "value of type %s cannot be indexed", types->type->name) + fail(ptrexpr->location, msg) + ptrtype = types->type + + assert ptrtype != NULL + assert ptrtype->kind == TypeKind::Pointer + + indextypes = typecheck_expression_not_void(ft, indexexpr) + if not is_integer_type(indextypes->type): + snprintf(msg, sizeof(msg), "the index inside [...] must be an integer, not %s", indextypes->type->name) + fail(indexexpr->location, msg) + + # LLVM assumes that indexes smaller than 64 bits are signed. + # https://github.com/Akuli/jou/issues/48 + do_implicit_cast(ft->current_fom_types, indextypes, longType, Location{}, NULL) + + return ptrtype->value_type + + +def typecheck_and_or( + ft: FileTypes*, + lhsexpr: AstExpression*, + rhsexpr: AstExpression*, + and_or: byte*, +) -> None: + assert strcmp(and_or, "and") == 0 or strcmp(and_or, "or") == 0 + + errormsg: byte[100] + snprintf(errormsg, sizeof(errormsg), "'%s' only works with booleans, not ", and_or) + + typecheck_expression_with_implicit_cast(ft, lhsexpr, boolType, errormsg) + typecheck_expression_with_implicit_cast(ft, rhsexpr, boolType, errormsg) + + +global nth_result_buffer: byte[100] + + +# Be aware that return value may change as you call this many times. +def nth(n: int) -> byte*: + assert n >= 1 + + first_few = [NULL as byte*, "first", "second", "third", "fourth", "fifth", "sixth"] + if n < sizeof(first_few) / sizeof(first_few[0]): + return first_few[n] + + sprintf(nth_result_buffer, "%dth", n) + return nth_result_buffer + + +def plural_s(n: int) -> byte*: + if n == 1: + # e.g. "1 argument" + return "" + else: + # e.g. "0 arguments", "2 arguments" + return "s" + + +# returns NULL if the function doesn't return anything, otherwise non-owned pointer to non-owned type +def typecheck_function_or_method_call(ft: FileTypes*, call: AstCall*, self_type: Type*, location: Location) -> Type*: + msg: byte[500] + + sig = find_function_or_method(ft, self_type, call->name) + if sig == NULL: + if self_type == NULL: + snprintf(msg, sizeof(msg), "function '%s' not found", call->name) + elif self_type->kind == TypeKind::Class: + snprintf( + msg, sizeof(msg), + "class %s does not have a method named '%s'", + self_type->name, call->name) + elif self_type->kind == TypeKind::Pointer and find_method(self_type->value_type, call->name) != NULL: + snprintf( + msg, sizeof(msg), + "the method '%s' is defined on class %s, not on the pointer type %s, so you need to dereference the pointer first (e.g. by using '->' instead of '.')", + call->name, self_type->value_type->name, self_type->name) + else: + snprintf( + msg, sizeof(msg), + "type %s does not have any methods because it is %s, not a class", + self_type->name, short_type_description(self_type)) + + fail(location, msg) + + if self_type == NULL: + function_or_method = "function" + else: + function_or_method = "method" + + sigstr = signature_to_string(sig, False, False) + + nargs = sig->nargs + if self_type != NULL: + nargs-- + + if call->nargs < nargs or (call->nargs > nargs and not sig->takes_varargs): + snprintf( + msg, sizeof(msg), + "%s %s takes %d argument%s, but it was called with %d argument%s", + function_or_method, + sigstr, + nargs, + plural_s(nargs), + call->nargs, + plural_s(call->nargs), + ) + fail(location, msg) + + k = 0 + for i = 0; i < sig->nargs; i++: + if strcmp(sig->argnames[i], "self") == 0: + continue + # This is a common error, so worth spending some effort to get a good error message. + snprintf(msg, sizeof msg, "%s argument of %s %s should have type , not ", nth(i+1), function_or_method, sigstr) + typecheck_expression_with_implicit_cast(ft, &call->args[k++], sig->argtypes[i], msg) + + for i = k; i < call->nargs; i++: + # This code runs for varargs, e.g. the things to format in printf(). + types = typecheck_expression_not_void(ft, &call->args[i]) + + if types->type->kind == TypeKind::Array: + cast_array_to_pointer(ft->current_fom_types, types) + elif ( + (is_integer_type(types->type) and types->type->size_in_bits < 32) + or types->type == boolType + ): + # Add implicit cast to signed int, just like in C. + do_implicit_cast(ft->current_fom_types, types, intType, Location{}, NULL) + elif types->type == floatType: + do_implicit_cast(ft->current_fom_types, types, doubleType, Location{}, NULL) + + free(sigstr) + return sig->returntype + - field = class_type->class_members.find_field(field_name) - if field == NULL: - message: byte[500] - snprintf(message, sizeof message, "class %s has no field named '%s'", class_type->name, field_name) - fail(location, message) - return field +def typecheck_class_field( + classtype: Type*, + fieldname: byte*, + location: Location, +) -> ClassField*: + assert classtype->kind == TypeKind::Class + for f = classtype->classdata.fields; f < &classtype->classdata.fields[classtype->classdata.nfields]; f++: + if strcmp(f->name, fieldname) == 0: + return f + + msg: byte[500] + snprintf(msg, sizeof(msg), "class %s has no field named '%s'", classtype->name, fieldname) + fail(location, msg) + + +def typecheck_instantiation(ft: FileTypes*, inst: AstInstantiation*, location: Location) -> Type*: + tmp = AstType{kind = AstTypeKind::Named, location = inst->class_name_location} + assert sizeof(tmp.name) == sizeof(inst->class_name) + strcpy(tmp.name, inst->class_name) + t = type_from_ast(ft, &tmp) + + msg: byte[500] + + if t->kind != TypeKind::Class: + snprintf( + msg, sizeof(msg), + "the %s{...} syntax is only for classes, but %s is %s", + t->name, t->name, short_type_description(t)) + fail(location, msg) + + specified_fields: ClassField** = malloc(sizeof(specified_fields[0]) * inst->nfields) + + for i = 0; i < inst->nfields; i++: + f = typecheck_class_field(t, inst->field_names[i], inst->field_values[i].location) + + snprintf(msg, sizeof msg, + "value for field '%s' of class %s must be of type , not ", + inst->field_names[i], inst->class_name) + typecheck_expression_with_implicit_cast(ft, &inst->field_values[i], f->type, msg) + specified_fields[i] = f -def cast_array_items_to_a_common_type(error_location: Location, types: ExpressionTypes**, ntypes: int) -> Type*: + for i1 = 0; i1 < inst->nfields; i1++: + for i2 = i1+1; i2 < inst->nfields; i2++: + if specified_fields[i1]->union_id == specified_fields[i2]->union_id: + snprintf(msg, sizeof(msg), + "fields '%s' and '%s' cannot be set simultaneously because they belong to the same union", + specified_fields[i1]->name, specified_fields[i2]->name) + fail(inst->field_values[i2].location,msg) + + free(specified_fields) + return t + + +def enum_member_exists(t: Type*, name: byte*) -> bool: + assert t->kind == TypeKind::Enum + for i = 0; i < t->enummembers.count; i++: + if strcmp(t->enummembers.names[i], name) == 0: + return True + return False + + +def cast_array_members_to_a_common_type(fom: FunctionOrMethodTypes*, error_location: Location, exprtypes: ExpressionTypes**) -> Type*: # Avoid O(ntypes^2) code in a long array where all or almost all items have the same type. - # This is at most O(ntypes*ndistinct). - distinct: Type** = malloc(sizeof distinct[0] * ntypes) + # This is at most O(ntypes*k) where k is the number of distinct types. + distinct: Type** = NULL ndistinct = 0 - for i = 0; i < ntypes; i++: + + for et = exprtypes; *et != NULL; et++: found = False - for k = 0; k < ndistinct; k++: - if distinct[k] == types[i]->original_type: + for t = distinct; t < &distinct[ndistinct]; t++: + if (*et)->type == *t: found = True break if not found: - distinct[ndistinct++] = types[i]->original_type + distinct = realloc(distinct, sizeof(distinct[0]) * (ndistinct + 1)) + assert distinct != NULL + distinct[ndistinct++] = (*et)->type - compatible_with_all: Type** = malloc(sizeof compatible_with_all[0] * ndistinct) + compatible_with_all: Type** = NULL n_compatible_with_all = 0 - for i = 0; i < ndistinct; i++: - compat = True - for k = 0; k < ndistinct; k++: - if not can_cast_implicitly(distinct[k], distinct[i]): - compat = False + + for t = distinct; t < &distinct[ndistinct]; t++: + t_compatible_with_all_others = True + for t2 = distinct; t2 < &distinct[ndistinct]; t2++: + if not can_cast_implicitly(*t2, *t): + t_compatible_with_all_others = False break - if compat: - compatible_with_all[n_compatible_with_all++] = distinct[i] + + if t_compatible_with_all_others: + compatible_with_all = realloc(compatible_with_all, sizeof(compatible_with_all[0]) * (n_compatible_with_all + 1)) + assert compatible_with_all != NULL + compatible_with_all[n_compatible_with_all++] = *t if n_compatible_with_all != 1: - # Can't make an unambiguous choice. Mention all types we considered in the error message. - assert sizeof distinct[0]->name == 100 - message: byte* = calloc(200, ndistinct+1) - strcpy(message, "array items have different types (") - for i = 0; i < ndistinct; i++: - if i != 0: - strcat(message, ", ") - strcat(message, distinct[i]->name) - strcat(message, ")") - fail(error_location, message) - - item_type = compatible_with_all[0] + size = 500L + for t = distinct; t < &distinct[ndistinct]; t++: + size += strlen((*t)->name) + 3 # 1 for comma, 1 for space, 1 because why not lol + + msg: byte* = malloc(size) + assert msg != NULL + + strcpy(msg, "array items have different types (") + for t = distinct; t < &distinct[ndistinct]; t++: + if t != distinct: + strcat(msg, ", ") + strcat(msg, (*t)->name) + strcat(msg, ")") + fail(error_location, msg) + + elemtype = compatible_with_all[0] free(distinct) free(compatible_with_all) - for i = 0; i < ntypes; i++: - types[i]->do_implicit_cast(item_type, Location{}, NULL) - return item_type - - -class Stage3TypeChecker: - file_types: FileTypes* - current_function_or_method: FunctionOrMethodTypes* - nested_loop_count: int - - def add_local_var(self, name: byte*, type: Type*) -> LocalVariable*: - v: LocalVariable* = calloc(1, sizeof *v) - assert strlen(name) < sizeof v->name - strcpy(v->name, name) - v->type = type - - dest_pointer = &self->current_function_or_method->local_vars - while *dest_pointer != NULL: - dest_pointer = &(*dest_pointer)->next - - *dest_pointer = v - return v - - def find_var(self, name: byte*) -> Type*: - if get_special_constant(name) != -1: - return &bool_type - local_var = self->current_function_or_method->find_local_var(name) - if local_var != NULL: - return local_var->type - for i = 0; i < self->file_types->nglobals; i++: - if strcmp(self->file_types->globals[i].name, name) == 0: - return self->file_types->globals[i].type - return NULL + for et = exprtypes; *et != NULL; et++: + do_implicit_cast(fom, *et, elemtype, error_location, NULL) + return elemtype + + +def typecheck_expression(ft: FileTypes*, expr: AstExpression*) -> ExpressionTypes*: + msg: byte[500] + result: Type* = NULL + + if expr->kind == AstExpressionKind::Bool: + result = boolType + elif expr->kind == AstExpressionKind::Byte: + result = byteType + elif expr->kind == AstExpressionKind::Double: + result = doubleType + elif expr->kind == AstExpressionKind::Float: + result = floatType + elif expr->kind == AstExpressionKind::Short: + result = shortType + elif expr->kind == AstExpressionKind::Int: + result = intType + elif expr->kind == AstExpressionKind::Long: + result = longType + elif expr->kind == AstExpressionKind::Null: + result = voidPtrType + elif expr->kind == AstExpressionKind::String: + result = get_pointer_type(byteType) - def find_function_or_method(self, self_type: Type*, name: byte*) -> Signature*: - if self_type == NULL: - return self->file_types->find_function(name) - elif self_type->kind == TypeKind::Class: - return self_type->class_members.find_method(name) - else: - return NULL + elif expr->kind == AstExpressionKind::GetEnumMember: + result = find_type(ft, expr->enum_member.enum_name) + if result == NULL: + snprintf(msg, sizeof(msg), "there is no type named '%s'", expr->enum_member.enum_name) + fail(expr->location, msg) + if result->kind != TypeKind::Enum: + snprintf( + msg, sizeof(msg), + "the '::' syntax is only for enums, but %s is %s", + expr->enum_member.enum_name, + short_type_description(result), + ) + fail(expr->location, msg) + if not enum_member_exists(result, expr->enum_member.member_name): + snprintf( + msg, sizeof(msg), + "enum %s has no member named '%s'", + expr->enum_member.enum_name, expr->enum_member.member_name) + fail(expr->location, msg) - def do_call(self, call: AstCall*) -> Type*: - message: byte[500] + elif expr->kind == AstExpressionKind::SizeOf: + typecheck_expression_not_void(ft, &expr->operands[0]) + result = longType - if call->method_call_self != NULL: - self_type = self->do_expression(call->method_call_self)->original_type - if call->uses_arrow_operator: - if self_type->kind != TypeKind::Pointer or self_type->value_type->kind != TypeKind::Class: - snprintf( - message, sizeof message, - "left side of the '->' operator must be a pointer, not %s", - self_type->name, - ) - fail(call->location, message) - self_type = self_type->value_type - else: - self_type = NULL - - signature = self->find_function_or_method(self_type, call->name) - if signature == NULL: - if self_type == NULL: - snprintf(message, sizeof message, "function '%s' not found", call->name) - elif ( - self_type->kind == TypeKind::Pointer - and self_type->value_type->kind == TypeKind::Class - and self_type->value_type->class_members.find_method(call->name) != NULL - ): - snprintf( - message, sizeof message, - "the method '%s' is defined on class %s, not on the pointer type %s, so you need to dereference the pointer first (e.g. by using '->' instead of '.')", - call->name, self_type->value_type->name, self_type->name, - ) - elif self_type->kind == TypeKind::Class: + elif expr->kind == AstExpressionKind::Instantiate: + result = typecheck_instantiation(ft, &expr->instantiation, expr->location) + + elif expr->kind == AstExpressionKind::Array: + n = expr->array.length + exprtypes: ExpressionTypes** = calloc(sizeof(exprtypes[0]), n+1) + for i = 0; i < n; i++: + exprtypes[i] = typecheck_expression_not_void(ft, &expr->array.items[i]) + + membertype = cast_array_members_to_a_common_type(ft->current_fom_types, expr->location, exprtypes) + free(exprtypes) + result = get_array_type(membertype, n) + + elif expr->kind == AstExpressionKind::GetClassField: + if expr->class_field.uses_arrow_operator: + temptype = typecheck_expression_not_void(ft, expr->class_field.instance)->type + if temptype->kind != TypeKind::Pointer or temptype->value_type->kind != TypeKind::Class: snprintf( - message, sizeof message, - "class %s does not have a method named '%s'", self_type->name, call->name, - ) - else: + msg, sizeof(msg), + "left side of the '->' operator must be a pointer to a class, not %s", + temptype->name) + fail(expr->location, msg) + result = typecheck_class_field(temptype->value_type, expr->class_field.field_name, expr->location)->type + else: + temptype = typecheck_expression_not_void(ft, expr->class_field.instance)->type + if temptype->kind != TypeKind::Class: snprintf( - message, sizeof message, - "type %s does not have any methods because it is %s, not a class", - self_type->name, short_type_description(self_type), - ) - fail(call->location, message) + msg, sizeof(msg), + "left side of the '.' operator must be an instance of a class, not %s", + temptype->name) + fail(expr->location, msg) + result = typecheck_class_field(temptype, expr->class_field.field_name, expr->location)->type - if call->method_call_self != NULL and not call->uses_arrow_operator: - snprintf( - message, sizeof message, - "cannot take address of %%s, needed for calling the %s() method", call->name) - ensure_can_take_address(call->method_call_self, message) + elif expr->kind == AstExpressionKind::Call: + if expr->call.method_call_self == NULL: + result = typecheck_function_or_method_call(ft, &expr->call, NULL, expr->location) + elif expr->call.uses_arrow_operator: + temptype = typecheck_expression_not_void(ft, expr->call.method_call_self)->type + if temptype->kind != TypeKind::Pointer: + snprintf(msg, sizeof(msg), + "left side of the '->' operator must be a pointer, not %s", + temptype->name) + fail(expr->location, msg) + result = typecheck_function_or_method_call(ft, &expr->call, temptype->value_type, expr->location) + else: + temptype = typecheck_expression_not_void(ft, expr->call.method_call_self)->type + result = typecheck_function_or_method_call(ft, &expr->call, temptype, expr->location) + + # If self argument is passed by pointer, make sure we can create that pointer + found = False + assert temptype->kind == TypeKind::Class + for m = temptype->classdata.methods; m < &temptype->classdata.methods[temptype->classdata.nmethods]; m++: + if strcmp(m->name, expr->call.name) != 0: + continue + + if is_pointer_type(m->argtypes[0]): + assert strstr(expr->call.name, "%") == NULL + snprintf( + msg, sizeof msg, + "cannot take address of %%s, needed for calling the %s() method", + expr->call.name) + ensure_can_take_address(ft->current_fom_types, expr->call.method_call_self, msg) - signature_string = signature->to_string(False, False) + found = True + break - expected = signature->nargs - if self_type != NULL: - expected-- # exclude self + assert found - if call->nargs < expected or (call->nargs > expected and not signature->takes_varargs): - snprintf( - message, sizeof message, - "%s %s takes %d argument%s, but it was called with %d argument%s", - signature->function_or_method(), - signature_string, - expected, - plural_s(expected), - call->nargs, - plural_s(call->nargs), - ) - fail(call->location, message) + if result == NULL: + # no return value produced + return NULL - k = 0 - for i = 0; i < signature->nargs; i++: - if strcmp(signature->argnames[i], "self") == 0: - continue + elif expr->kind == AstExpressionKind::Indexing: + result = typecheck_indexing(ft, &expr->operands[0], &expr->operands[1]) - # This is a common error, so worth spending some effort to get a good error message. - tmp = nth(i+1) - snprintf( - message, sizeof message, - "%s argument of %s %s should have type , not ", - tmp, signature->function_or_method(), signature_string, - ) - self->do_expression_and_implicit_cast(&call->args[k++], signature->argtypes[i], message) - - for i = k; i < call->nargs; i++: - # This code runs for varargs, e.g. the things to format in printf(). - types = self->do_expression(&call->args[i]) - - if ( - (types->original_type->is_integer_type() and types->original_type->size_in_bits < 32) - or types->original_type == &bool_type - ): - # Add implicit cast to signed int, just like in C. - types->do_implicit_cast(int_type, Location{}, NULL) - elif types->original_type == &float_type: - types->do_implicit_cast(&double_type, Location{}, NULL) - elif types->original_type->kind == TypeKind::Array: - types->cast_array_to_pointer() - - free(signature_string) - return signature->return_type - - def do_increment_or_decrement(self, expression: AstExpression*, increment_or_decrement: byte*) -> Type*: - assert strcmp(increment_or_decrement, "increment") == 0 or strcmp(increment_or_decrement, "decrement") == 0 - - bad_expression_error_template: byte[50] - sprintf(bad_expression_error_template, "cannot %s %%s", increment_or_decrement) - ensure_can_take_address(&expression->operands[0], bad_expression_error_template) - - t = self->do_expression(&expression->operands[0])->original_type - if not t->is_integer_type() and not t->is_pointer_type(): - error: byte* = malloc(strlen(t->name) + 100) - sprintf(error, "cannot %s a value of type %s", increment_or_decrement, t->name) - fail(expression->location, error) - return t - - def do_enum_member(self, location: Location, enum_name: byte*, member_name: byte*) -> Type*: - message: byte[200] - - enum_type = self->file_types->find_type(enum_name) - if enum_type == NULL: - snprintf(message, sizeof message, "there is no type named '%s'", enum_name) - fail(location, message) - - if enum_type->kind != TypeKind::Enum: - snprintf( - message, sizeof message, - "the '::' syntax is only for enums, but %s is %s", - enum_name, short_type_description(enum_type), - ) - fail(location, message) + elif expr->kind == AstExpressionKind::AddressOf: + ensure_can_take_address(ft->current_fom_types, &expr->operands[0], "the '&' operator cannot be used with %s") + temptype = typecheck_expression_not_void(ft, &expr->operands[0])->type + result = get_pointer_type(temptype) - if enum_type->enum_members.find_index(member_name) == -1: - snprintf(message, sizeof message, "enum %s has no member named '%s'", enum_name, member_name) - fail(location, message) + elif expr->kind == AstExpressionKind::GetVariable: + result = find_any_var(ft, expr->varname) + if result == NULL: + snprintf(msg, sizeof(msg), "no variable named '%s'", expr->varname) + fail(expr->location, msg) - return enum_type + elif expr->kind == AstExpressionKind::Self: + selfvar = find_local_var(ft, "self") + assert selfvar != NULL + result = selfvar->type - def do_instantiation(self, instantiation: AstInstantiation*) -> Type*: - message:byte[500] + elif expr->kind == AstExpressionKind::Dereference: + temptype = typecheck_expression_not_void(ft, &expr->operands[0])->type + typecheck_dereferenced_pointer(expr->location, temptype) + result = temptype->value_type - t = self->file_types->find_type(instantiation->class_name) - if t == NULL: - snprintf( - message, sizeof message, - "there is no type named '%s'", instantiation->class_name, - ) - fail(instantiation->class_name_location, message) + elif expr->kind == AstExpressionKind::And: + typecheck_and_or(ft, &expr->operands[0], &expr->operands[1], "and") + result = boolType - if t->kind != TypeKind::Class: - description = short_type_description(t) - snprintf( - message, sizeof message, - "the %s{...} syntax is only for classes, but %s is %s", - t->name, t->name, description, - ) - fail(instantiation->class_name_location, message) + elif expr->kind == AstExpressionKind::Or: + typecheck_and_or(ft, &expr->operands[0], &expr->operands[1], "or") + result = boolType - specified_fields: ClassField** = malloc(sizeof specified_fields[0] * instantiation->nfields) - for i = 0; i < instantiation->nfields; i++: - snprintf( - message, sizeof message, - "value for field '%s' of class %s must be of type , not ", - instantiation->field_names[i], t->name, - ) - specified_fields[i] = check_class_field( - instantiation->field_values[i].location, - t, - instantiation->field_names[i], - ) - self->do_expression_and_implicit_cast( - &instantiation->field_values[i], - specified_fields[i]->type, - message, - ) + elif expr->kind == AstExpressionKind::Not: + typecheck_expression_with_implicit_cast( + ft, &expr->operands[0], boolType, + "value after 'not' must be a boolean, not ") + result = boolType + + elif expr->kind == AstExpressionKind::Negate: + result = typecheck_expression_not_void(ft, &expr->operands[0])->type + if result->kind != TypeKind::SignedInteger and result->kind != TypeKind::FloatingPoint: + snprintf(msg, sizeof(msg), + "value after '-' must be a float or double or a signed integer, not %s", + result->name) + fail(expr->location, msg) - for i1 = 0; i1 < instantiation->nfields; i1++: - for i2 = i1+1; i2 < instantiation->nfields; i2++: - if specified_fields[i1]->union_id == specified_fields[i2]->union_id: - snprintf( - message, sizeof message, - "fields '%s' and '%s' cannot be set simultaneously because they belong to the same union", - specified_fields[i1]->name, - specified_fields[i2]->name, - ) - fail(instantiation->field_values[i2].location, message) - - return t - - def do_indexing(self, pointer: AstExpression*, index: AstExpression*) -> Type*: - message: byte[500] - types = self->do_expression(pointer) - - if types->original_type->kind == TypeKind::Array: - types->cast_array_to_pointer() - pointer_type = types->implicit_cast_type - elif types->original_type->kind == TypeKind::Pointer: - pointer_type = types->original_type - else: - snprintf(message, sizeof message[0], "value of type %s cannot be indexed", types->original_type->name) - fail(pointer->location, message) - - index_types = self->do_expression(index) - assert index_types != NULL - - if not index_types->original_type->is_integer_type(): - snprintf(message, sizeof message[0], "the index inside [...] must be an integer, not %s", index_types->original_type->name) - fail(index->location, message) - - # LLVM assumes that indexes smaller than 64 bits are signed. - # https://github.com/Akuli/jou/issues/48 - index_types->do_implicit_cast(long_type, Location{}, NULL) - - return pointer_type->value_type - - def do_expression_maybe_void(self, expression: AstExpression*) -> ExpressionTypes*: - result: Type* - message: byte[200] - - if expression->kind == AstExpressionKind::String: - result = byte_type->get_pointer_type() - elif expression->kind == AstExpressionKind::Bool: - result = &bool_type - elif expression->kind == AstExpressionKind::Byte: - result = byte_type - elif expression->kind == AstExpressionKind::Short: - result = short_type - elif expression->kind == AstExpressionKind::Int: - result = int_type - elif expression->kind == AstExpressionKind::Long: - result = long_type - elif expression->kind == AstExpressionKind::Float: - result = &float_type - elif expression->kind == AstExpressionKind::Double: - result = &double_type - elif expression->kind == AstExpressionKind::Null: - result = &void_ptr_type - elif expression->kind == AstExpressionKind::Array: - n = expression->array.length - item_types: ExpressionTypes** = malloc(n * sizeof item_types[0]) - for i = 0; i < n; i++: - item_types[i] = self->do_expression(&expression->array.items[i]) - member_type = cast_array_items_to_a_common_type(expression->location, item_types, n) - free(item_types) - result = member_type->get_array_type(n) - elif expression->kind == AstExpressionKind::Call: - result = self->do_call(&expression->call) - if result == NULL: - return NULL - elif expression->kind == AstExpressionKind::GetVariable: - result = self->find_var(expression->varname) - if result == NULL: - snprintf(message, sizeof message, "no variable named '%s'", expression->varname) - fail(expression->location, message) - elif expression->kind == AstExpressionKind::As: - value_types = self->do_expression(&expression->as_expression->value) - result = type_from_ast(self->file_types, &expression->as_expression->type) - value_types->do_explicit_cast(result, expression->location) - elif expression->kind == AstExpressionKind::GetEnumMember: - result = self->do_enum_member( - expression->location, - expression->enum_member.enum_name, - expression->enum_member.member_name, - ) - elif expression->kind == AstExpressionKind::And: - self->do_expression_and_implicit_cast(&expression->operands[0], &bool_type, "'and' only works with booleans, not ") - self->do_expression_and_implicit_cast(&expression->operands[1], &bool_type, "'and' only works with booleans, not ") - result = &bool_type - elif expression->kind == AstExpressionKind::Or: - self->do_expression_and_implicit_cast(&expression->operands[0], &bool_type, "'or' only works with booleans, not ") - self->do_expression_and_implicit_cast(&expression->operands[1], &bool_type, "'or' only works with booleans, not ") - result = &bool_type - elif ( - expression->kind == AstExpressionKind::Add - or expression->kind == AstExpressionKind::Subtract - or expression->kind == AstExpressionKind::Multiply - or expression->kind == AstExpressionKind::Divide - or expression->kind == AstExpressionKind::Modulo - or expression->kind == AstExpressionKind::Eq - or expression->kind == AstExpressionKind::Ne - or expression->kind == AstExpressionKind::Gt - or expression->kind == AstExpressionKind::Ge - or expression->kind == AstExpressionKind::Lt - or expression->kind == AstExpressionKind::Le - ): - lhs_types = self->do_expression(&expression->operands[0]) - rhs_types = self->do_expression(&expression->operands[1]) - result = check_binop(expression->kind, expression->location, lhs_types, rhs_types) - elif expression->kind == AstExpressionKind::Negate: - result = self->do_expression(&expression->operands[0])->original_type - # TODO: check for floats/doubles too - if result->kind != TypeKind::SignedInteger and result->kind != TypeKind::FloatingPoint: - snprintf( - message, sizeof message, - "value after '-' must be a float or double or a signed integer, not %s", - result->name, - ) - fail(expression->location, message) - elif expression->kind == AstExpressionKind::PreIncr or expression->kind == AstExpressionKind::PostIncr: - result = self->do_increment_or_decrement(expression, "increment") - elif expression->kind == AstExpressionKind::PreDecr or expression->kind == AstExpressionKind::PostDecr: - result = self->do_increment_or_decrement(expression, "decrement") - elif expression->kind == AstExpressionKind::GetClassField: - lhs_type = self->do_expression(expression->class_field.instance)->original_type - if expression->class_field.uses_arrow_operator: - if lhs_type->kind != TypeKind::Pointer or lhs_type->value_type->kind != TypeKind::Class: - snprintf( - message, sizeof message, - "left side of the '->' operator must be a pointer to a class, not %s", - lhs_type->name, - ) - fail(expression->location, message) - result = check_class_field(expression->location, lhs_type->value_type, expression->class_field.field_name)->type - else: - if lhs_type->kind != TypeKind::Class: - snprintf( - message, sizeof message, - "left side of the '.' operator must be an instance of a class, not %s", - lhs_type->name, - ) - fail(expression->location, message) - result = check_class_field(expression->location, lhs_type, expression->class_field.field_name)->type - elif expression->kind == AstExpressionKind::AddressOf: - ensure_can_take_address(&expression->operands[0], "the '&' operator cannot be used with %s") - result = self->do_expression(&expression->operands[0])->original_type->get_pointer_type() - elif expression->kind == AstExpressionKind::Dereference: - pointer_type = self->do_expression(expression->operands)->original_type - if pointer_type->kind != TypeKind::Pointer: - snprintf( - message, sizeof message, - "the dereference operator '*' is only for pointers, not for %s", - pointer_type->name, - ) - fail(expression->location, message) - result = pointer_type->value_type - elif expression->kind == AstExpressionKind::Instantiate: - result = self->do_instantiation(&expression->instantiation) - elif expression->kind == AstExpressionKind::Indexing: - result = self->do_indexing(&expression->operands[0], &expression->operands[1]) - elif expression->kind == AstExpressionKind::Not: - self->do_expression_and_implicit_cast( - &expression->operands[0], &bool_type, - "value after 'not' must be a boolean, not ", - ) - result = &bool_type - elif expression->kind == AstExpressionKind::Self: - class_type = self->current_function_or_method->signature.get_containing_class() - assert class_type != NULL - result = class_type->get_pointer_type() - elif expression->kind == AstExpressionKind::SizeOf: - self->do_expression(&expression->operands[0]) - result = long_type + elif ( + expr->kind == AstExpressionKind::Add + or expr->kind == AstExpressionKind::Subtract + or expr->kind == AstExpressionKind::Multiply + or expr->kind == AstExpressionKind::Divide + or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Eq + or expr->kind == AstExpressionKind::Ne + or expr->kind == AstExpressionKind::Gt + or expr->kind == AstExpressionKind::Ge + or expr->kind == AstExpressionKind::Lt + or expr->kind == AstExpressionKind::Le + ): + lhstypes = typecheck_expression_not_void(ft, &expr->operands[0]) + rhstypes = typecheck_expression_not_void(ft, &expr->operands[1]) + result = check_binop(ft->current_fom_types, expr->kind, expr->location, lhstypes, rhstypes) + + elif ( + expr->kind == AstExpressionKind::PreIncr + or expr->kind == AstExpressionKind::PreDecr + or expr->kind == AstExpressionKind::PostIncr + or expr->kind == AstExpressionKind::PostDecr + ): + result = check_increment_or_decrement(ft, expr) + + elif expr->kind == AstExpressionKind::As: + origtypes = typecheck_expression_not_void(ft, &expr->as_->value) + result = type_from_ast(ft, &expr->as_->type) + do_explicit_cast(ft->current_fom_types, origtypes, result, expr->location) + + else: + printf("%d\n", expr->kind) + assert False + + assert result != NULL + + types: ExpressionTypes* = calloc(1, sizeof *types) + types->expr = expr + types->type = result + + ft->current_fom_types->expr_types = realloc(ft->current_fom_types->expr_types, sizeof(ft->current_fom_types->expr_types[0]) * (ft->current_fom_types->n_expr_types + 1)) + assert ft->current_fom_types->expr_types != NULL + ft->current_fom_types->expr_types[ft->current_fom_types->n_expr_types++] = types + + return types + + +def typecheck_body(ft: FileTypes*, body: AstBody*) -> None: + for i = 0; i < body->nstatements; i++: + typecheck_statement(ft, &body->statements[i]) + + +def typecheck_if_statement(ft: FileTypes*, ifstmt: AstIfStatement*) -> None: + for i = 0; i < ifstmt->n_if_and_elifs; i++: + if i == 0: + errmsg = "'if' condition must be a boolean, not " else: - printf("*** expr %d\n", expression->kind as int) - expression->print() - assert False + errmsg = "'elif' condition must be a boolean, not " - p: ExpressionTypes* = malloc(sizeof *p) - *p = ExpressionTypes{ - expression = expression, - original_type = result, - next = self->current_function_or_method->expression_types, - } - self->current_function_or_method->expression_types = p - return p - - def do_expression(self, expression: AstExpression*) -> ExpressionTypes*: - types = self->do_expression_maybe_void(expression) - if types == NULL: - assert expression->kind == AstExpressionKind::Call - name = expression->call.name - message = malloc(strlen(name) + 100) - sprintf(message, "%s '%s' does not return a value", expression->call.function_or_method(), name) - fail(expression->location, message) - return types + typecheck_expression_with_implicit_cast( + ft, &ifstmt->if_and_elifs[i].condition, boolType, errmsg) + typecheck_body(ft, &ifstmt->if_and_elifs[i].body) - def do_expression_and_implicit_cast( - self, - expression: AstExpression*, - cast_type: Type*, - error_message_template: byte*, - ) -> ExpressionTypes*: - types = self->do_expression(expression) - types->do_implicit_cast(cast_type, expression->location, error_message_template) - return types + typecheck_body(ft, &ifstmt->else_body) - def do_in_place_operation( - self, - location: Location, - target: AstExpression*, # the foo of "foo += 1" - value: AstExpression*, # the 1 of "foo += 1" - op_expr_kind: AstExpressionKind, # e.g. AstExpressionKind::Add - op_description: byte[20], # e.g. "addition" - ) -> None: - ensure_can_take_address(target, "cannot assign to %s") - target_types = self->do_expression(target) - value_types = self->do_expression(value) - - t = check_binop(op_expr_kind, location, target_types, value_types) - temp_value_types = ExpressionTypes{ expression = target, original_type = t } - - error_template: byte[200] - strcpy(error_template, op_description) - strcat(error_template, " produced a value of type which cannot be assigned back to ") - temp_value_types.do_implicit_cast(target_types->original_type, location, error_template) +def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: + msg: byte[500] - # I think it is currently impossible to cast target. - # If this assert fails, we probably need a new error message. - assert target_types->implicit_cast_type == NULL + if stmt->kind == AstStatementKind::If: + typecheck_if_statement(ft, &stmt->if_statement) - def do_statement(self, statement: AstStatement*) -> None: - if statement->kind == AstStatementKind::Assert: - self->do_expression_and_implicit_cast( - &statement->expression, &bool_type, "assertion must be a boolean, not " - ) + elif stmt->kind == AstStatementKind::WhileLoop: + typecheck_expression_with_implicit_cast( + ft, &stmt->while_loop.condition, boolType, + "'while' condition must be a boolean, not ") + typecheck_body(ft, &stmt->while_loop.body) - elif statement->kind == AstStatementKind::ExpressionStatement: - self->do_expression_maybe_void(&statement->expression) + elif stmt->kind == AstStatementKind::ForLoop: + typecheck_statement(ft, stmt->for_loop.init) + typecheck_expression_with_implicit_cast( + ft, &stmt->for_loop.cond, boolType, + "'for' condition must be a boolean, not ") + typecheck_body(ft, &stmt->for_loop.body) + typecheck_statement(ft, stmt->for_loop.incr) - elif statement->kind == AstStatementKind::Return: - sig = &self->current_function_or_method->signature + elif ( + stmt->kind == AstStatementKind::Break + or stmt->kind == AstStatementKind::Continue + or stmt->kind == AstStatementKind::Pass + ): + pass - # TODO: check for noreturn functions + elif stmt->kind == AstStatementKind::Assign: + targetexpr = &stmt->assignment.target + valueexpr = &stmt->assignment.value - msg: byte[500] + if ( + targetexpr->kind == AstExpressionKind::GetVariable + and find_any_var(ft, targetexpr->varname) == NULL + ): + # Making a new variable. Use the type of the value being assigned. + types = typecheck_expression_not_void(ft, valueexpr) + add_variable(ft, types->type, targetexpr->varname) + else: + # Convert value to the type of an existing variable or other assignment target. + ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s") - if statement->return_value != NULL and sig->return_type == NULL: - snprintf( - msg, sizeof msg, - "%s '%s' cannot return a value because it was defined with '-> None'", - sig->function_or_method(), sig->name, - ) - fail(statement->location, msg) - if statement->return_value == NULL and sig->return_type != NULL: - snprintf( - msg, sizeof msg, - "%s '%s' must return a value because it was defined with '-> %s'", - sig->function_or_method(), sig->name, sig->return_type->name, - ) - fail(statement->location, msg) - - if statement->return_value != NULL: - cast_error_msg: byte[500] - snprintf( - cast_error_msg, sizeof cast_error_msg, - "attempting to return a value of type from %s '%s' defined with '-> '", - sig->function_or_method(), sig->name, - ) - self->do_expression_and_implicit_cast( - statement->return_value, sig->return_type, cast_error_msg - ) - - elif statement->kind == AstStatementKind::Assign: - target_expr = &statement->assignment.target - value_expr = &statement->assignment.value - ensure_can_take_address(target_expr, "cannot assign to %s") - - if ( - target_expr->kind == AstExpressionKind::GetVariable - and self->find_var(target_expr->varname) == NULL - ): - # Making a new variable. Use the type of the value being assigned. - types = self->do_expression(value_expr) - self->add_local_var(target_expr->varname, types->original_type) + if targetexpr->kind == AstExpressionKind::Dereference: + strcpy(msg, "cannot place a value of type into a pointer of type *") else: - # Convert value to the type of an existing variable or other assignment target. - # This tends to fail often, so try to produce a helpful error message. - error_template: byte[500] - if target_expr->kind == AstExpressionKind::Dereference: - error_template = "cannot place a value of type into a pointer of type *" - else: - target_description: byte[200] = short_expression_description(target_expr) - snprintf( - error_template, sizeof error_template, - "cannot assign a value of type to %s of type ", - target_description, - ) - - target_types = self->do_expression(target_expr) - self->do_expression_and_implicit_cast(value_expr, target_types->original_type, error_template) - - elif statement->kind == AstStatementKind::DeclareLocalVar: - ntv: AstNameTypeValue* = &statement->var_declaration - if self->find_var(ntv->name) != NULL: - message: byte[200] - snprintf(message, sizeof message, "a variable named '%s' already exists", ntv->name) - fail(statement->location, message) - - type = type_from_ast(self->file_types, &ntv->type) - self->add_local_var(ntv->name, type) - if ntv->value != NULL: - self->do_expression_and_implicit_cast( - ntv->value, type, - "initial value for variable of type cannot be of type ", - ) - - elif statement->kind == AstStatementKind::If: - for i = 0; i < statement->if_statement.n_if_and_elifs; i++: - if i == 0: - template = "'if' condition must be a boolean, not " - else: - template = "'elif' condition must be a boolean, not " - self->do_expression_and_implicit_cast( - &statement->if_statement.if_and_elifs[i].condition, &bool_type, template - ) - self->do_body(&statement->if_statement.if_and_elifs[i].body) - self->do_body(&statement->if_statement.else_body) - - elif statement->kind == AstStatementKind::WhileLoop: - self->do_expression_and_implicit_cast( - &statement->while_loop.condition, &bool_type, - "'while' condition must be a boolean, not ", - ) - self->nested_loop_count++ - self->do_body(&statement->while_loop.body) - self->nested_loop_count-- - - elif statement->kind == AstStatementKind::ForLoop: - self->do_statement(statement->for_loop.init) - self->do_expression_and_implicit_cast( - &statement->for_loop.cond, &bool_type, - "'for' condition must be a boolean, not ", - ) - self->nested_loop_count++ - self->do_body(&statement->for_loop.body) - self->nested_loop_count-- - self->do_statement(statement->for_loop.incr) - - elif statement->kind == AstStatementKind::Pass: - pass - - elif statement->kind == AstStatementKind::Break: - if self->nested_loop_count == 0: - fail(statement->location, "'break' can only be used inside a loop") - - elif statement->kind == AstStatementKind::Continue: - if self->nested_loop_count == 0: - fail(statement->location, "'continue' can only be used inside a loop") - - elif statement->kind == AstStatementKind::InPlaceAdd: - self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Add, "addition") - elif statement->kind == AstStatementKind::InPlaceSubtract: - self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Subtract, "subtraction") - elif statement->kind == AstStatementKind::InPlaceMultiply: - self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Multiply, "multiplication") - elif statement->kind == AstStatementKind::InPlaceDivide: - self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Divide, "division") - elif statement->kind == AstStatementKind::InPlaceModulo: - self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Modulo, "modulo") + snprintf(msg, sizeof msg, + "cannot assign a value of type to %s of type ", + short_expression_description(targetexpr)) + targettypes = typecheck_expression_not_void(ft, targetexpr) + typecheck_expression_with_implicit_cast(ft, valueexpr, targettypes->type, msg) + + elif ( + stmt->kind == AstStatementKind::InPlaceAdd + or stmt->kind == AstStatementKind::InPlaceSubtract + or stmt->kind == AstStatementKind::InPlaceMultiply + or stmt->kind == AstStatementKind::InPlaceDivide + or stmt->kind == AstStatementKind::InPlaceModulo + ): + targetexpr = &stmt->assignment.target + valueexpr = &stmt->assignment.value + + ensure_can_take_address(ft->current_fom_types, targetexpr, "cannot assign to %s") + targettypes = typecheck_expression_not_void(ft, targetexpr) + value_types = typecheck_expression_not_void(ft, valueexpr) + + if stmt->kind == AstStatementKind::InPlaceAdd: + op = AstExpressionKind::Add + opname = "addition" + elif stmt->kind == AstStatementKind::InPlaceSubtract: + op = AstExpressionKind::Subtract + opname = "subtraction" + elif stmt->kind == AstStatementKind::InPlaceMultiply: + op = AstExpressionKind::Multiply + opname = "multiplication" + elif stmt->kind == AstStatementKind::InPlaceDivide: + op = AstExpressionKind::Divide + opname = "division" + elif stmt->kind == AstStatementKind::InPlaceModulo: + op = AstExpressionKind::Modulo + opname = "modulo" else: - statement->print() - printf("*** typecheck: unknown statement kind %d\n", statement->kind) assert False - def do_body(self, body: AstBody*) -> None: - for i = 0; i < body->nstatements; i++: - self->do_statement(&body->statements[i]) + t = check_binop(ft->current_fom_types, op, stmt->location, targettypes, value_types) + tempvalue_types = ExpressionTypes{expr = targetexpr, type = t} - def define_function_or_method(self, signature: Signature*, body: AstBody*) -> None: - assert self->current_function_or_method == NULL - self->file_types->defined_functions = realloc( - self->file_types->defined_functions, - (self->file_types->n_defined_functions + 1) * sizeof self->file_types->defined_functions[0], - ) - self->current_function_or_method = &self->file_types->defined_functions[self->file_types->n_defined_functions++] - *self->current_function_or_method = FunctionOrMethodTypes{signature = signature->copy()} - - for k = 0; k < signature->nargs; k++: - self->add_local_var(signature->argnames[k], signature->argtypes[k]) - - self->do_body(body) - self->current_function_or_method = NULL - - -def typecheck_stage3_function_and_method_bodies(file_types: FileTypes*, ast_file: AstFile*) -> None: - checker = Stage3TypeChecker{file_types = file_types} - for i = 0; i < ast_file->body.nstatements; i++: - ts = &ast_file->body.statements[i] - if ts->kind == AstStatementKind::Function and ts->function.body.nstatements > 0: - signature = file_types->find_function(ts->function.signature.name) - assert signature != NULL - checker.define_function_or_method(signature, &ts->function.body) - elif ts->kind == AstStatementKind::Class: - class_type = file_types->find_type(ts->classdef.name) - assert class_type != NULL # created in previous typecheck stage - assert class_type->kind == TypeKind::Class - for k = 0; k < ts->classdef.nmembers; k++: - if ts->classdef.members[k].kind == AstClassMemberKind::Method: - signature = class_type->class_members.find_method(ts->classdef.members[k].method.signature.name) - checker.define_function_or_method(signature, &ts->classdef.members[k].method.body) + snprintf(msg, sizeof msg, "%s produced a value of type which cannot be assigned back to ", opname) + do_implicit_cast(ft->current_fom_types, &tempvalue_types, targettypes->type, stmt->location, msg) + + # I think it is currently impossible to cast target. + # If this assert fails, we probably need to add another error message for it. + assert targettypes->implicit_cast_type == NULL + + elif stmt->kind == AstStatementKind::Return: + if ft->current_fom_types->signature.is_noreturn: + snprintf(msg, sizeof(msg), + "function '%s' cannot return because it was defined with '-> noreturn'", + ft->current_fom_types->signature.name) + fail(stmt->location, msg) + + return_type = ft->current_fom_types->signature.returntype + + if stmt->return_value != NULL and return_type == NULL: + snprintf(msg, sizeof(msg), "function '%s' cannot return a value because it was defined with '-> None'", + ft->current_fom_types->signature.name) + fail(stmt->location, msg) + + if return_type != NULL and stmt->return_value == NULL: + snprintf(msg, sizeof(msg), + "a return value is needed, because the return type of function '%s' is %s", + ft->current_fom_types->signature.name, + ft->current_fom_types->signature.returntype->name) + fail(stmt->location, msg) + + if stmt->return_value != NULL: + snprintf(msg, sizeof msg, + "attempting to return a value of type from function '%s' defined with '-> '", + ft->current_fom_types->signature.name) + typecheck_expression_with_implicit_cast( + ft, stmt->return_value, find_local_var(ft, "return")->type, msg) + + elif stmt->kind == AstStatementKind::DeclareLocalVar: + if find_any_var(ft, stmt->var_declaration.name) != NULL: + snprintf(msg, sizeof(msg), "a variable named '%s' already exists", stmt->var_declaration.name) + fail(stmt->location, msg) + + type = type_from_ast(ft, &stmt->var_declaration.type) + add_variable(ft, type, stmt->var_declaration.name) + + if stmt->var_declaration.value != NULL: + typecheck_expression_with_implicit_cast( + ft, stmt->var_declaration.value, type, + "initial value for variable of type cannot be of type ") + + elif stmt->kind == AstStatementKind::ExpressionStatement: + typecheck_expression(ft, &stmt->expression) + + elif stmt->kind == AstStatementKind::Assert: + typecheck_expression_with_implicit_cast(ft, &stmt->expression, boolType, "assertion must be a boolean, not ") + + else: + assert False + + +def typecheck_function_or_method_body(ft: FileTypes*, sig: Signature*, body: AstBody*) -> None: + assert ft->current_fom_types == NULL + + ft->fomtypes = realloc(ft->fomtypes, sizeof(ft->fomtypes[0]) * (ft->nfomtypes + 1)) + assert ft->fomtypes != NULL + ft->fomtypes[ft->nfomtypes++] = FunctionOrMethodTypes{} + + ft->current_fom_types = &ft->fomtypes[ft->nfomtypes - 1] + ft->current_fom_types->signature = copy_signature(sig) + + for i = 0; i < sig->nargs; i++: + v = add_variable(ft, sig->argtypes[i], sig->argnames[i]) + v->is_argument = True + + if sig->returntype != NULL: + add_variable(ft, sig->returntype, "return") + + typecheck_body(ft, body) + ft->current_fom_types = NULL + + +def typecheck_stage3_function_and_method_bodies(ft: FileTypes*, ast: AstFile*) -> None: + for i = 0; i < ast->body.nstatements; i++: + stmt = &ast->body.statements[i] + if stmt->kind == AstStatementKind::Function and stmt->function.body.nstatements > 0: + sig: Signature* = NULL + for f = ft->functions; f < &ft->functions[ft->nfunctions]; f++: + if strcmp(f->signature.name, stmt->function.signature.name) == 0: + sig = &f->signature + break + assert sig != NULL + + typecheck_function_or_method_body(ft, sig, &stmt->function.body) + + if stmt->kind == AstStatementKind::Class: + classtype: Type* = NULL + for t = ft->owned_types; t < &ft->owned_types[ft->n_owned_types]; t++: + if strcmp((*t)->name, stmt->classdef.name) == 0: + classtype = *t + break + assert classtype != NULL + + for m = stmt->classdef.members; m < &stmt->classdef.members[stmt->classdef.nmembers]; m++: + if m->kind != AstClassMemberKind::Method: + continue + method = &m->method + + sig = NULL + for s = classtype->classdata.methods; s < &classtype->classdata.methods[classtype->classdata.nmethods]; s++: + if strcmp(s->name, method->signature.name) == 0: + sig = s + break + assert sig != NULL + typecheck_function_or_method_body(ft, sig, &method->body) diff --git a/self_hosted/types.jou b/self_hosted/types.jou index 5656fa0e..c8c845a5 100644 --- a/self_hosted/types.jou +++ b/self_hosted/types.jou @@ -1,28 +1,7 @@ -import "stdlib/str.jou" import "stdlib/mem.jou" - -enum TypeKind: - Bool - SignedInteger - UnsignedInteger - FloatingPoint - Pointer - VoidPointer - Class - OpaqueClass - Enum - Array - -class EnumMembers: - count: int - names: byte[100]* - - # Returns -1 for not found - def find_index(self, name: byte*) -> int: - for i = 0; i < self->count; i++: - if strcmp(self->names[i], name) == 0: - return i - return -1 +import "stdlib/str.jou" +import "./structs.jou" +import "./free.jou" class ClassField: name: byte[100] @@ -31,216 +10,301 @@ class ClassField: # It means that only one of the fields can be used at a time. union_id: int -class ClassMembers: +class ClassData: fields: ClassField* nfields: int methods: Signature* nmethods: int - def find_field(self, name: byte*) -> ClassField*: - for i = 0; i < self->nfields; i++: - if strcmp(self->fields[i].name, name) == 0: - return &self->fields[i] - return NULL +class ArrayType: + item_type: Type* + len: int - def find_method(self, name: byte*) -> Signature*: - for i = 0; i < self->nmethods; i++: - if strcmp(self->methods[i].name, name) == 0: - return &self->methods[i] - return NULL +class EnumType: + count: int + names: byte[100]* -class ArrayInfo: - length: int - item_type: Type* +enum TypeKind: + SignedInteger + UnsignedInteger + Bool + FloatingPoint # float or double + Pointer + VoidPointer + Array + Class + OpaqueClass # class with unknown members. TODO when used? + Enum class Type: - name: byte[100] + name: byte[500] # All types have a name for error messages and debugging. kind: TypeKind - union: size_in_bits: int # SignedInteger, UnsignedInteger, FloatingPoint - value_type: Type* # Pointer (not used for VoidPointer) - enum_members: EnumMembers - class_members: ClassMembers - array: ArrayInfo - - # Pointers and arrays of a given type live as long as the type itself. - # To make it possible, we just store them within the type. - # These are initially NULL and created dynamically as needed. - # - # Do not access these outside this file. - cached_pointer_type: Type* - cached_array_types: Type** - n_cached_array_types: int + value_type: Type* # Pointer + classdata: ClassData # Class + array: ArrayType # Array + enummembers: EnumType def is_integer_type(self) -> bool: return self->kind == TypeKind::SignedInteger or self->kind == TypeKind::UnsignedInteger def is_number_type(self) -> bool: - return self->is_integer_type() or self->kind == TypeKind::FloatingPoint + return is_integer_type(self) or self->kind == TypeKind::FloatingPoint def is_pointer_type(self) -> bool: - return self->kind == TypeKind::Pointer or self->kind == TypeKind::VoidPointer - - def get_pointer_type(self) -> Type*: - if self->cached_pointer_type == NULL: - pointer_name: byte[100] - snprintf(pointer_name, sizeof pointer_name, "%s*", self->name) - - self->cached_pointer_type = malloc(sizeof *self->cached_pointer_type) - *self->cached_pointer_type = Type{ - name = pointer_name, - kind = TypeKind::Pointer, - value_type = self, - } - - return self->cached_pointer_type - - def get_array_type(self, length: int) -> Type*: - assert length > 0 - - for i = 0; i < self->n_cached_array_types; i++: - if self->cached_array_types[i]->array.length == length: - return self->cached_array_types[i] - - array_name: byte[100] - snprintf(array_name, sizeof array_name, "%s[%d]", self->name, length) - - t: Type* = malloc(sizeof *t) - *t = Type{ - name = array_name, - kind = TypeKind::Array, - array = ArrayInfo{length = length, item_type = self}, - } - - self->cached_array_types = realloc(self->cached_array_types, sizeof self->cached_array_types[0] * (self->n_cached_array_types + 1)) - self->cached_array_types[self->n_cached_array_types++] = t - return t - -# Typese are cached into global state, so you can use == between -# pointers to compare them. Also, you don't usually need to copy a -# type, you can just pass around a pointer to it. -global signed_integers: Type[65] # indexed by size in bits (8, 16, 32, 64) -global unsigned_integers: Type[65] # indexed by size in bits (8, 16, 32, 64) -global bool_type: Type -global void_ptr_type: Type -global float_type: Type -global double_type: Type - -# TODO: it seems weird in other files these are pointers but bool_type isn't -global byte_type: Type* -global short_type: Type* -global int_type: Type* -global long_type: Type* + return (self->kind == TypeKind::Pointer or self->kind == TypeKind::VoidPointer) + +class TypeInfo: + type: Type + # TODO: pointer should be TypeInfo* but can't be due to compiler bug + pointer: void* # type that represents a pointer to this type, or NULL + # TODO: arrays should be TypeInfo** but can't be due to compiler bug + arrays: void* # types that represent arrays of this type + narrays: long + + def get_pointer(self) -> TypeInfo*: + return self->pointer + + def get_arrays(self) -> TypeInfo**: + return self->arrays + +# Types are cached into global state. This makes a lot of things easier +# because you don't need to copy and free the types everywhere. This is +# important: previously it was a lot of work to find forgotten copies and +# frees with valgrind. +# +# This also simplifies checking whether two types are the same type: you +# can simply use "==" between two "const Type *" pointers. +# +# Class types are a bit different. When you make a class, you get a +# pointer that you must pass to free_type() later. You can still "==" +# compare types, because two different classes with the same members are +# not the same type. +class GlobalTypeState: + integers: TypeInfo[2][65] # integers[i][j] = i-bit integer, j=1 for signed, j=0 for unsigned + boolean: TypeInfo + doublelele: TypeInfo + floater: TypeInfo + voidptr: TypeInfo + +global global_type_state: GlobalTypeState + +global boolType: Type* # bool +global shortType: Type* # short (16-bit signed) +global intType: Type* # int (32-bit signed) +global longType: Type* # long (64-bit signed) +global byteType: Type* # byte (8-bit unsigned) +global floatType: Type* # float (32-bit) +global doubleType: Type* # double (64-bit) +global voidPtrType: Type* # void* + +# The TypeInfo for type T contains the type T* (if it has been used) +# and all array types with element type T. +def free_pointer_and_array_types(info: TypeInfo*) -> None: + free_type(&info->get_pointer()->type) + for arrtype = info->get_arrays(); arrtype < &info->get_arrays()[info->narrays]; arrtype++: + free_type(&(*arrtype)->type) + free(info->arrays) + +def free_type(t: Type*) -> None: + if t != NULL: + if t->kind == TypeKind::Class: + for m = t->classdata.methods; m < &t->classdata.methods[t->classdata.nmethods]; m++: + free_signature(m) + free(t->classdata.fields) + free(t->classdata.methods) + + ti = t as TypeInfo* + assert &ti->type == t + + free_pointer_and_array_types(ti) + free(t) + +def free_global_type_state() -> None: + free_pointer_and_array_types(&global_type_state.boolean) + free_pointer_and_array_types(&global_type_state.floater) + free_pointer_and_array_types(&global_type_state.doublelele) + free_pointer_and_array_types(&global_type_state.voidptr) + for size = 8; size <= 64; size *= 2: + for is_signed = 0; is_signed <= 1; is_signed++: + free_pointer_and_array_types(&global_type_state.integers[size][is_signed]) def init_types() -> None: - void_ptr_type = Type{name = "void*", kind = TypeKind::VoidPointer} - bool_type = Type{name = "bool", kind = TypeKind::Bool} - float_type = Type{name = "float", size_in_bits = 32, kind = TypeKind::FloatingPoint} - double_type = Type{name = "double", size_in_bits = 64, kind = TypeKind::FloatingPoint} + memset(&global_type_state, 0, sizeof(global_type_state)) + + boolType = &global_type_state.boolean.type + shortType = &global_type_state.integers[16][1].type + intType = &global_type_state.integers[32][1].type + longType = &global_type_state.integers[64][1].type + byteType = &global_type_state.integers[8][0].type + floatType = &global_type_state.floater.type + doubleType = &global_type_state.doublelele.type + voidPtrType = &global_type_state.voidptr.type + + global_type_state.boolean.type = Type{name = "bool", kind = TypeKind::Bool } + global_type_state.voidptr.type = Type{name = "void*", kind = TypeKind::VoidPointer } + global_type_state.floater.type = Type{name = "float", kind = TypeKind::FloatingPoint, size_in_bits = 32 } + global_type_state.doublelele.type = Type{name = "double", kind = TypeKind::FloatingPoint, size_in_bits = 64 } for size = 8; size <= 64; size *= 2: - sprintf(signed_integers[size].name, "<%d-bit signed integer>", size) - sprintf(unsigned_integers[size].name, "<%d-bit unsigned integer>", size) - signed_integers[size].kind = TypeKind::SignedInteger - unsigned_integers[size].kind = TypeKind::UnsignedInteger - signed_integers[size].size_in_bits = size - unsigned_integers[size].size_in_bits = size - - byte_type = &unsigned_integers[8] - short_type = &signed_integers[16] - int_type = &signed_integers[32] - long_type = &signed_integers[64] - - byte_type->name = "byte" - short_type->name = "short" - int_type->name = "int" - long_type->name = "long" + global_type_state.integers[size][0].type.kind = TypeKind::UnsignedInteger + global_type_state.integers[size][1].type.kind = TypeKind::SignedInteger + + global_type_state.integers[size][0].type.size_in_bits = size + global_type_state.integers[size][1].type.size_in_bits = size + + sprintf(global_type_state.integers[size][0].type.name, "<%d-bit unsigned integer>", size) + sprintf(global_type_state.integers[size][1].type.name, "<%d-bit signed integer>", size) + + strcpy(global_type_state.integers[8][0].type.name, "byte") + strcpy(global_type_state.integers[16][1].type.name, "short") + strcpy(global_type_state.integers[32][1].type.name, "int") + strcpy(global_type_state.integers[64][1].type.name, "long") + +def get_integer_type(size_in_bits: int, is_signed: bool) -> Type*: + assert size_in_bits==8 or size_in_bits==16 or size_in_bits==32 or size_in_bits==64 + return &global_type_state.integers[size_in_bits][is_signed as int].type + +def get_pointer_type(t: Type*) -> Type*: + info = t as TypeInfo* + assert t == &info->type # the 'type' field is first member and has 0 bytes offset + + if info->pointer == NULL: + ptr: TypeInfo* = calloc(1, sizeof *ptr) + ptr->type = Type{kind=TypeKind::Pointer, value_type=t} + snprintf(ptr->type.name, sizeof ptr->type.name, "%s*", t->name) + info->pointer = ptr + + return &info->get_pointer()->type + +def get_array_type(t: Type*, len: int) -> Type*: + info = t as TypeInfo* + assert &info->type == t + + assert len > 0 + for existing = info->get_arrays(); existing < &info->get_arrays()[info->narrays]; existing++: + if (*existing)->type.array.len == len: + return &(*existing)->type + + arr: TypeInfo* = calloc(1, sizeof *arr) + arr->type = Type{kind = TypeKind::Array, array = ArrayType{item_type = t, len = len}} + snprintf(arr->type.name, sizeof arr->type.name, "%s[%d]", t->name, len) + info->arrays = realloc(info->arrays, sizeof(info->get_arrays()[0]) * (info->narrays + 1)) + assert info->arrays != NULL + info->get_arrays()[info->narrays++] = arr + return &arr->type + +def is_integer_type(t: Type*) -> bool: + return t->kind == TypeKind::SignedInteger or t->kind == TypeKind::UnsignedInteger + +def is_number_type(t: Type*) -> bool: + return is_integer_type(t) or t->kind == TypeKind::FloatingPoint + +def is_pointer_type(t: Type*) -> bool: + return (t->kind == TypeKind::Pointer or t->kind == TypeKind::VoidPointer) + +def type_of_constant(c: Constant*) -> Type*: + if c->kind == ConstantKind::EnumMember: + return c->enum_member.enumtype + if c->kind == ConstantKind::Null: + return voidPtrType + if c->kind == ConstantKind::Double: + return doubleType + if c->kind == ConstantKind::Float: + return floatType + if c->kind == ConstantKind::Bool: + return boolType + if c->kind == ConstantKind::String: + return get_pointer_type(byteType) + if c->kind == ConstantKind::Integer: + return get_integer_type(c->integer.size_in_bits, c->integer.is_signed) + assert False def create_opaque_class(name: byte*) -> Type*: - result: Type* = malloc(sizeof *result) - *result = Type{kind = TypeKind::OpaqueClass} - assert strlen(name) < sizeof result->name - strcpy(result->name, name) - return result + result: TypeInfo* = calloc(1, sizeof *result) + result->type = Type{kind = TypeKind::OpaqueClass} + + assert strlen(name) < sizeof result->type.name + strcpy(result->type.name, name) -def create_enum(name: byte*, member_count: int, member_names: byte[100]*) -> Type*: - copied_member_names: byte[100]* = malloc(member_count * sizeof copied_member_names[0]) - memcpy(copied_member_names, member_names, member_count * sizeof copied_member_names[0]) + return &result->type - result: Type* = malloc(sizeof *result) - *result = Type{ +def create_enum(name: byte*, membercount: int, membernames: byte[100]*) -> Type*: + result: TypeInfo* = calloc(1, sizeof *result) + result->type = Type{ kind = TypeKind::Enum, - enum_members = EnumMembers{count = member_count, names = copied_member_names}, + enummembers = EnumType{count=membercount, names=membernames}, } - assert strlen(name) < sizeof result->name - strcpy(result->name, name) - return result + assert strlen(name) < sizeof result->type.name + strcpy(result->type.name, name) + + return &result->type + + +def get_self_class(sig: Signature*) -> Type*: + if sig->nargs > 0 and strcmp(sig->argnames[0], "self") == 0: + if sig->argtypes[0]->kind == TypeKind::Pointer: + return sig->argtypes[0]->value_type + if sig->argtypes[0]->kind == TypeKind::Class: + return sig->argtypes[0] + assert False + return NULL + +def signature_to_string(sig: Signature*, include_return_type: bool, include_self: bool) -> byte*: + result = strdup(sig->name) + assert result != NULL + + result = realloc(result, strlen(result) + 2) + assert result != NULL + strcat(result, "(") + + for i = 0; i < sig->nargs; i++: + if strcmp(sig->argnames[i], "self") == 0 and not include_self: + continue + + assert sizeof sig->argnames[i] == 100 + assert sizeof sig->argtypes[i]->name == 500 + result = realloc(result, strlen(result) + 1000) + assert result != NULL + strcat(result, sig->argnames[i]) + strcat(result, ": ") + strcat(result, sig->argtypes[i]->name) + if i < sig->nargs - 1: + strcat(result, ", ") + + result = realloc(result, strlen(result) + 100) + assert result != NULL -class Signature: - name: byte[100] # name of function or method, after "def" keyword - nargs: int - argnames: byte[100]* - argtypes: Type** - takes_varargs: bool # True for functions like printf() - return_type: Type* - - def get_containing_class(self) -> Type*: - for i = 0; i < self->nargs; i++: - if strcmp(self->argnames[i], "self") == 0: - assert self->argtypes[i]->kind == TypeKind::Pointer - assert self->argtypes[i]->value_type->kind == TypeKind::Class - return self->argtypes[i]->value_type - return NULL - - def is_method(self) -> bool: - return self->get_containing_class() != NULL - - def function_or_method(self) -> byte*: - if self->is_method(): - return "method" + if sig->takes_varargs: + if sig->nargs != 0: + strcat(result, ", ") + strcat(result, "...") + strcat(result, ")") + + if include_return_type: + assert sizeof(sig->returntype->name) == 500 + result = realloc(result, strlen(result) + 600) + assert result != NULL + + strcat(result, " -> ") + if sig->is_noreturn: + strcat(result, "noreturn") + elif sig->returntype == NULL: + strcat(result, "void") else: - return "function" + strcat(result, sig->returntype->name) - def to_string(self, include_self: bool, include_return_type: bool) -> byte*: - result: byte* = malloc(500*(self->nargs + 1)) - strcpy(result, self->name) + return result - strcat(result, "(") +def copy_signature(sig: Signature*) -> Signature: + result = *sig - for i = 0; i < self->nargs; i++: - if strcmp(self->argnames[i], "self") == 0 and not include_self: - continue - strcat(result, self->argnames[i]) - strcat(result, ": ") - strcat(result, self->argtypes[i]->name) - strcat(result, ", ") + result.argtypes = malloc(sizeof(result.argtypes[0]) * result.nargs) + memcpy(result.argtypes, sig->argtypes, sizeof(result.argtypes[0]) * result.nargs) - if self->takes_varargs: - strcat(result, "...") - elif ends_with(result, ", "): - result[strlen(result)-2] = '\0' - - strcat(result, ")") - - if include_return_type: - if self->return_type == NULL: - strcat(result, " -> None") - else: - strcat(result, " -> ") - strcat(result, self->return_type->name) - - return result - - def copy(self) -> Signature: - result = *self - result.argnames = malloc(result.nargs * sizeof(result.argnames[0])) - result.argtypes = malloc(result.nargs * sizeof(result.argtypes[0])) - memcpy(result.argnames, self->argnames, result.nargs * sizeof(result.argnames[0])) - memcpy(result.argtypes, self->argtypes, result.nargs * sizeof(result.argtypes[0])) - return result - - def free(self) -> None: - free(self->argnames) - free(self->argtypes) + result.argnames = malloc(sizeof(result.argnames[0]) * result.nargs) + memcpy(result.argnames, sig->argnames, sizeof(result.argnames[0]) * result.nargs) + + return result diff --git a/self_hosted/update.jou b/self_hosted/update.jou new file mode 100644 index 00000000..b1f95a49 --- /dev/null +++ b/self_hosted/update.jou @@ -0,0 +1,63 @@ +# Self-update: "jou --update" updates the Jou compiler. + +import "stdlib/ascii.jou" +import "stdlib/errno.jou" +import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/mem.jou" +import "stdlib/process.jou" +import "./paths.jou" + +# TODO: add some kind of chdir to standard library +if WINDOWS: + declare _chdir(dirname: byte*) -> int + + def chdir(dir: byte*) -> int: + return _chdir(dir) +else: + declare chdir(path: byte*) -> int + + +def fail_update() -> None: + puts("") + puts("Updating Jou failed. If you need help, please create an issue on GitHub:") + puts(" https://github.com/Akuli/jou/issues/new") + exit(1) + +def confirm(prompt: byte*) -> None: + printf("%s (y/n) ", prompt) + fflush(stdout) + + line: byte[50] + fgets(line, sizeof(line) as int, stdin) + trim_ascii_whitespace(line) + + yes = strcmp(line, "Y") == 0 or strcmp(line, "y") == 0 + if not yes: + printf("Aborted.\n") + exit(1) + +def update_jou_compiler() -> None: + exe = find_current_executable() + exedir = dirname(exe) + printf("Installation directory: %s\n\n", exedir) + + if chdir(exedir) == -1: + fprintf(stderr, "chdir(\"%s\") failed: %s\n", exedir, strerror(get_errno())) + fail_update() + + if WINDOWS: + confirm("Download and install the latest version of Jou from GitHub releases?") + if system("powershell -ExecutionPolicy bypass -File update.ps1") != 0: + fail_update() + elif NETBSD: + confirm("Run \"git pull && gmake\"?") + if system("git pull && gmake") != 0: + fail_update() + else: + confirm("Run \"git pull && make\"?") + if system("git pull && make") != 0: + fail_update() + + free(exe) + printf("\n\nYou now have the latest version of Jou :)\n") diff --git a/self_hosted_old/ast.jou b/self_hosted_old/ast.jou new file mode 100644 index 00000000..bcf3b890 --- /dev/null +++ b/self_hosted_old/ast.jou @@ -0,0 +1,837 @@ +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" +import "./errors_and_warnings.jou" + +# TODO: move to stdlib +declare isprint(b: int) -> int + +enum AstTypeKind: + Named + Pointer + Array + +class AstArrayType: + member_type: AstType* + length: AstExpression* + + def free(self) -> None: + self->member_type->free() + self->length->free() + free(self->member_type) + free(self->length) + +class AstType: + kind: AstTypeKind + location: Location + + union: + name: byte[100] # AstTypeKind::Named + value_type: AstType* # AstTypeKind::Pointer + array: AstArrayType # AstTypeKind::Array + + def is_void(self) -> bool: + return self->kind == AstTypeKind::Named and strcmp(self->name, "void") == 0 + + def is_none(self) -> bool: + return self->kind == AstTypeKind::Named and strcmp(self->name, "None") == 0 + + def is_noreturn(self) -> bool: + return self->kind == AstTypeKind::Named and strcmp(self->name, "noreturn") == 0 + + def print(self, show_lineno: bool) -> None: + if self->kind == AstTypeKind::Named: + printf("%s", self->name) + elif self->kind == AstTypeKind::Pointer: + self->value_type->print(False) + printf("*") + elif self->kind == AstTypeKind::Array: + self->array.member_type->print(False) + printf("[]") # TODO: show the size expression better? + else: + assert False + + if show_lineno: + printf(" [line %d]", self->location.lineno) + + def free(self) -> None: + if self->kind == AstTypeKind::Pointer: + self->value_type->free() + free(self->value_type) + if self->kind == AstTypeKind::Array: + self->array.free() + +# Statements and expressions can be printed in a tree. +# To see a tree, run: +# +# $ jou --parse-only examples/hello.jou +# +class TreePrinter: + prefix: byte[100] + + # Returned subprinter can be used to print elements "inside" the current line. + def print_prefix(self, is_last_child: bool) -> TreePrinter: + subprinter = TreePrinter{} + if is_last_child: + printf("%s`--- ", self->prefix) + snprintf(subprinter.prefix, sizeof subprinter.prefix, "%s ", self->prefix) + else: + printf("%s|--- ", self->prefix) + snprintf(subprinter.prefix, sizeof subprinter.prefix, "%s| ", self->prefix) + return subprinter + +enum AstExpressionKind: + String + Int + Short + Long + Byte + Float + Double + Bool + Null + Array + Call # function call or method call + Instantiate # MyClass{x=1, y=2} + Self # not a variable lookup, so you can't use 'self' as variable name outside a class + GetVariable + GetEnumMember + GetClassField + As + # unary operators + SizeOf # sizeof x + AddressOf # &x + Dereference # *x + Negate # -x + Not # not x + PreIncr # ++x + PostIncr # x++ + PreDecr # --x + PostDecr # x-- + # binary operators + Add # x+y + Subtract # x-y + Multiply # x*y + Divide # x/y + Indexing # x[y] + Modulo # x % y + Eq # x == y + Ne # x != y + Gt # x > y + Lt # x < y + Ge # x >= y + Le # x <= y + And # x and y + Or # x or y + +class AstExpression: + location: Location + kind: AstExpressionKind + + union: + enum_member: AstEnumMember + class_field: AstClassField + string: byte* + int_value: int + short_value: short + long_value: long + byte_value: byte + bool_value: bool + call: AstCall + instantiation: AstInstantiation + as_expression: AstAsExpression* # Must be pointer, because it contains an AstExpression + array: AstArray + varname: byte[100] + float_or_double_text: byte[100] + operands: AstExpression* # Only for operators. Length is arity, see get_arity() + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("[line %d] ", self->location.lineno) + if self->kind == AstExpressionKind::String: + printf("\"") + for s = self->string; *s != 0; s++: + if isprint(*s) != 0: + putchar(*s) + elif *s == '\n': + printf("\\n") + else: + printf("\\x%02x", *s) + printf("\"\n") + elif self->kind == AstExpressionKind::Short: + printf("%hd (16-bit signed)\n", self->short_value) + elif self->kind == AstExpressionKind::Int: + printf("%d (32-bit signed)\n", self->int_value) + elif self->kind == AstExpressionKind::Long: + printf("%lld (64-bit signed)\n", self->long_value) + elif self->kind == AstExpressionKind::Byte: + printf("%d (8-bit unsigned)\n", self->byte_value) + elif self->kind == AstExpressionKind::Float: + printf("float %s\n", self->float_or_double_text) + elif self->kind == AstExpressionKind::Double: + printf("double %s\n", self->float_or_double_text) + elif self->kind == AstExpressionKind::Bool: + if self->bool_value: + printf("True\n") + else: + printf("False\n") + elif self->kind == AstExpressionKind::Null: + printf("NULL\n") + elif self->kind == AstExpressionKind::Indexing: + printf("indexing\n") + elif self->kind == AstExpressionKind::Array: + printf("array\n") + for i = 0; i < self->array.length; i++: + self->array.items[i].print_with_tree_printer(tp.print_prefix(i == self->array.length-1)) + elif self->kind == AstExpressionKind::Call: + if self->call.uses_arrow_operator: + printf("dereference and ") + printf("call %s \"%s\"\n", self->call.function_or_method(), self->call.name) + self->call.print(tp) + elif self->kind == AstExpressionKind::Instantiate: + printf("instantiate \"%s\"\n", self->instantiation.class_name) + self->instantiation.print(tp) + elif self->kind == AstExpressionKind::Self: + printf("self\n") + elif self->kind == AstExpressionKind::GetVariable: + printf("get variable \"%s\"\n", self->varname) + elif self->kind == AstExpressionKind::GetEnumMember: + printf( + "get member \"%s\" from enum \"%s\"\n", + self->enum_member.member_name, + self->enum_member.enum_name, + ) + elif self->kind == AstExpressionKind::GetClassField: + if self->class_field.uses_arrow_operator: + printf("dereference and ") + printf("get class field \"%s\"\n", self->class_field.field_name) + self->class_field.instance->print_with_tree_printer(tp.print_prefix(True)) + elif self->kind == AstExpressionKind::As: + printf("as ") + self->as_expression->type.print(True) + printf("\n") + self->as_expression->value.print_with_tree_printer(tp.print_prefix(True)) + elif self->kind == AstExpressionKind::SizeOf: + printf("sizeof\n") + elif self->kind == AstExpressionKind::AddressOf: + printf("address of\n") + elif self->kind == AstExpressionKind::Dereference: + printf("dereference\n") + elif self->kind == AstExpressionKind::Negate: + printf("negate\n") + elif self->kind == AstExpressionKind::Not: + printf("not\n") + elif self->kind == AstExpressionKind::PreIncr: + printf("pre-increment\n") + elif self->kind == AstExpressionKind::PostIncr: + printf("post-increment\n") + elif self->kind == AstExpressionKind::PreDecr: + printf("pre-decrement\n") + elif self->kind == AstExpressionKind::PostDecr: + printf("post-decrement\n") + elif self->kind == AstExpressionKind::Add: + printf("add\n") + elif self->kind == AstExpressionKind::Subtract: + printf("sub\n") + elif self->kind == AstExpressionKind::Multiply: + printf("mul\n") + elif self->kind == AstExpressionKind::Divide: + printf("div\n") + elif self->kind == AstExpressionKind::Modulo: + printf("mod\n") + elif self->kind == AstExpressionKind::Eq: + printf("eq\n") + elif self->kind == AstExpressionKind::Ne: + printf("ne\n") + elif self->kind == AstExpressionKind::Gt: + printf("gt\n") + elif self->kind == AstExpressionKind::Ge: + printf("ge\n") + elif self->kind == AstExpressionKind::Lt: + printf("lt\n") + elif self->kind == AstExpressionKind::Le: + printf("le\n") + elif self->kind == AstExpressionKind::And: + printf("and\n") + elif self->kind == AstExpressionKind::Or: + printf("or\n") + else: + printf("?????\n") + + for i = 0; i < self->get_arity(); i++: + self->operands[i].print_with_tree_printer(tp.print_prefix(i == self->get_arity()-1)) + + def free(self) -> None: + if self->kind == AstExpressionKind::Call: + self->call.free() + elif self->kind == AstExpressionKind::As: + self->as_expression->free() + free(self->as_expression) + elif self->kind == AstExpressionKind::String: + free(self->string) + elif self->kind == AstExpressionKind::GetClassField: + self->class_field.free() + + if self->get_arity() != 0: + for i = 0; i < self->get_arity(); i++: + self->operands[i].free() + free(self->operands) + + # arity = number of operands, e.g. 2 for a binary operator such as "+" + def get_arity(self) -> int: + if ( + self->kind == AstExpressionKind::SizeOf + or self->kind == AstExpressionKind::AddressOf + or self->kind == AstExpressionKind::Dereference + or self->kind == AstExpressionKind::Negate + or self->kind == AstExpressionKind::Not + or self->kind == AstExpressionKind::PreIncr + or self->kind == AstExpressionKind::PreDecr + or self->kind == AstExpressionKind::PostIncr + or self->kind == AstExpressionKind::PostDecr + ): + return 1 + if ( + self->kind == AstExpressionKind::Add + or self->kind == AstExpressionKind::Subtract + or self->kind == AstExpressionKind::Multiply + or self->kind == AstExpressionKind::Divide + or self->kind == AstExpressionKind::Indexing + or self->kind == AstExpressionKind::Modulo + or self->kind == AstExpressionKind::Eq + or self->kind == AstExpressionKind::Ne + or self->kind == AstExpressionKind::Gt + or self->kind == AstExpressionKind::Lt + or self->kind == AstExpressionKind::Ge + or self->kind == AstExpressionKind::Le + or self->kind == AstExpressionKind::And + or self->kind == AstExpressionKind::Or + ): + return 2 + return 0 + + def can_have_side_effects(self) -> bool: + return ( + self->kind == AstExpressionKind::Call + or self->kind == AstExpressionKind::PreIncr + or self->kind == AstExpressionKind::PreDecr + or self->kind == AstExpressionKind::PostIncr + or self->kind == AstExpressionKind::PostDecr + ) + +class AstArray: + length: int + items: AstExpression* + + def free(self) -> None: + for i = 0; i < self->length; i++: + self->items[i].free() + free(self->items) + +class AstEnumMember: + enum_name: byte[100] + member_name: byte[100] + +class AstClassField: + instance: AstExpression* + uses_arrow_operator: bool # distinguishes foo.bar and foo->bar + field_name: byte[100] + + def free(self) -> None: + self->instance->free() + free(self->instance) + +class AstAsExpression: + value: AstExpression + type: AstType + + def free(self) -> None: + self->value.free() + self->type.free() + +class AstCall: + location: Location + name: byte[100] # name of function or method + method_call_self: AstExpression* # NULL for function calls, the foo of foo.bar() for method calls + uses_arrow_operator: bool # distinguishes foo->bar() and foo.bar() + nargs: int + args: AstExpression* + + # Useful for formatting error messages, but not much else. + def function_or_method(self) -> byte*: + if self->method_call_self == NULL: + return "function" + else: + return "method" + + def print(self, tp: TreePrinter) -> None: + if self->method_call_self != NULL: + sub = tp.print_prefix(self->nargs == 0) + printf("self: ") + self->method_call_self->print_with_tree_printer(sub) + + for i = 0; i < self->nargs; i++: + sub = tp.print_prefix(i == self->nargs - 1) + printf("argument %d: ", i) + self->args[i].print_with_tree_printer(sub) + + def free(self) -> None: + for i = 0; i < self->nargs; i++: + self->args[i].free() + free(self->args) + +class AstInstantiation: + class_name_location: Location + class_name: byte[100] + nfields: int + field_names: byte[100]* + field_values: AstExpression* + + def print(self, tp: TreePrinter) -> None: + for i = 0; i < self->nfields; i++: + sub = tp.print_prefix(i == self->nfields - 1) + printf("field \"%s\": ", self->field_names[i]) + self->field_values[i].print_with_tree_printer(sub) + + def free(self) -> None: + for i = 0; i < self->nfields; i++: + self->field_values[i].free() + free(self->field_names) + free(self->field_values) + +class AstAssertion: + condition: AstExpression + condition_str: byte* + +enum AstStatementKind: + ExpressionStatement # Evaluate an expression. Discard the result. + Assert + Pass + Return + If + WhileLoop + ForLoop + Break + Continue + DeclareLocalVar # x: SomeType = y (the "= y" is optional) + Assign # x = y + InPlaceAdd # x += y + InPlaceSubtract # x -= y + InPlaceMultiply # x *= y + InPlaceDivide # x /= y + InPlaceModulo # x %= y + Function + Class + Enum + GlobalVariableDeclaration + GlobalVariableDefinition + +class AstStatement: + location: Location + kind: AstStatementKind + + union: + expression: AstExpression # ExpressionStatement, Assert + if_statement: AstIfStatement + while_loop: AstConditionAndBody + for_loop: AstForLoop + return_value: AstExpression* # can be NULL + assignment: AstAssignment + var_declaration: AstNameTypeValue # DeclareLocalVar + function: AstFunctionOrMethod + classdef: AstClassDef + enumdef: AstEnumDef + assertion: AstAssertion + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("[line %d] ", self->location.lineno) + if self->kind == AstStatementKind::ExpressionStatement: + printf("expression statement\n") + self->expression.print_with_tree_printer(tp.print_prefix(True)) + elif self->kind == AstStatementKind::Assert: + printf("assert \"%s\"\n", self->assertion.condition_str) + self->assertion.condition.print_with_tree_printer(tp.print_prefix(True)) + elif self->kind == AstStatementKind::Pass: + printf("pass\n") + elif self->kind == AstStatementKind::Return: + printf("return\n") + if self->return_value != NULL: + self->return_value->print_with_tree_printer(tp.print_prefix(True)) + elif self->kind == AstStatementKind::If: + printf("if\n") + self->if_statement.print(tp) + elif self->kind == AstStatementKind::ForLoop: + printf("for loop\n") + self->for_loop.print(tp) + elif self->kind == AstStatementKind::WhileLoop: + printf("while loop\n") + self->while_loop.print_with_tree_printer(tp, True) + elif self->kind == AstStatementKind::Break: + printf("break\n") + elif self->kind == AstStatementKind::Continue: + printf("continue\n") + elif self->kind == AstStatementKind::DeclareLocalVar: + printf("declare local var ") + self->var_declaration.print_with_tree_printer(&tp) + elif self->kind == AstStatementKind::Assign: + printf("assign\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::InPlaceAdd: + printf("in-place add\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::InPlaceSubtract: + printf("in-place sub\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::InPlaceMultiply: + printf("in-place mul\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::InPlaceDivide: + printf("in-place div\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::InPlaceModulo: + printf("in-place mod\n") + self->assignment.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::Function: + if self->function.body.nstatements == 0: + printf("declare a function: ") + else: + printf("define a function: ") + self->function.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::Class: + printf("define a ") + self->classdef.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::Enum: + printf("define ") + self->enumdef.print_with_tree_printer(tp) + elif self->kind == AstStatementKind::GlobalVariableDeclaration: + printf("declare global var ") + self->var_declaration.print_with_tree_printer(NULL) + printf("\n") + elif self->kind == AstStatementKind::GlobalVariableDefinition: + printf("define global var ") + self->var_declaration.print_with_tree_printer(NULL) + printf("\n") + else: + printf("??????\n") + + def free(self) -> None: + if self->kind == AstStatementKind::Enum: + self->enumdef.free() + if self->kind == AstStatementKind::ExpressionStatement: + self->expression.free() + if self->kind == AstStatementKind::Return and self->return_value != NULL: + self->return_value->free() + free(self->return_value) + if self->kind == AstStatementKind::If: + self->if_statement.free() + if self->kind == AstStatementKind::ForLoop: + self->for_loop.free() + +# Useful for e.g. "while condition: body", "if condition: body" +class AstConditionAndBody: + condition: AstExpression + body: AstBody + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}, True) + + def print_with_tree_printer(self, tp: TreePrinter, body_is_last_sub_item: bool) -> None: + sub = tp.print_prefix(False) + printf("condition: ") + self->condition.print_with_tree_printer(sub) + + sub = tp.print_prefix(body_is_last_sub_item) + printf("body:\n") + self->body.print_with_tree_printer(sub) + + def free(self) -> None: + self->condition.free() + self->body.free() + +class AstAssignment: + target: AstExpression + value: AstExpression + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + self->target.print_with_tree_printer(tp.print_prefix(False)) + self->value.print_with_tree_printer(tp.print_prefix(True)) + +class AstIfStatement: + if_and_elifs: AstConditionAndBody* + n_if_and_elifs: int # At least 1 (the if statement). The rest, if any, are elifs. + else_body: AstBody # Empty if there is no else + + def print(self, tp: TreePrinter) -> None: + for i = 0; i < self->n_if_and_elifs; i++: + self->if_and_elifs[i].print_with_tree_printer(tp, i == self->n_if_and_elifs - 1 and self->else_body.nstatements == 0) + + if self->else_body.nstatements > 0: + sub = tp.print_prefix(True) + printf("else body:\n") + self->else_body.print_with_tree_printer(sub) + + def free(self) -> None: + for i = 0; i < self->n_if_and_elifs; i++: + self->if_and_elifs[i].free() + free(self->if_and_elifs) + self->else_body.free() + +class AstForLoop: + # for init; cond; incr: + # ...body... + # + # init and incr must be pointers because this struct goes inside AstStatement. + init: AstStatement* + cond: AstExpression + incr: AstStatement* + body: AstBody + + def print(self, tp: TreePrinter) -> None: + sub = tp.print_prefix(False) + printf("init: ") + self->init->print_with_tree_printer(sub) + + sub = tp.print_prefix(False) + printf("cond: ") + self->cond.print_with_tree_printer(sub) + + sub = tp.print_prefix(False) + printf("incr: ") + self->incr->print_with_tree_printer(sub) + + sub = tp.print_prefix(True) + printf("body:\n") + self->body.print_with_tree_printer(sub) + + def free(self) -> None: + self->init->free() + free(self->init) + self->cond.free() + self->incr->free() + free(self->incr) + self->body.free() + +class AstNameTypeValue: + # name: type = value + name: byte[100] + name_location: Location + type: AstType + value: AstExpression* # can be NULL + + def print(self) -> None: + tp = TreePrinter{} + self->print_with_tree_printer(&tp) + + # tp can be set to NULL, in that case no trailing newline is printed + def print_with_tree_printer(self, tp: TreePrinter*) -> None: + printf("%s: ", self->name) + self->type.print(True) + if tp == NULL: + assert self->value == NULL + else: + printf("\n") + if self->value != NULL: + sub = tp->print_prefix(True) + printf("initial value: ") + self->value->print_with_tree_printer(sub) + + def free(self) -> None: + if self->value != NULL: + self->value->free() + free(self->value) + +class AstBody: + statements: AstStatement* + nstatements: int + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + for i = 0; i < self->nstatements; i++: + self->statements[i].print_with_tree_printer(tp.print_prefix(i == self->nstatements - 1)) + + def free(self) -> None: + for i = 0; i < self->nstatements; i++: + self->statements[i].free() + free(self->statements) + +class AstSignature: + name_location: Location + name: byte[100] # name of function or method, after "def" keyword + args: AstNameTypeValue* + nargs: int + takes_varargs: bool # True for functions like printf() + return_type: AstType + + def print(self) -> None: + printf("%s(", self->name) + for i = 0; i < self->nargs; i++: + if i != 0: + printf(", ") + + if ( + strcmp(self->args[i].name, "self") == 0 + and self->args[i].type.kind == AstTypeKind::Named + and self->args[i].type.name[0] == '\0' + ): + # self with implicitly given type + printf("self") + else: + self->args[i].print_with_tree_printer(NULL) + + if self->takes_varargs: + if self->nargs != 0: + printf(", ") + printf("...") + + printf(") -> ") + self->return_type.print(True) + printf("\n") + + def free(self) -> None: + self->return_type.free() + +class AstImport: + location: Location + specified_path: byte* # Path in jou code e.g. "stdlib/io.jou" + resolved_path: byte* # Absolute path or relative to current working directory e.g. "/home/akuli/jou/stdlib/io.jou" + + def print(self) -> None: + printf( + "line %d: Import \"%s\", which resolves to \"%s\".\n", + self->location.lineno, self->specified_path, self->resolved_path) + + def free(self) -> None: + free(self->specified_path) + free(self->resolved_path) + +class AstFile: + path: byte* # not owned + imports: AstImport* + nimports: int + body: AstBody + + def print(self) -> None: + printf("===== AST for file \"%s\" =====\n", self->path) + for i = 0; i < self->nimports; i++: + self->imports[i].print() + for i = 0; i < self->body.nstatements; i++: + self->body.statements[i].print() + + def free(self) -> None: + for i = 0; i < self->nimports; i++: + self->imports[i].free() + free(self->imports) + self->body.free() + +class AstFunctionOrMethod: + signature: AstSignature + body: AstBody # empty body means declaration, otherwise it's a definition + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + self->signature.print() + self->body.print_with_tree_printer(tp) + + def free(self) -> None: + self->signature.free() + self->body.free() + +class AstUnionFields: + fields: AstNameTypeValue* + nfields: int + + def print(self, tp: TreePrinter) -> None: + for i = 0; i < self->nfields; i++: + subprinter = tp.print_prefix(i == self->nfields-1) + self->fields[i].print_with_tree_printer(&subprinter) # TODO: does this need to be optional/pointer? + + def free(self) -> None: + for i = 0; i < self->nfields; i++: + self->fields[i].free() + free(self->fields) + +enum AstClassMemberKind: + Field + Union + Method + +class AstClassMember: + kind: AstClassMemberKind + union: + field: AstNameTypeValue + union_fields: AstUnionFields + method: AstFunctionOrMethod + + def print(self, tp: TreePrinter) -> None: + if self->kind == AstClassMemberKind::Field: + printf("field ") + self->field.print_with_tree_printer(NULL) + printf("\n") + elif self->kind == AstClassMemberKind::Union: + printf("union:\n") + self->union_fields.print(tp) + elif self->kind == AstClassMemberKind::Method: + printf("method ") + self->method.signature.print() + self->method.body.print_with_tree_printer(tp) + else: + assert False + + def free(self) -> None: + if self->kind == AstClassMemberKind::Field: + self->field.free() + elif self->kind == AstClassMemberKind::Union: + self->union_fields.free() + elif self->kind == AstClassMemberKind::Method: + self->method.free() + else: + assert False + +class AstClassDef: + name: byte[100] + name_location: Location + members: AstClassMember* + nmembers: int + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("class \"%s\" with %d members\n", self->name, self->nmembers) + for i = 0; i < self->nmembers; i++: + self->members[i].print(tp.print_prefix(i == self->nmembers-1)) + + def free(self) -> None: + for i = 0; i < self->nmembers; i++: + self->members[i].free() + free(self->members) + +class AstEnumDef: + name: byte[100] + name_location: Location + member_count: int + member_names: byte[100]* + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("enum \"%s\" with %d members\n", self->name, self->member_count) + for i = 0; i < self->member_count; i++: + tp.print_prefix(i == self->member_count-1) + puts(self->member_names[i]) + + def free(self) -> None: + free(self->member_names) diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted_old/create_llvm_ir.jou similarity index 100% rename from self_hosted/create_llvm_ir.jou rename to self_hosted_old/create_llvm_ir.jou diff --git a/self_hosted_old/errors_and_warnings.jou b/self_hosted_old/errors_and_warnings.jou new file mode 100644 index 00000000..0e520cd2 --- /dev/null +++ b/self_hosted_old/errors_and_warnings.jou @@ -0,0 +1,19 @@ +import "stdlib/process.jou" +import "stdlib/io.jou" + +class Location: + path: byte* # Not owned. Points to a string that is held elsewhere. + lineno: int + +def fail(location: Location, message: byte*) -> noreturn: + # When stdout is redirected to same place as stderr, + # make sure that normal printf()s show up before our error. + fflush(stdout) + fflush(stderr) + + fprintf(stderr, "compiler error in file \"%s\"", location.path) + if location.lineno != 0: + fprintf(stderr, ", line %d", location.lineno) + fprintf(stderr, ": %s\n", message) + + exit(1) diff --git a/self_hosted_old/evaluate.jou b/self_hosted_old/evaluate.jou new file mode 100644 index 00000000..355885f1 --- /dev/null +++ b/self_hosted_old/evaluate.jou @@ -0,0 +1,72 @@ +# Compile-time evaluating if statements. + +import "./ast.jou" +import "./errors_and_warnings.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" + + +# Return values: 1=true, 0=false, -1=unknown +def get_special_constant(name: byte*) -> int: + if strcmp(name, "WINDOWS") == 0: + return WINDOWS as int + if strcmp(name, "MACOS") == 0: + return MACOS as int + if strcmp(name, "NETBSD") == 0: + return NETBSD as int + return -1 + + +def evaluate_condition(expr: AstExpression*) -> bool: + if expr->kind == AstExpressionKind::GetVariable: + v = get_special_constant(expr->varname) + if v == 1: + return True + if v == 0: + return False + + if expr->kind == AstExpressionKind::And: + return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1]) + if expr->kind == AstExpressionKind::Or: + return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1]) + if expr->kind == AstExpressionKind::Not: + return not evaluate_condition(&expr->operands[0]) + + fail(expr->location, "cannot evaluate condition at compile time") + + +# returns the statements to replace if statement with +def evaluate_compile_time_if_statement(if_stmt: AstIfStatement*) -> AstBody: + result = &if_stmt->else_body + for p = if_stmt->if_and_elifs; p < &if_stmt->if_and_elifs[if_stmt->n_if_and_elifs]; p++: + if evaluate_condition(&p->condition): + result = &p->body + break + + ret = *result + *result = AstBody{} # avoid double-free + return ret + + +# Replace body->statements[i] with zero or more statements from another body. +def replace(body: AstBody*, i: int, new: AstBody) -> None: + body->statements[i].free() + + item_size = sizeof(body->statements[0]) + body->statements = realloc(body->statements, (body->nstatements + new.nstatements) * item_size) + memmove(&body->statements[i + new.nstatements], &body->statements[i+1], (body->nstatements - (i+1)) * item_size) + memcpy(&body->statements[i], new.statements, new.nstatements * item_size) + + free(new.statements) + body->nstatements-- + body->nstatements += new.nstatements + + +# This handles nested if statements. +def evaluate_compile_time_if_statements_in_body(body: AstBody*) -> None: + i = 0 + while i < body->nstatements: + if body->statements[i].kind == AstStatementKind::If: + replace(body, i, evaluate_compile_time_if_statement(&body->statements[i].if_statement)) + else: + i++ diff --git a/self_hosted_old/llvm.jou b/self_hosted_old/llvm.jou new file mode 100644 index 00000000..56817086 --- /dev/null +++ b/self_hosted_old/llvm.jou @@ -0,0 +1,278 @@ +class LLVMModule: + _dummy: int +class LLVMType: + _dummy: int +class LLVMValue: + _dummy: int +class LLVMBasicBlock: + _dummy: int +class LLVMBuilder: + _dummy: int +class LLVMPassManager: + _dummy: int + +class LLVMTarget: + _dummy: int +class LLVMTargetData: + _dummy: int +class LLVMTargetMachine: + _dummy: int + +# =========== Target.h =========== +declare LLVMInitializeX86TargetInfo() -> None +declare LLVMInitializeX86Target() -> None +declare LLVMInitializeX86TargetMC() -> None +declare LLVMInitializeX86AsmPrinter() -> None +declare LLVMInitializeX86AsmParser() -> None +declare LLVMInitializeX86Disassembler() -> None + +declare LLVMDisposeTargetData(TD: LLVMTargetData*) -> None +declare LLVMCopyStringRepOfTargetData(TD: LLVMTargetData*) -> byte* + +declare LLVMStoreSizeOfType(TD: LLVMTargetData*, Ty: LLVMType*) -> long +declare LLVMABISizeOfType(TD: LLVMTargetData*, Ty: LLVMType*) -> long + + + +# =========== TargetMachine.h =========== +enum LLVMCodeGenOptLevel: + none # can't make it None because that is a keyword + Less + Default + Aggressive + +enum LLVMRelocMode: + Default + Static + PIC + DynamicNoPic + ROPI + RWPI + ROPI_RWPI + +enum LLVMCodeModel: + Default + JITDefault + Tiny + Small + Kernel + Medium + Large + +enum LLVMCodeGenFileType: + AssemblyFile + ObjectFile + +declare LLVMCreateTargetMachine(T: LLVMTarget*, Triple: byte*, CPU: byte*, Features: byte*, Level: LLVMCodeGenOptLevel, Reloc: LLVMRelocMode, CodeModel: LLVMCodeModel) -> LLVMTargetMachine* +declare LLVMDisposeTargetMachine(T: LLVMTargetMachine*) -> None +declare LLVMCreateTargetDataLayout(T: LLVMTargetMachine*) -> LLVMTargetData* +declare LLVMTargetMachineEmitToFile(T: LLVMTargetMachine*, M: LLVMModule*, Filename: byte*, codegen: LLVMCodeGenFileType, ErrorMessage: byte**) -> int +declare LLVMGetTargetFromTriple(Triple: byte*, T: LLVMTarget**, ErrorMessage: byte**) -> int +declare LLVMGetDefaultTargetTriple() -> byte* + +# =========== Linker.h =========== +declare LLVMLinkModules2(Dest: LLVMModule*, Src: LLVMModule*) -> int + +# =========== Analysis.h =========== +enum LLVMVerifierFailureAction: + AbortProcess + PrintMessage + ReturnStatus + +declare LLVMVerifyModule(M: LLVMModule*, Action: LLVMVerifierFailureAction, OutMessage: byte**) -> int + +# =========== Core.h =========== +enum LLVMTypeKind: + Void + Half + Float + Double + X86_FP80 + FP128 + PPC_FP128 + Label + Integer + Function + Struct + Array + Pointer + Vector + Metadata + X86_MMX + Token + ScalableVector + BFloat + +enum LLVMLinkage: + External + AvailableExternally + LinkOnceAny + LinkOnceODR + Obsolete1 + WeakAny + WeakODR + Appending + Internal + Private + Obsolete2 + Obsolete3 + ExternalWeak + Obsolete4 + Common + LinkerPrivate + LinkerPrivateWeak + +enum LLVMIntPredicate: + # TODO: a better way to start the enum at 32 + Dummy0 + Dummy1 + Dummy2 + Dummy3 + Dummy4 + Dummy5 + Dummy6 + Dummy7 + Dummy8 + Dummy9 + Dummy10 + Dummy11 + Dummy12 + Dummy13 + Dummy14 + Dummy15 + Dummy16 + Dummy17 + Dummy18 + Dummy19 + Dummy20 + Dummy21 + Dummy22 + Dummy23 + Dummy24 + Dummy25 + Dummy26 + Dummy27 + Dummy28 + Dummy29 + Dummy30 + Dummy31 + EQ + NE + UGT + UGE + ULT + ULE + SGT + SGE + SLT + SLE + +enum LLVMRealPredicate: + AlwaysFalse + OEQ + OGT + OGE + OLT + OLE + ONE + ORD + UNO + UEQ + UGT + UGE + ULT + ULE + UNE + AlwaysTrue + +declare LLVMVoidType() -> LLVMType* +declare LLVMFloatType() -> LLVMType* +declare LLVMDoubleType() -> LLVMType* +declare LLVMFunctionType(ReturnType: LLVMType*, ParamTypes: LLVMType**, ParamCount: int, IsVarArg: int) -> LLVMType* +declare LLVMStructType(ElementTypes: LLVMType**, ElementCount: int, Packed: int) -> LLVMType* +declare LLVMArrayType(ElementType: LLVMType*, ElementCount: int) -> LLVMType* +declare LLVMPointerType(ElementType: LLVMType*, AddressSpace: int) -> LLVMType* +declare LLVMDisposeMessage(Message: byte*) -> None +declare LLVMModuleCreateWithName(ModuleID: byte*) -> LLVMModule* +declare LLVMDisposeModule(M: LLVMModule*) -> None +declare LLVMGetSourceFileName(M: LLVMModule*, Len: long*) -> byte* # Return value not owned +declare LLVMSetDataLayout(M: LLVMModule*, DataLayoutStr: byte*) -> None +declare LLVMSetTarget(M: LLVMModule*, Triple: byte*) -> None +declare LLVMDumpModule(M: LLVMModule*) -> None +declare LLVMPrintModuleToString(M: LLVMModule*) -> byte* +declare LLVMAddFunction(M: LLVMModule*, Name: byte*, FunctionTy: LLVMType*) -> LLVMValue* +declare LLVMGetNamedFunction(M: LLVMModule*, Name: byte*) -> LLVMValue* +declare LLVMGetTypeKind(Ty: LLVMType*) -> LLVMTypeKind +declare LLVMInt1Type() -> LLVMType* +declare LLVMInt8Type() -> LLVMType* +declare LLVMInt16Type() -> LLVMType* +declare LLVMInt32Type() -> LLVMType* +declare LLVMInt64Type() -> LLVMType* +declare LLVMIntType(NumBits: int) -> LLVMType* +declare LLVMGetReturnType(FunctionTy: LLVMType*) -> LLVMType* +declare LLVMGetParam(Fn: LLVMValue*, Index: int) -> LLVMValue* +declare LLVMGetElementType(Ty: LLVMType*) -> LLVMType* +declare LLVMTypeOf(Val: LLVMValue*) -> LLVMType* +declare LLVMConstNull(Ty: LLVMType*) -> LLVMValue* +declare LLVMGetUndef(Ty: LLVMType*) -> LLVMValue* +declare LLVMConstInt(IntTy: LLVMType*, N: long, SignExtend: int) -> LLVMValue* +declare LLVMConstRealOfString(RealTy: LLVMType*, Text: byte*) -> LLVMValue* +declare LLVMConstString(Str: byte*, Length: int, DontNullTerminate: int) -> LLVMValue* +declare LLVMSizeOf(Ty: LLVMType*) -> LLVMValue* +declare LLVMSetLinkage(Global: LLVMValue*, Linkage: LLVMLinkage) -> None +declare LLVMAddGlobal(M: LLVMModule*, Ty: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMGetNamedGlobal(M: LLVMModule*, Name: byte*) -> LLVMValue* +declare LLVMSetInitializer(GlobalVar: LLVMValue*, ConstantVal: LLVMValue*) -> None +declare LLVMAppendBasicBlock(Fn: LLVMValue*, Name: byte*) -> LLVMBasicBlock* +declare LLVMAddIncoming(PhiNode: LLVMValue*, IncomingValues: LLVMValue**, IncomingBlocks: LLVMBasicBlock**, Count: int) -> None +declare LLVMCreateBuilder() -> LLVMBuilder* +declare LLVMPositionBuilderAtEnd(Builder: LLVMBuilder*, Block: LLVMBasicBlock*) -> None +declare LLVMGetInsertBlock(Builder: LLVMBuilder*) -> LLVMBasicBlock* +declare LLVMDisposeBuilder(Builder: LLVMBuilder*) -> None +declare LLVMBuildRet(Builder: LLVMBuilder*, V: LLVMValue*) -> LLVMValue* +declare LLVMBuildRetVoid(Builder: LLVMBuilder*) -> LLVMValue* +declare LLVMBuildBr(Builder: LLVMBuilder*, Dest: LLVMBasicBlock*) -> LLVMValue* +declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMBasicBlock*, Else: LLVMBasicBlock*) -> LLVMValue* +declare LLVMBuildUnreachable(Builder: LLVMBuilder*) -> LLVMValue* +declare LLVMBuildAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildMul(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFMul(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildUDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildExactSDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFDiv(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildURem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildSRem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFRem(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildXor(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFNeg(Builder: LLVMBuilder*, V: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildMemSet(Builder: LLVMBuilder*, Ptr: LLVMValue*, Val: LLVMValue*, Len: LLVMValue*, Align: int) -> LLVMValue* +declare LLVMBuildAlloca(Builder: LLVMBuilder*, Ty: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildLoad(Builder: LLVMBuilder*, PointerVal: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildStore(Builder: LLVMBuilder*, Val: LLVMValue*, Ptr: LLVMValue*) -> LLVMValue* +declare LLVMBuildGEP(Builder: LLVMBuilder*, Pointer: LLVMValue*, Indices: LLVMValue**, NumIndices: int, Name: byte*) -> LLVMValue* +declare LLVMBuildStructGEP2(Builder: LLVMBuilder*, Ty: LLVMType*, Pointer: LLVMValue*, Idx: int, Name: byte*) -> LLVMValue* +declare LLVMBuildTrunc(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildZExt(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildSExt(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildFPToUI(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildFPToSI(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildUIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildSIToFP(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildPtrToInt(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildIntToPtr(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildBitCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildIntCast2(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, IsSigned: int, Name: byte*) -> LLVMValue* +declare LLVMBuildFPCast(Builder: LLVMBuilder*, Val: LLVMValue*, DestTy: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildICmp(Builder: LLVMBuilder*, Op: LLVMIntPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildFCmp(Builder: LLVMBuilder*, Op: LLVMRealPredicate, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* +declare LLVMBuildPhi(Builder: LLVMBuilder*, Ty: LLVMType*, Name: byte*) -> LLVMValue* +declare LLVMBuildCall2(Builder: LLVMBuilder*, FunctionTy: LLVMType*, Fn: LLVMValue*, Args: LLVMValue**, NumArgs: int, Name: byte*) -> LLVMValue* +declare LLVMBuildExtractValue(Builder: LLVMBuilder*, AggVal: LLVMValue*, Index: int, Name: byte*) -> LLVMValue* +declare LLVMBuildInsertValue(Builder: LLVMBuilder*, AggVal: LLVMValue*, EltVal: LLVMValue*, Index: int, Name: byte*) -> LLVMValue* +declare LLVMCreatePassManager() -> LLVMPassManager* +declare LLVMRunPassManager(PM: LLVMPassManager*, M: LLVMModule*) -> int +declare LLVMDisposePassManager(PM: LLVMPassManager*) -> None diff --git a/self_hosted_old/main.jou b/self_hosted_old/main.jou new file mode 100644 index 00000000..75932da6 --- /dev/null +++ b/self_hosted_old/main.jou @@ -0,0 +1,510 @@ +import "../config.jou" +import "./ast.jou" +import "./errors_and_warnings.jou" +import "./tokenizer.jou" +import "./parser.jou" +import "./types.jou" +import "./typecheck.jou" +import "./paths.jou" +import "./target.jou" +import "./create_llvm_ir.jou" +import "./llvm.jou" +import "./evaluate.jou" +import "stdlib/mem.jou" +import "stdlib/process.jou" +import "stdlib/str.jou" +import "stdlib/io.jou" + +enum CompilerMode: + TokenizeOnly # Tokenize one file, don't recurse to imports + ParseOnly # Tokenize and parse one file, don't recurse to imports + Compile # Compile and link + CompileAndRun # Compile, link and run a program (default) + +class CommandLineArgs: + mode: CompilerMode + output_file: byte* # The argument after -o, possibly with .exe appended to it + verbosity: int # Number of -v/--verbose flags given + main_path: byte* # Jou file path passed on command line + + +# An error message should have already been printed to stderr, without a trailing \n +def fail_parsing_args(argv0: byte*, message: byte*) -> None: + fprintf(stderr, "%s: %s (try \"%s --help\")\n", argv0, message, argv0) + exit(2) + +def print_help(argv0: byte*) -> None: + printf("Usage:\n") + printf(" %s [options] FILENAME.jou\n", argv0) + printf(" %s --help # This message\n", argv0) + printf("\n") + printf("Options:\n") + printf(" -o OUTFILE output an executable file, don't run the code\n") + printf(" -v/--verbose print what compilation steps are done\n") + printf(" -vv / --verbose --verbose show what each compilation step produces\n") + printf(" --tokenize-only tokenize one file and display the resulting tokens\n") + printf(" --parse-only tokenize and parse one file and display the AST\n") + +def parse_args(argc: int, argv: byte**) -> CommandLineArgs: + result = CommandLineArgs{mode = CompilerMode::CompileAndRun} + + i = 1 + while i < argc: + arg = argv[i++] + + if strcmp(arg, "--help") == 0: + print_help(argv[0]) + exit(0) + + if result.mode != CompilerMode::CompileAndRun and ( + strcmp(arg, "--tokenize-only") == 0 + or strcmp(arg, "--parse-only") == 0 + or strcmp(arg, "-o") == 0 + ): + fail_parsing_args(argv[0], "only one of --tokenize-only, --parse-only or -o can be given") + + if strcmp(arg, "--tokenize-only") == 0: + result.mode = CompilerMode::TokenizeOnly + elif strcmp(arg, "--parse-only") == 0: + result.mode = CompilerMode::ParseOnly + elif strcmp(arg, "-o") == 0: + result.mode = CompilerMode::Compile + result.output_file = argv[i++] + if result.output_file == NULL: + fail_parsing_args(argv[0], "-o must be followed by the name of an output file") + elif strcmp(arg, "--verbose") == 0: + result.verbosity++ + elif starts_with(arg, "-v") and strspn(&arg[1], "v") == strlen(arg) - 1: + result.verbosity += (strlen(arg) as int) - 1 + elif arg[0] == '-': + message = malloc(strlen(arg) + 100) + sprintf(message, "unknown option '%s'", arg) + fail_parsing_args(argv[0], message) + elif result.main_path == NULL: + result.main_path = arg + else: + fail_parsing_args(argv[0], "you can pass only one Jou file") + + if result.main_path == NULL: + fail_parsing_args(argv[0], "missing Jou file name") + + return result + +def find_file(files: FileState*, nfiles: int, path: byte*) -> FileState*: + for i = 0; i < nfiles; i++: + if strcmp(files[i].ast.path, path) == 0: + return &files[i] + return NULL + +# C:\Users\myname\.foo-bar.jou --> "_foo_bar" +# Result never contains "-", so you can add "-" separated suffixes without conflicts. +def get_sane_filename(path: byte*) -> byte[50]: + while True: + if strstr(path, "/") != NULL: + path = strstr(path, "/") + elif strstr(path, "\\") != NULL: + path = strstr(path, "\\") + else: + break + path++ # skip slash + + name: byte[50] + snprintf(name, sizeof name, "%s", path) + assert name[0] != '\0' + + if name[0] == '.': + name[0] = '_' + for i = 0; name[i] != '\0'; i++: + if name[i] == '.': + name[i] = '\0' + break + if name[i] == '-': + name[i] = '_' + return name + + +def check_main_function(ast: AstFile*) -> bool: + for i = 0; i < ast->body.nstatements; i++: + s = &ast->body.statements[i] + if s->kind == AstStatementKind::Function and strcmp(s->function.signature.name, "main") == 0: + return True + return False + +def check_ast_and_import_conflicts(ast: AstFile*, symbol: ExportSymbol*) -> None: + for i = 0; i < ast->body.nstatements; i++: + ts = &ast->body.statements[i] + if ts->kind == AstStatementKind::Function: + conflict = symbol->kind == ExportSymbolKind::Function and strcmp(ts->function.signature.name, symbol->name) == 0 + elif ( + ts->kind == AstStatementKind::GlobalVariableDeclaration + or ts->kind == AstStatementKind::GlobalVariableDefinition + ): + conflict = symbol->kind == ExportSymbolKind::GlobalVariable and strcmp(ts->var_declaration.name, symbol->name) == 0 + elif ts->kind == AstStatementKind::Class: + conflict = symbol->kind == ExportSymbolKind::Type and strcmp(ts->classdef.name, symbol->name) == 0 + elif ts->kind == AstStatementKind::Enum: + conflict = symbol->kind == ExportSymbolKind::Type and strcmp(ts->enumdef.name, symbol->name) == 0 + else: + assert False + + if conflict: + if symbol->kind == ExportSymbolKind::Function: + kind_name = "function" + elif symbol->kind == ExportSymbolKind::GlobalVariable: + kind_name = "global variable" + elif symbol->kind == ExportSymbolKind::Type: + kind_name = "type" + else: + assert False + + message: byte[500] + # TODO: might be nice to show where it was imported from + snprintf(message, sizeof message, "a %s named '%s' already exists", kind_name, symbol->name) + fail(ts->location, message) + + +class FileState: + ast: AstFile + typectx: FileTypes + pending_exports: ExportSymbol* + +class ParseQueueItem: + path: byte* + is_imported: bool + import_location: Location + +class Compiler: + argv0: byte* + verbosity: int + stdlib_path: byte* + args: CommandLineArgs* + files: FileState* + nfiles: int + automagic_files: byte*[10] + + def determine_automagic_files(self) -> None: + self->automagic_files[0] = malloc(strlen(self->stdlib_path) + 40) + sprintf(self->automagic_files[0], "%s/_assert_fail.jou", self->stdlib_path) + if WINDOWS or MACOS or NETBSD: + self->automagic_files[1] = malloc(strlen(self->stdlib_path) + 40) + sprintf(self->automagic_files[1], "%s/_jou_startup.jou", self->stdlib_path) + + def parse_all_files(self) -> None: + queue: ParseQueueItem* = malloc(50 * sizeof queue[0]) + queue_len = 0 + queue[queue_len++] = ParseQueueItem{path = self->args->main_path} + for i = 0; self->automagic_files[i] != NULL; i++: + queue[queue_len++] = ParseQueueItem{path = self->automagic_files[i]} + + while queue_len > 0: + item = queue[--queue_len] + + found = False + for i = 0; i < self->nfiles; i++: + if strcmp(self->files[i].ast.path, item.path) == 0: + found = True + break + if found: + continue + + if self->verbosity >= 1: + printf("Parsing %s\n", item.path) + + if item.is_imported: + tokens = tokenize(item.path, &item.import_location) + else: + tokens = tokenize(item.path, NULL) + if self->verbosity >= 2: + print_tokens(tokens) + ast = parse(tokens, self->stdlib_path) + if self->verbosity >= 2: + ast.print() + free(tokens) # TODO: do this properly + + evaluate_compile_time_if_statements_in_body(&ast.body) + + if item.is_imported and check_main_function(&ast): + assert item.import_location.path != NULL + fail(item.import_location, "imported file should not have `main` function") + + self->files = realloc(self->files, sizeof self->files[0] * (self->nfiles + 1)) + self->files[self->nfiles++] = FileState{ast = ast} + + for i = 0; i < ast.nimports; i++: + # TODO: offsetof() + queue = realloc(queue, sizeof queue[0] * (queue_len + 1)) + queue[queue_len++] = ParseQueueItem{ + path = ast.imports[i].resolved_path, + is_imported = True, + import_location = ast.imports[i].location, + } + + free(queue) + + def process_imports_and_exports(self) -> None: + if self->verbosity >= 1: + printf("Processing imports/exports\n") + + for idest = 0; idest < self->nfiles; idest++: + dest = &self->files[idest] + seen_before: FileState** = malloc(sizeof(seen_before[0]) * dest->ast.nimports) + + for i = 0; i < dest->ast.nimports; i++: + imp = &dest->ast.imports[i] + + src: FileState* = NULL + for isrc = 0; isrc < self->nfiles; isrc++: + if strcmp(self->files[isrc].ast.path, imp->resolved_path) == 0: + src = &self->files[isrc] + break + assert src != NULL + + if src == dest: + fail(imp->location, "the file itself cannot be imported") + + for k = 0; k < i; k++: + if seen_before[k] == src: + message: byte[500] + snprintf(message, sizeof(message), "file \"%s\" is imported twice", imp->specified_path) + fail(imp->location, message) + seen_before[i] = src + + for exp = src->pending_exports; exp->name[0] != '\0'; exp++: + if self->verbosity >= 1: + printf( + " %s: imported in %s, exported in %s\n", + exp->name, src->ast.path, dest->ast.path, + ) + check_ast_and_import_conflicts(&dest->ast, exp) + dest->typectx.add_imported_symbol(exp) + + free(seen_before) + + for i = 0; i < self->nfiles; i++: + free(self->files[i].pending_exports) + self->files[i].pending_exports = NULL + + def typecheck_stage1_all_files(self) -> None: + for i = 0; i < self->nfiles; i++: + if self->verbosity >= 1: + printf("Type-check stage 1: %s\n", self->files[i].ast.path) + + assert self->files[i].pending_exports == NULL + self->files[i].pending_exports = typecheck_stage1_create_types( + &self->files[i].typectx, + &self->files[i].ast, + ) + + def typecheck_stage2_all_files(self) -> None: + for i = 0; i < self->nfiles; i++: + if self->verbosity >= 1: + printf("Type-check stage 2: %s\n", self->files[i].ast.path) + + assert self->files[i].pending_exports == NULL + self->files[i].pending_exports = typecheck_stage2_populate_types( + &self->files[i].typectx, + &self->files[i].ast, + ) + + def typecheck_stage3_all_files(self) -> None: + for i = 0; i < self->nfiles; i++: + if self->verbosity >= 1: + printf("Type-check stage 3: %s\n", self->files[i].ast.path) + + typecheck_stage3_function_and_method_bodies( + &self->files[i].typectx, + &self->files[i].ast, + ) + + def get_object_file_paths(self) -> byte**: + sane_names: byte[50]* = calloc(sizeof sane_names[0], self->nfiles) + result: byte** = calloc(sizeof result[0], self->nfiles + 1) # NULL terminated + + # First, extract just the names from file paths. + # "blah/blah/foo.jou" --> "foo" + for i = 0; i < self->nfiles; i++: + sane_names[i] = get_sane_filename(self->files[i].ast.path) + + for i = 0; i < self->nfiles; i++: + # If there are 3 files named foo.jou in different directories, their object files will be foo.o, foo-1.o, foo-2.o + counter = 0 + for k = 0; k < i; k++: + if strcmp(sane_names[k], sane_names[i]) == 0: + counter++ + + name: byte[100] + if counter == 0: + sprintf(name, "%s.o", sane_names[i]) + else: + sprintf(name, "%s-%d.o", sane_names[i], counter) + result[i] = get_path_to_file_in_jou_compiled(name) + + free(sane_names) + return result + + def get_exe_file_path(self) -> byte*: + if self->args->output_file == NULL: + tmp = get_sane_filename(self->args->main_path) + exe = get_path_to_file_in_jou_compiled(tmp) + else: + exe = strdup(self->args->output_file) + + if WINDOWS and not ends_with(exe, ".exe") and not ends_with(exe, ".EXE"): + exe = realloc(exe, strlen(exe) + 10) + strcat(exe, ".exe") + + if WINDOWS: + for i = 0; exe[i] != '\0'; i++: + if exe[i] == '/': + exe[i] = '\\' + + return exe + + def create_object_files(self) -> byte**: + paths = self->get_object_file_paths() + + for i = 0; i < self->nfiles; i++: + if self->verbosity >= 1: + printf("Build LLVM IR: %s\n", self->files[i].ast.path) + + module = create_llvm_ir(&self->files[i].ast, &self->files[i].typectx) + if self->verbosity >= 2: + # Don't want to use LLVMDumpModule() because it dumps to stdout. + # When redirected, stdout and stderr tend to get mixed up into a weird order. + s = LLVMPrintModuleToString(module) + puts(s) + LLVMDisposeMessage(s) + + if self->verbosity >= 1: + printf("Verify LLVM IR: %s\n", self->files[i].ast.path) + LLVMVerifyModule(module, LLVMVerifierFailureAction::AbortProcess, NULL) + + path = paths[i] + if self->verbosity >= 1: + printf("Emit LLVM IR: %s --> %s\n", self->files[i].ast.path, path) + + error: byte* = NULL + if LLVMTargetMachineEmitToFile(target.target_machine, module, path, LLVMCodeGenFileType::ObjectFile, &error) != 0: + assert error != NULL + fprintf(stderr, "error in LLVMTargetMachineEmitToFile(): %s\n", error) + exit(1) + assert error == NULL + + return paths + + def link(self, object_files: byte**) -> byte*: + exe = self->get_exe_file_path() + if WINDOWS: + c_compiler = find_installation_directory() + c_compiler = realloc(c_compiler, strlen(c_compiler) + 100) + strcat(c_compiler, "\\mingw64\\bin\\gcc.exe") + else: + c_compiler = get_jou_clang_path() + + command_size = strlen(c_compiler) + strlen(exe) + 100 + for i = 0; object_files[i] != NULL; i++: + command_size += 5 + command_size += strlen(object_files[i]) + command: byte* = malloc(command_size) + + sprintf(command, "\"%s\" -o \"%s\"", c_compiler, exe) + for i = 0; object_files[i] != NULL; i++: + sprintf(&command[strlen(command)], " \"%s\"", object_files[i]) + strcat(command, " -lm") + + if WINDOWS: + # windows strips outermost quotes for some reason, so let's quote it all one more time... + memmove(&command[1], &command[0], strlen(command) + 1) + command[0] = '"' + strcat(command, "\"") + + if self->verbosity >= 1: + printf("Run linker command: %s\n", command) + + # make sure that compiler output shows up before command output, even if redirected + fflush(stdout) + fflush(stderr) + + ret = system(command) + free(command) + if ret != 0: + fprintf(stderr, "%s: linking failed\n", self->argv0) + exit(1) + + return exe + + def run(self, exe: byte*) -> None: + command = malloc(strlen(exe) + 10) + sprintf(command, "\"%s\"", exe) + if self->verbosity >= 1: + printf("Run the compiled program command: %s\n", command) + + # make sure that compiler output shows up before command output, even if redirected + fflush(stdout) + fflush(stderr) + + ret = system(command) + if ret != 0: + # TODO: print something? The shell doesn't print stuff + # like "Segmentation fault" on Windows afaik + exit(1) + + +def main(argc: int, argv: byte**) -> int: + init_target() + init_types() + + args = parse_args(argc, argv) + + if args.mode == CompilerMode::TokenizeOnly: + tokens = tokenize(args.main_path, NULL) + print_tokens(tokens) + free(tokens) + elif args.mode == CompilerMode::ParseOnly: + tokens = tokenize(args.main_path, NULL) + stdlib_path = find_stdlib() + ast = parse(tokens, stdlib_path) + ast.print() + ast.free() + free(tokens) + free(stdlib_path) + elif args.mode == CompilerMode::Compile or args.mode == CompilerMode::CompileAndRun: + compiler = Compiler{ + argv0 = argv[0], + verbosity = args.verbosity, + stdlib_path = find_stdlib(), + args = &args, + } + compiler.determine_automagic_files() + compiler.parse_all_files() + + compiler.typecheck_stage1_all_files() + compiler.process_imports_and_exports() + compiler.typecheck_stage2_all_files() + compiler.process_imports_and_exports() + compiler.typecheck_stage3_all_files() + + mainfile = find_file(compiler.files, compiler.nfiles, args.main_path) + assert mainfile != NULL + + if not check_main_function(&mainfile->ast): + l = Location{path=mainfile->ast.path, lineno=0} + fail(l, "missing `main` function to execute the program") + + object_files = compiler.create_object_files() + executable = compiler.link(object_files) + for i = 0; object_files[i] != NULL; i++: + free(object_files[i]) + free(object_files) + + # TODO: make the -o flag work + if args.mode == CompilerMode::CompileAndRun: + compiler.run(executable) + free(executable) + for i = 0; compiler.automagic_files[i] != NULL; i++: + free(compiler.automagic_files[i]) + + else: + assert False + + return 0 diff --git a/self_hosted_old/parser.jou b/self_hosted_old/parser.jou new file mode 100644 index 00000000..bd3840bd --- /dev/null +++ b/self_hosted_old/parser.jou @@ -0,0 +1,1145 @@ +import "stdlib/ascii.jou" +import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/mem.jou" +import "./token.jou" +import "./ast.jou" +import "./errors_and_warnings.jou" +import "./paths.jou" + + +# arity = number of operands, e.g. 2 for a binary operator such as "+" +# +# This cannot be used for ++ and --, because with them we can't know the kind from +# just the token (e.g. ++ could mean pre-increment or post-increment). +def build_operator_expression(t: Token*, arity: int, operands: AstExpression*) -> AstExpression: + assert arity == 1 or arity == 2 + nbytes = arity * sizeof operands[0] + ptr = malloc(nbytes) + memcpy(ptr, operands, nbytes) + + result = AstExpression{location = t->location, operands = ptr} + + if t->is_operator("&"): + assert arity == 1 + result.kind = AstExpressionKind::AddressOf + elif t->is_operator("["): + assert arity == 2 + result.kind = AstExpressionKind::Indexing + elif t->is_operator("=="): + assert arity == 2 + result.kind = AstExpressionKind::Eq + elif t->is_operator("!="): + assert arity == 2 + result.kind = AstExpressionKind::Ne + elif t->is_operator(">"): + assert arity == 2 + result.kind = AstExpressionKind::Gt + elif t->is_operator(">="): + assert arity == 2 + result.kind = AstExpressionKind::Ge + elif t->is_operator("<"): + assert arity == 2 + result.kind = AstExpressionKind::Lt + elif t->is_operator("<="): + assert arity == 2 + result.kind = AstExpressionKind::Le + elif t->is_operator("+"): + assert arity == 2 + result.kind = AstExpressionKind::Add + elif t->is_operator("-"): + if arity == 2: + result.kind = AstExpressionKind::Subtract + else: + result.kind = AstExpressionKind::Negate + elif t->is_operator("*"): + if arity == 2: + result.kind = AstExpressionKind::Multiply + else: + result.kind = AstExpressionKind::Dereference + elif t->is_operator("/"): + assert arity == 2 + result.kind = AstExpressionKind::Divide + elif t->is_operator("%"): + assert arity == 2 + result.kind = AstExpressionKind::Modulo + elif t->is_keyword("and"): + assert arity == 2 + result.kind = AstExpressionKind::And + elif t->is_keyword("or"): + assert arity == 2 + result.kind = AstExpressionKind::Or + elif t->is_keyword("not"): + assert arity == 1 + result.kind = AstExpressionKind::Not + else: + assert False + + assert result.get_arity() == arity + return result + +# reverse code golfing: https://xkcd.com/1960/ +def determine_the_kind_of_a_statement_that_starts_with_an_expression( + this_token_is_after_that_initial_expression: Token* +) -> AstStatementKind: + if this_token_is_after_that_initial_expression->is_operator("="): + return AstStatementKind::Assign + if this_token_is_after_that_initial_expression->is_operator("+="): + return AstStatementKind::InPlaceAdd + if this_token_is_after_that_initial_expression->is_operator("-="): + return AstStatementKind::InPlaceSubtract + if this_token_is_after_that_initial_expression->is_operator("*="): + return AstStatementKind::InPlaceMultiply + if this_token_is_after_that_initial_expression->is_operator("/="): + return AstStatementKind::InPlaceDivide + if this_token_is_after_that_initial_expression->is_operator("%="): + return AstStatementKind::InPlaceModulo + return AstStatementKind::ExpressionStatement + +class MemberInfo: + kind: byte* + name: byte[100] + name_location: Location + +def check_class_for_duplicate_names(classdef: AstClassDef*) -> None: + n = 0 + for i = 0; i < classdef->nmembers; i++: + member = &classdef->members[i] + # We will make a separate MemberInfo for each union field + if member->kind == AstClassMemberKind::Union: + n += member->union_fields.nfields + else: + n++ + + infos: MemberInfo* = malloc(n * sizeof infos[0] * 1000) + destptr: MemberInfo* = infos + + for i = 0; i < classdef->nmembers; i++: + member = &classdef->members[i] + if member->kind == AstClassMemberKind::Field: + *destptr++ = MemberInfo{ + kind = "a field", + name = member->field.name, + name_location = member->field.name_location, + } + elif member->kind == AstClassMemberKind::Union: + for k = 0; k < member->union_fields.nfields; k++: + *destptr++ = MemberInfo{ + kind = "a union member", + name = member->union_fields.fields[k].name, + name_location = member->union_fields.fields[k].name_location, + } + elif member->kind == AstClassMemberKind::Method: + *destptr++ = MemberInfo{ + kind = "a method", + name = member->method.signature.name, + name_location = member->method.signature.name_location, + } + else: + assert False + + assert destptr == &infos[n] + + for p1 = infos; p1 < destptr; p1++: + for p2 = &p1[1]; p2 < destptr; p2++: + if strcmp(p1->name, p2->name) == 0: + message: byte[500] + snprintf( + message, sizeof message, + "class %s already has %s named '%s'", + classdef->name, p1->kind, p1->name, + ) + fail(p2->name_location, message) + + +# TODO: this function is just bad... +def read_assertion_from_file(start: Location, end: Location) -> byte*: + assert start.path == end.path + + f = fopen(start.path, "rb") + assert f != NULL + + line: byte[1024] + lineno = 1 + while lineno < start.lineno: + assert fgets(line, sizeof(line) as int, f) != NULL + lineno++ + + result: byte* = malloc(2000 * (end.lineno - start.lineno + 1)) + result[0] = '\0' + + while lineno <= end.lineno: + assert fgets(line, sizeof(line) as int, f) != NULL + lineno++ + + # TODO: strings containing '#' ... so much wrong with dis + if strstr(line, "#") != NULL: + *strstr(line, "#") = '\0' + trim_ascii_whitespace(line) + + # Add spaces between lines, but not after '(' or before ')' + if not starts_with(line, ")") and not ends_with(result, "("): + strcat(result, " ") + strcat(result, line) + + fclose(f) + + trim_ascii_whitespace(result) + if starts_with(result, "assert"): + memmove(result, &result[6], strlen(&result[6]) + 1) + trim_ascii_whitespace(result) + + return result + + +class Parser: + tokens: Token* + stdlib_path: byte* + is_parsing_method_body: bool + + def eat_newline(self) -> None: + if self->tokens->kind != TokenKind::Newline: + self->tokens->fail_expected_got("end of line") + self->tokens++ + + def parse_import(self) -> AstImport: + assert self->tokens->is_keyword("import") + import_keyword = self->tokens++ + + path_token = self->tokens++ + if path_token->kind != TokenKind::String: + path_token->fail_expected_got("a string to specify the file name") + + self->eat_newline() + + if starts_with(path_token->long_string, "stdlib/"): + # Starts with stdlib --> import from where stdlib actually is + tmp = NULL + part1 = self->stdlib_path + part2 = &path_token->long_string[7] + elif starts_with(path_token->long_string, "."): + # Relative to directory where the file is + tmp = strdup(path_token->location.path) + part1 = dirname(tmp) + part2 = path_token->long_string + else: + fail( + path_token->location, + "import path must start with 'stdlib/' (standard-library import) or a dot (relative import)" + ) + + # 1 for slash, 1 for \0, 1 for fun + path = malloc(strlen(part1) + strlen(part2) + 3) + sprintf(path, "%s/%s", part1, part2) + free(tmp) + + simplify_path(path) + return AstImport{ + location = import_keyword->location, + specified_path = strdup(path_token->long_string), + resolved_path = path, + } + + def parse_type(self) -> AstType: + if not ( + self->tokens->kind == TokenKind::Name + or self->tokens->is_keyword("None") + or self->tokens->is_keyword("void") + or self->tokens->is_keyword("noreturn") + or self->tokens->is_keyword("short") + or self->tokens->is_keyword("int") + or self->tokens->is_keyword("long") + or self->tokens->is_keyword("byte") + or self->tokens->is_keyword("float") + or self->tokens->is_keyword("double") + or self->tokens->is_keyword("bool") + ): + self->tokens->fail_expected_got("a type") + + result = AstType{ + kind = AstTypeKind::Named, + location = self->tokens->location, + name = self->tokens->short_string, + } + self->tokens++ + + while self->tokens->is_operator("*") or self->tokens->is_operator("["): + p: AstType* = malloc(sizeof *p) + *p = result + + if self->tokens->is_operator("*"): + result = AstType{ + location = (self->tokens++)->location, # TODO: shouldn't need all the parentheses + kind = AstTypeKind::Pointer, + value_type = p, + } + else: + location = (self->tokens++)->location + + len_expression: AstExpression* = malloc(sizeof *len_expression) + *len_expression = self->parse_expression() + + if not self->tokens->is_operator("]"): + self->tokens->fail_expected_got("a ']' to end the array size") + self->tokens++ + + result = AstType{ + location = location, + kind = AstTypeKind::Array, + array = AstArrayType{ + member_type = p, + length = len_expression, + } + } + + return result + + def parse_name_type_value(self, expected_what_for_name: byte*) -> AstNameTypeValue: + if self->tokens->kind != TokenKind::Name: + assert expected_what_for_name != NULL + self->tokens->fail_expected_got(expected_what_for_name) + + result = AstNameTypeValue{name = self->tokens->short_string, name_location = self->tokens->location} + self->tokens++ + + if not self->tokens->is_operator(":"): + self->tokens->fail_expected_got("':' and a type after it (example: \"foo: int\")") + self->tokens++ + result.type = self->parse_type() + + if self->tokens->is_operator("="): + self->tokens++ + p: AstExpression* = malloc(sizeof *p) + *p = self->parse_expression() + result.value = p + + return result + + def parse_function_or_method_signature(self, is_method: bool) -> AstSignature: + # TODO: change error messages to say method, when it is a method (#243) + used_self: bool = False + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a function name") + + result = AstSignature{ + name_location = self->tokens->location, + name = self->tokens->short_string, + } + self->tokens++ + + if not self->tokens->is_operator("("): + self->tokens->fail_expected_got("a '(' to denote the start of function arguments") + self->tokens++ + + while not self->tokens->is_operator(")"): + if result.takes_varargs: + fail(self->tokens->location, "if '...' is used, it must be the last parameter") + + if self->tokens->is_operator("..."): + result.takes_varargs = True + self->tokens++ + + elif self->tokens->is_keyword("self"): + if not is_method: + fail(self->tokens->location, "'self' cannot be used here") + + self_arg = AstNameTypeValue{ + name = "self", + name_location = self->tokens->location, + } + self->tokens++ + + if self->tokens->is_operator(":"): + self->tokens++ + self_arg.type = self->parse_type() + + result.args = realloc(result.args, sizeof result.args[0] * (result.nargs+1)) + result.args[result.nargs++] = self_arg + used_self = True + + else: + arg = self->parse_name_type_value("an argument name") + if arg.value != NULL: + fail(arg.value->location, "arguments cannot have default values") + + for i = 0; i < result.nargs; i++: + if strcmp(result.args[i].name, arg.name) == 0: + message: byte[200] + snprintf( + message, sizeof message, + "there are multiple arguments named '%s'", arg.name) + fail(arg.name_location, message) + + result.args = realloc(result.args, sizeof result.args[0] * (result.nargs+1)) + result.args[result.nargs++] = arg + + if not self->tokens->is_operator(","): + break + self->tokens++ + + if not self->tokens->is_operator(")"): + self->tokens->fail_expected_got("a ')'") + self->tokens++ + + # TODO: + # * If is_method, ensure that self parameter exists and is first + # * Else, ensure that self parameter does not exists + + # Special case for common typo: def foo(): + if self->tokens->is_operator(":"): + fail(self->tokens->location, "return type must be specified with '->', or with '-> None' if the function doesn't return anything") + if not self->tokens->is_operator("->"): + self->tokens->fail_expected_got("a '->'") + self->tokens++ + + if not used_self and is_method: + throwerror: byte[300] + snprintf(throwerror, sizeof throwerror, "missing self, should be 'def %s(self, ...)'", result.name) + fail(self->tokens->location, throwerror) + + result.return_type = self->parse_type() + return result + + def parse_call(self) -> AstCall: + assert self->tokens->kind == TokenKind::Name # must be checked when calling this function + result = AstCall{location = self->tokens->location, name = self->tokens->short_string} + self->tokens++ + assert self->tokens->is_operator("(") + self->tokens++ + + while not self->tokens->is_operator(")"): + result.args = realloc(result.args, sizeof result.args[0] * (result.nargs+1)) + result.args[result.nargs++] = self->parse_expression() + if not self->tokens->is_operator(","): + break + self->tokens++ + + if not self->tokens->is_operator(")"): + self->tokens->fail_expected_got("a ')'") + self->tokens++ + + return result + + def parse_instantiation(self) -> AstInstantiation: + assert self->tokens->kind == TokenKind::Name # must be checked when calling this function + result = AstInstantiation{class_name_location = self->tokens->location, class_name = self->tokens->short_string} + self->tokens++ + assert self->tokens->is_operator("{") + self->tokens++ + + while not self->tokens->is_operator("}"): + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a field name") + field_name = self->tokens->short_string + for i = 0; i < result.nfields; i++: + if strcmp(result.field_names[i], field_name) == 0: + error: byte[500] + snprintf(error, sizeof error, "multiple values were given for field '%s'", field_name) + fail(self->tokens->location, error) + result.field_names = realloc(result.field_names, (result.nfields + 1) * sizeof result.field_names[0]) + result.field_names[result.nfields] = field_name + self->tokens++ + + if not self->tokens->is_operator("="): + msg: byte[300] + snprintf(msg, sizeof msg, "'=' followed by a value for field '%s'", field_name) + self->tokens->fail_expected_got(msg) + self->tokens++ + + result.field_values = realloc(result.field_values, sizeof result.field_values[0] * (result.nfields+1)) + result.field_values[result.nfields] = self->parse_expression() + + result.nfields++ + if not self->tokens->is_operator(","): + break + self->tokens++ + + if not self->tokens->is_operator("}"): + self->tokens->fail_expected_got("a '}'") + self->tokens++ + + return result + + def parse_array(self) -> AstArray: + assert self->tokens->is_operator("[") + self->tokens++ + + result = AstArray{} + while not self->tokens->is_operator("]"): + result.items = realloc(result.items, (result.length + 1) * sizeof result.items[0]) + result.items[result.length++] = self->parse_expression() + if not self->tokens->is_operator(","): + break + self->tokens++ + + if not self->tokens->is_operator("]"): + self->tokens->fail_expected_got("a ']' to end the array") + if result.length == 0: + fail(self->tokens->location, "arrays cannot be empty") + self->tokens++ + + return result + + def parse_elementary_expression(self) -> AstExpression: + expr = AstExpression{location = self->tokens->location} + + if self->tokens->kind == TokenKind::Short: + expr.kind = AstExpressionKind::Short + expr.short_value = self->tokens->short_value + self->tokens++ + elif self->tokens->kind == TokenKind::Int: + expr.kind = AstExpressionKind::Int + expr.int_value = self->tokens->int_value + self->tokens++ + elif self->tokens->kind == TokenKind::Long: + expr.kind = AstExpressionKind::Long + expr.long_value = self->tokens->long_value + self->tokens++ + elif self->tokens->kind == TokenKind::Byte: + expr.kind = AstExpressionKind::Byte + expr.byte_value = self->tokens->byte_value + self->tokens++ + elif self->tokens->kind == TokenKind::String: + expr.kind = AstExpressionKind::String + expr.string = strdup(self->tokens->long_string) + self->tokens++ + elif self->tokens->kind == TokenKind::Float: + expr.kind = AstExpressionKind::Float + expr.float_or_double_text = self->tokens->short_string + self->tokens++ + elif self->tokens->kind == TokenKind::Double: + expr.kind = AstExpressionKind::Double + expr.float_or_double_text = self->tokens->short_string + self->tokens++ + elif self->tokens->is_keyword("True"): + expr.kind = AstExpressionKind::Bool + expr.bool_value = True + self->tokens++ + elif self->tokens->is_keyword("False"): + expr.kind = AstExpressionKind::Bool + expr.bool_value = False + self->tokens++ + elif self->tokens->is_keyword("NULL"): + expr.kind = AstExpressionKind::Null + self->tokens++ + elif self->tokens->is_keyword("None"): + fail(self->tokens->location, "None is not a value in Jou, use e.g. -1 for numbers or NULL for pointers") + elif self->tokens->is_keyword("self"): + if not self->is_parsing_method_body: + fail(self->tokens->location, "'self' cannot be used here") + expr.kind = AstExpressionKind::Self + self->tokens++ + elif self->tokens->kind == TokenKind::Name: + if self->tokens[1].is_operator("("): + expr.kind = AstExpressionKind::Call + expr.call = self->parse_call() + elif self->tokens[1].is_operator("{"): + expr.kind = AstExpressionKind::Instantiate + expr.instantiation = self->parse_instantiation() + elif self->tokens[1].is_operator("::") and self->tokens[2].kind == TokenKind::Name: + expr.kind = AstExpressionKind::GetEnumMember + expr.enum_member = AstEnumMember{ + enum_name = self->tokens->short_string, + member_name = self->tokens[2].short_string, + } + self->tokens++ + self->tokens++ + self->tokens++ + else: + expr.kind = AstExpressionKind::GetVariable + expr.varname = self->tokens->short_string + self->tokens++ + elif self->tokens->is_operator("("): + self->tokens++ + expr = self->parse_expression() + if not self->tokens->is_operator(")"): + self->tokens->fail_expected_got("a ')'") + self->tokens++ + elif self->tokens->is_operator("["): + expr.kind = AstExpressionKind::Array + expr.array = self->parse_array() + else: + self->tokens->fail_expected_got("an expression") + + return expr + + def parse_expression_with_fields_and_methods_and_indexing(self) -> AstExpression: + result = self->parse_elementary_expression() + + while self->tokens->is_operator(".") or self->tokens->is_operator("->") or self->tokens->is_operator("["): + if self->tokens->is_operator("["): + open_bracket = self->tokens++ + operands = [result, self->parse_expression()] + if not self->tokens->is_operator("]"): + self->tokens->fail_expected_got("a ']'") + self->tokens++ + result = build_operator_expression(open_bracket, 2, operands) + + else: + start_op = self->tokens++ + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a field or method name") + + instance: AstExpression* = malloc(sizeof *instance) + *instance = result + + if self->tokens[1].is_operator("("): + call = self->parse_call() + call.method_call_self = instance + call.uses_arrow_operator = start_op->is_operator("->") + result = AstExpression{ + location = call.location, + kind = AstExpressionKind::Call, + call = call, + } + else: + result = AstExpression{ + location = self->tokens->location, + kind = AstExpressionKind::GetClassField, + class_field = AstClassField{ + instance = instance, + uses_arrow_operator = start_op->is_operator("->"), + field_name = self->tokens->short_string, + }, + } + self->tokens++ + + return result + + def parse_expression_with_unary_operators(self) -> AstExpression: + # prefix = sequneces of 0 or more unary operator tokens: start,start+1,...,end-1 + prefix_start = self->tokens + while ( + self->tokens->is_operator("++") + or self->tokens->is_operator("--") + or self->tokens->is_operator("&") + or self->tokens->is_operator("*") + or self->tokens->is_keyword("sizeof") + ): + self->tokens++ + prefix_end = self->tokens + + result = self->parse_expression_with_fields_and_methods_and_indexing() + + suffix_start = self->tokens + while self->tokens->is_operator("++") or self->tokens->is_operator("--"): + self->tokens++ + suffix_end = self->tokens + + while prefix_start != prefix_end or suffix_start != suffix_end: + # ++ and -- "bind tighter", so *foo++ is equivalent to *(foo++) + # It is implemented by always consuming ++/-- prefixes and suffixes when they exist. + if prefix_start != prefix_end and prefix_end[-1].is_operator("++"): + token = --prefix_end + kind = AstExpressionKind::PreIncr + elif prefix_start != prefix_end and prefix_end[-1].is_operator("--"): + token = --prefix_end + kind = AstExpressionKind::PreDecr + elif suffix_start != suffix_end and suffix_start[0].is_operator("++"): + token = suffix_start++ + kind = AstExpressionKind::PostIncr + elif suffix_start != suffix_end and suffix_start[0].is_operator("--"): + token = suffix_start++ + kind = AstExpressionKind::PostDecr + else: + # We don't have ++ or --, so it must be something in the prefix + assert prefix_start != prefix_end and suffix_start == suffix_end + token = --prefix_end + if token->is_operator("*"): + kind = AstExpressionKind::Dereference + elif token->is_operator("&"): + kind = AstExpressionKind::AddressOf + elif token->is_keyword("sizeof"): + kind = AstExpressionKind::SizeOf + else: + assert False + + p: AstExpression* = malloc(sizeof(*p)) + *p = result + result = AstExpression{location = token->location, kind = kind, operands = p} + + return result + + def parse_expression_with_mul_and_div(self) -> AstExpression: + result = self->parse_expression_with_unary_operators() + while self->tokens->is_operator("*") or self->tokens->is_operator("/") or self->tokens->is_operator("%"): + t = self->tokens++ + lhs_rhs = [result, self->parse_expression_with_unary_operators()] + result = build_operator_expression(t, 2, lhs_rhs) + return result + + def parse_expression_with_add(self) -> AstExpression: + if self->tokens->is_operator("-"): + minus = self->tokens++ + else: + minus = NULL + + result = self->parse_expression_with_mul_and_div() + if minus != NULL: + result = build_operator_expression(minus, 1, &result) + + while self->tokens->is_operator("+") or self->tokens->is_operator("-"): + t = self->tokens++ + lhs_rhs = [result, self->parse_expression_with_mul_and_div()] + result = build_operator_expression(t, 2, lhs_rhs) + + return result + + # "as" operator has somewhat low precedence, so that "1+2 as float" works as expected + def parse_expression_with_as(self) -> AstExpression: + result = self->parse_expression_with_add() + while self->tokens->is_keyword("as"): + as_location = (self->tokens++)->location # TODO: shouldn't need so many parentheses + p: AstAsExpression* = malloc(sizeof(*p)) + *p = AstAsExpression{type = self->parse_type(), value = result} + result = AstExpression{ + location = as_location, + kind = AstExpressionKind::As, + as_expression = p, + } + return result + + def parse_expression_with_comparisons(self) -> AstExpression: + result = self->parse_expression_with_as() + if self->tokens->is_comparison(): + t = self->tokens++ + lhs_rhs = [result, self->parse_expression_with_as()] + result = build_operator_expression(t, 2, lhs_rhs) + if self->tokens->is_comparison(): + fail(self->tokens->location, "comparisons cannot be chained") + return result + + def parse_expression_with_not(self) -> AstExpression: + if self->tokens->is_keyword("not"): + not_token = self->tokens + self->tokens++ + else: + not_token = NULL + + if self->tokens->is_keyword("not"): + fail(self->tokens->location, "'not' cannot be repeated") + + result = self->parse_expression_with_comparisons() + if not_token != NULL: + result = build_operator_expression(not_token, 1, &result) + return result + + def parse_expression_with_and_or(self) -> AstExpression: + result = self->parse_expression_with_not() + got_and = False + got_or = False + + while True: + if self->tokens->is_keyword("and"): + got_and = True + elif self->tokens->is_keyword("or"): + got_or = True + else: + break + if got_and and got_or: + fail(self->tokens->location, "'and' cannot be chained with 'or', you need more parentheses") + + t = self->tokens++ + lhs_rhs = [result, self->parse_expression_with_not()] + result = build_operator_expression(t, 2, lhs_rhs) + + return result + + def parse_expression(self) -> AstExpression: + return self->parse_expression_with_and_or() + + # does not eat a trailing newline + def parse_oneline_statement(self) -> AstStatement: + result = AstStatement{ location = self->tokens->location } + if self->tokens->is_keyword("return"): + self->tokens++ + result.kind = AstStatementKind::Return + if self->tokens->kind != TokenKind::Newline: + result.return_value = malloc(sizeof *result.return_value) + *result.return_value = self->parse_expression() + elif self->tokens->is_keyword("assert"): + self->tokens++ + result.kind = AstStatementKind::Assert + start = self->tokens->location + result.assertion.condition = self->parse_expression() + end = self->tokens->location + result.assertion.condition_str = read_assertion_from_file(start, end) + elif self->tokens->is_keyword("pass"): + self->tokens++ + result.kind = AstStatementKind::Pass + elif self->tokens->is_keyword("break"): + self->tokens++ + result.kind = AstStatementKind::Break + elif self->tokens->is_keyword("continue"): + self->tokens++ + result.kind = AstStatementKind::Continue + elif self->tokens->kind == TokenKind::Name and self->tokens[1].is_operator(":"): + # "foo: int" creates a variable "foo" of type "int" + result.kind = AstStatementKind::DeclareLocalVar + result.var_declaration = self->parse_name_type_value(NULL) + else: + expr = self->parse_expression() + result.kind = determine_the_kind_of_a_statement_that_starts_with_an_expression(self->tokens) + if result.kind == AstStatementKind::ExpressionStatement: + if not expr.can_have_side_effects(): + fail(expr.location, "not a valid statement") + result.expression = expr + else: + self->tokens++ + result.assignment = AstAssignment{target = expr, value = self->parse_expression()} + if self->tokens->is_operator("="): + # Would fail elsewhere anyway, but let's make the error message clear + fail(self->tokens->location, "only one variable can be assigned at a time") + + return result + + def parse_if_statement(self) -> AstIfStatement: + ifs_and_elifs: AstConditionAndBody* = NULL + n = 0 + + assert self->tokens->is_keyword("if") + while True: + self->tokens++ + cond = self->parse_expression() + body = self->parse_body() + ifs_and_elifs = realloc(ifs_and_elifs, sizeof ifs_and_elifs[0] * (n+1)) + ifs_and_elifs[n++] = AstConditionAndBody{condition = cond, body = body} + if not self->tokens->is_keyword("elif"): + break + + if self->tokens->is_keyword("else"): + self->tokens++ + else_body = self->parse_body() + else: + else_body = AstBody{} + + return AstIfStatement{ + if_and_elifs = ifs_and_elifs, + n_if_and_elifs = n, + else_body = else_body, + } + + def parse_while_loop(self) -> AstConditionAndBody: + assert self->tokens->is_keyword("while") + self->tokens++ + cond = self->parse_expression() + body = self->parse_body() + return AstConditionAndBody{condition = cond, body = body} + + def parse_for_loop(self) -> AstForLoop: + assert self->tokens->is_keyword("for") + self->tokens++ + + # Check if it's "for i in ..." loop, those are not supported + if ( + self->tokens[0].kind == TokenKind::Name + and self->tokens[1].kind == TokenKind::Name + and strcmp(self->tokens[1].short_string, "in") == 0 + ): + fail(self->tokens[1].location, "Python-style for loops aren't supported. Use e.g. 'for i = 0; i < 10; i++'") + + init: AstStatement* = malloc(sizeof *init) + incr: AstStatement* = malloc(sizeof *incr) + + *init = self->parse_oneline_statement() + if not self->tokens->is_operator(";"): + self->tokens->fail_expected_got("a ';'") + self->tokens++ + cond = self->parse_expression() + if not self->tokens->is_operator(";"): + self->tokens->fail_expected_got("a ';'") + self->tokens++ + *incr = self->parse_oneline_statement() + + return AstForLoop{ + init = init, + cond = cond, + incr = incr, + body = self->parse_body(), + } + + # Parses the "x: int" part of "x, y, z: int", leaving "y, z: int" to be parsed later. + def parse_first_of_multiple_local_var_declares(self) -> AstNameTypeValue: + assert self->tokens->kind == TokenKind::Name + + ntv = AstNameTypeValue{ + name = self->tokens->short_string, + name_location = self->tokens->location, + } + + # Take a backup of the parser where first variable name and its comma are consumed. + save_state = *self + save_state.tokens = &save_state.tokens[2] + + # Skip variables and commas so we can parse the type that comes after it + self->tokens++ + while self->tokens->is_operator(",") and self->tokens[1].kind == TokenKind::Name: + self->tokens = &self->tokens[2] + + # Error for "x, y = 0" + if self->tokens->is_operator("="): + fail(self->tokens->location, "only one variable can be assigned at a time") + + if self->tokens->is_operator("="): + fail(self->tokens->location, "only one variable can be assigned at a time") + + if not self->tokens->is_operator(":"): + self->tokens->fail_expected_got("':' and a type after it (example: \"foo, bar: int\")") + self->tokens++ + + ntv.type = self->parse_type() + + # Error for "x, y: int = 0" + if self->tokens->is_operator("="): + fail(self->tokens->location, "only one variable can be assigned at a time") + + *self = save_state + return ntv + + def parse_statement(self) -> AstStatement: + if self->tokens->is_keyword("import"): + fail(self->tokens->location, "imports must be in the beginning of the file") + if self->tokens->is_keyword("def"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::Function, + function = self->parse_function_or_method(False), + } + + if self->tokens->is_keyword("declare"): + location = (self->tokens++)->location + if self->tokens->is_keyword("global"): + self->tokens++ + result = AstStatement{ + location = location, + kind = AstStatementKind::GlobalVariableDeclaration, + var_declaration = self->parse_name_type_value("a variable name"), + } + if result.var_declaration.value != NULL: + fail( + result.var_declaration.value->location, + "a value cannot be given when declaring a global variable", + ) + else: + result = AstStatement{ + location = location, + kind = AstStatementKind::Function, + function = AstFunctionOrMethod{signature = self->parse_function_or_method_signature(False)}, + } + self->eat_newline() + return result + + if self->tokens->is_keyword("global"): + result = AstStatement{ + location = (self->tokens++)->location, + kind = AstStatementKind::GlobalVariableDefinition, + var_declaration = self->parse_name_type_value("a variable name"), + } + if result.var_declaration.value != NULL: + fail( + result.var_declaration.value->location, + "specifying a value for a global variable is not supported yet", + ) + self->eat_newline() + return result + + if self->tokens->is_keyword("class"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::Class, + classdef = self->parse_class(), + } + + if self->tokens->is_keyword("enum"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::Enum, + enumdef = self->parse_enum(), + } + + if self->tokens->is_keyword("if"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::If, + if_statement = self->parse_if_statement(), + } + + if self->tokens->is_keyword("for"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::ForLoop, + for_loop = self->parse_for_loop(), + } + + if self->tokens->is_keyword("while"): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::WhileLoop, + while_loop = self->parse_while_loop(), + } + + if ( + self->tokens[0].kind == TokenKind::Name + and self->tokens[1].is_operator(",") + and self->tokens[2].kind == TokenKind::Name + ): + return AstStatement{ + location = self->tokens->location, + kind = AstStatementKind::DeclareLocalVar, + var_declaration = self->parse_first_of_multiple_local_var_declares(), + } + + result = self->parse_oneline_statement() + self->eat_newline() + return result + + def parse_start_of_body(self) -> None: + if not self->tokens->is_operator(":"): + self->tokens->fail_expected_got("':' followed by a new line with more indentation") + self->tokens++ + + if self->tokens->kind != TokenKind::Newline: + self->tokens->fail_expected_got("a new line with more indentation after ':'") + self->tokens++ + + if self->tokens->kind != TokenKind::Indent: + self->tokens->fail_expected_got("more indentation after ':'") + self->tokens++ + + def parse_body(self) -> AstBody: + self->parse_start_of_body() + + result: AstStatement* = NULL + n = 0 + while self->tokens->kind != TokenKind::Dedent: + result = realloc(result, sizeof result[0] * (n+1)) + result[n++] = self->parse_statement() + self->tokens++ + + return AstBody{ statements = result, nstatements = n } + + def parse_function_or_method(self, is_method: bool) -> AstFunctionOrMethod: + assert self->tokens->is_keyword("def") + self->tokens++ + + signature = self->parse_function_or_method_signature(is_method) + if strcmp(signature.name, "__init__") == 0 and is_method: + fail(self->tokens->location, "Jou does not have a special __init__ method like Python") + if signature.takes_varargs: + fail(self->tokens->location, "functions with variadic arguments cannot be defined yet") + + assert not self->is_parsing_method_body + self->is_parsing_method_body = is_method + body = self->parse_body() + self->is_parsing_method_body = False + + return AstFunctionOrMethod{signature = signature, body = body} + + def parse_class(self) -> AstClassDef: + assert self->tokens->is_keyword("class") + self->tokens++ + + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a name for the class") + + result = AstClassDef{ + name = self->tokens->short_string, + name_location = self->tokens->location, + } + self->tokens++ + + self->parse_start_of_body() + while self->tokens->kind != TokenKind::Dedent: + if self->tokens->is_keyword("def"): + new_member = AstClassMember{ + kind = AstClassMemberKind::Method, + method = self->parse_function_or_method(True), + } + + elif self->tokens->is_keyword("union"): + union_keyword_location = (self->tokens++)->location + self->parse_start_of_body() + + union_fields = AstUnionFields{} + while self->tokens->kind != TokenKind::Dedent: + field = self->parse_name_type_value("a union member") + if field.value != NULL: + fail(field.value->location, "union members cannot have default values") + + union_fields.fields = realloc(union_fields.fields, (union_fields.nfields + 1) * sizeof union_fields.fields[0]) + union_fields.fields[union_fields.nfields++] = field + self->eat_newline() + + self->tokens++ + + if union_fields.nfields < 2: + fail(union_keyword_location, "unions must have at least 2 members") + + new_member = AstClassMember{ + kind = AstClassMemberKind::Union, + union_fields = union_fields, + } + + else: + field = self->parse_name_type_value("a method, a field or a union") + if field.value != NULL: + fail(field.value->location, "class fields cannot have default values") + new_member = AstClassMember{ + kind = AstClassMemberKind::Field, + field = field, + } + self->eat_newline() + + result.members = realloc(result.members, (result.nmembers + 1) * sizeof result.members[0]) + result.members[result.nmembers++] = new_member + + self->tokens++ + check_class_for_duplicate_names(&result) + return result + + def parse_enum(self) -> AstEnumDef: + assert self->tokens->is_keyword("enum") + self->tokens++ + + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a name for the enum") + + result = AstEnumDef{ + name = self->tokens->short_string, + name_location = self->tokens->location, + } + self->tokens++ + + self->parse_start_of_body() + while self->tokens->kind != TokenKind::Dedent: + if self->tokens->kind != TokenKind::Name: + self->tokens->fail_expected_got("a name for an enum member") + + for i = 0; i < result.member_count; i++: + if strcmp(result.member_names[i], self->tokens->short_string) == 0: + assert sizeof self->tokens->short_string == 100 + error: byte[200] + sprintf(error, "the enum has two members named '%s'", self->tokens->short_string) + fail(self->tokens->location, error) + + result.member_names = realloc(result.member_names, sizeof result.member_names[0] * (result.member_count + 1)) + result.member_names[result.member_count++] = self->tokens->short_string + self->tokens++ + self->eat_newline() + + self->tokens++ + return result + +def parse(tokens: Token*, stdlib_path: byte*) -> AstFile: + parser = Parser{tokens = tokens, stdlib_path = stdlib_path} + result = AstFile{path = tokens[0].location.path} + + while parser.tokens->is_keyword("import"): + result.imports = realloc(result.imports, sizeof result.imports[0] * (result.nimports+1)) + result.imports[result.nimports++] = parser.parse_import() + + while parser.tokens->kind != TokenKind::EndOfFile: + result.body.statements = realloc(result.body.statements, sizeof result.body.statements[0] * (result.body.nstatements + 1)) + result.body.statements[result.body.nstatements++] = parser.parse_statement() + + return result diff --git a/self_hosted_old/paths.jou b/self_hosted_old/paths.jou new file mode 100644 index 00000000..8f52c706 --- /dev/null +++ b/self_hosted_old/paths.jou @@ -0,0 +1,181 @@ +import "stdlib/mem.jou" +import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/process.jou" + +if WINDOWS: + declare GetModuleFileNameA(hModule: void*, lpFilename: byte*, nSize: int) -> int +elif MACOS: + declare _NSGetExecutablePath(buf: byte*, bufsize: int*) -> int +else: + declare readlink(linkpath: byte*, result: byte*, result_size: long) -> long + +if WINDOWS: + declare _mkdir(path: byte*) -> int +else: + declare mkdir(path: byte*, mode: int) -> int # posix + +declare dirname(path: byte*) -> byte* +declare stat(path: byte*, buf: byte[1000]*) -> int # lol + + +def fail_finding_exe() -> noreturn: + # TODO: include os error message (GetLastError / errno) + fprintf(stderr, "error: cannot locate currently running executable, needed for finding the Jou standard library\n") + exit(1) + + +if WINDOWS: + def find_current_executable() -> byte*: + buf = NULL + for size = 2L; True; size *= 2: + buf = realloc(buf, size) + memset(buf, 0, size) + ret = GetModuleFileNameA(NULL, buf, size as int) + if ret <= 0: + fail_finding_exe() + if ret < size: + # buffer is big enough, it fits + return buf + +elif MACOS: + def find_current_executable() -> byte*: + n = 1 + result: byte* = malloc(n) + ret = _NSGetExecutablePath(result, &n) # sets n to desired size + assert ret < 0 # didn't fit + result = realloc(result, n) + ret = _NSGetExecutablePath(result, &n) + if ret != 0: + fail_finding_exe() + return result + +else: + def find_current_executable() -> byte*: + buf = NULL + for size = 2L; True; size *= 2: + buf = realloc(buf, size) + memset(buf, 0, size) + ret = readlink("/proc/self/exe", buf, size) + if ret <= 0: + fail_finding_exe() + if ret < size: + # buffer is big enough, it fits + return buf + + +def find_installation_directory() -> byte*: + exe = find_current_executable() + result = strdup(dirname(exe)) + free(exe) + return result + + +def find_stdlib() -> byte*: + checked: byte*[3] + memset(&checked, 0, sizeof checked) + + exedir = find_current_executable() + while WINDOWS and strstr(exedir, "\\") != NULL: + *strstr(exedir, "\\") = '/' + + for i = 0; i < sizeof checked / sizeof checked[0]; i++: + tmp = strdup(dirname(exedir)) + free(exedir) + exedir = tmp + + if strlen(exedir) <= 3: + # give up, seems like we reached root of file system (e.g. "C:/" or "/") + break + + path = malloc(strlen(exedir) + 10) + sprintf(path, "%s/stdlib", exedir) + + iojou: byte* = malloc(strlen(path) + 10) + sprintf(iojou, "%s/io.jou", path) + buf: byte[1000] + stat_result = stat(iojou, &buf) + free(iojou) + + if stat_result == 0: + free(exedir) + return path + + checked[i] = path + + # TODO: test this + fprintf(stderr, "error: cannot find the Jou standard library in any of the following locations:\n") + for i = 0; i < sizeof checked / sizeof checked[0] and checked[i] != NULL; i++: + fprintf(stderr, " %s\n", checked[i]) + exit(1) + +# Ignoring return values, because there's currently no way to check errno. +# We need to ignore the error when directory exists already (EEXIST). +# Ideally we wouldn't ignore any other errors. +if WINDOWS: + def my_mkdir(path: byte*) -> None: + _mkdir(path) +else: + def my_mkdir(path: byte*) -> None: + mkdir(path, 0o777) # this is what mkdir in bash does according to strace + +def get_path_to_file_in_jou_compiled(filename: byte*) -> byte*: + # TODO: is placing jou_compiled to current working directory a good idea? + my_mkdir("jou_compiled") + my_mkdir("jou_compiled/self_hosted") + + result: byte* = malloc(strlen(filename) + 100) + sprintf(result, "jou_compiled/self_hosted/%s", filename) + return result + +# TODO: put this to stdlib? or does it do too much for a stdlib function? +def delete_slice(start: byte*, end: byte*) -> None: + memmove(start, end, strlen(end) + 1) + +# In paths, "foo/../" is usually unnecessary, because it goes to a folder "foo" and then +# immediately back up. However, it makes a difference in a few cases: +# +# 1. folder "foo" doesn't exist +# 2. folder "foo" is a symlink to a different place +# 3. we are actually looking at "../../" (so "foo" is "..") +# +# Special cases 1 and 2 are not relevant in the Jou compiler, but special case 3 is relevant +# when importing from "../../file.jou" (bad style, but should work). +# +# This function deletes one unnecessary "foo/../", and may be called recursively to delete +# all of them. +def simplify_dotdot_once(path: byte*) -> bool: + assert strstr(path, "\\") == NULL # should be already taken care of when calling this + + for p = strstr(path, "/../"); p != NULL; p = strstr(&p[1], "/../"): + end = &p[4] + start = p + while start > path and start[-1] != '/': + start-- + + if not starts_with(start, "../"): + delete_slice(start, end) + return True + + return False + +def simplify_path(path: byte*) -> None: + if WINDOWS: + # Backslash to forward slash. + for p = path; *p != '\0'; p++: + if *p == '\\': + *p = '/' + + # Delete "." components. + while starts_with(path, "./"): + delete_slice(path, &path[2]) + + while True: + p = strstr(path, "/./") + if p == NULL: + break # TODO: walrus operator p := strstr(...) + delete_slice(p, &p[2]) + + # Delete unnecessary ".." components. + while simplify_dotdot_once(path): + pass diff --git a/self_hosted_old/runs_wrong.txt b/self_hosted_old/runs_wrong.txt new file mode 100644 index 00000000..d0b06d88 --- /dev/null +++ b/self_hosted_old/runs_wrong.txt @@ -0,0 +1,16 @@ +# This is a list of files that don't behave correctly when ran with the self-hosted compiler. +tests/other_errors/missing_return.jou +tests/other_errors/missing_value_in_return.jou +tests/other_errors/noreturn_but_return_with_value.jou +tests/other_errors/noreturn_but_return_without_value.jou +tests/should_succeed/compiler_cli.jou +tests/should_succeed/linked_list.jou +tests/should_succeed/pointer.jou +tests/should_succeed/printf.jou +tests/other_errors/return_void.jou +tests/should_succeed/stderr.jou +tests/should_succeed/unused_import.jou +tests/wrong_type/cannot_be_indexed.jou +tests/wrong_type/index.jou +tests/should_succeed/method_by_value.jou +tests/wrong_type/self_annotation.jou diff --git a/self_hosted_old/target.jou b/self_hosted_old/target.jou new file mode 100644 index 00000000..a06d91f2 --- /dev/null +++ b/self_hosted_old/target.jou @@ -0,0 +1,70 @@ +# LLVM makes a mess of how to define what kind of computer will run the +# compiled programs. Sometimes it wants a target triple, sometimes a +# data layout. Sometimes it wants a string, sometimes an object +# representing the thing. +# +# This file aims to provide everything you may ever need. Hopefully it +# will make the mess slightly less miserable to you. Just use the global +# "target" variable, it contains everything you will ever need. + +import "./llvm.jou" +import "stdlib/str.jou" +import "stdlib/io.jou" +import "stdlib/process.jou" + +class Target: + triple: byte[100] + data_layout: byte[500] + target: LLVMTarget* + target_machine: LLVMTargetMachine* + target_data: LLVMTargetData* + +global target: Target + +# TODO: run this with atexit() once we have function pointers +#def cleanup() -> None: +# LLVMDisposeTargetMachine(target.target_machine) +# LLVMDisposeTargetData(target.target_data) + +def init_target() -> None: + LLVMInitializeX86TargetInfo() + LLVMInitializeX86Target() + LLVMInitializeX86TargetMC() + LLVMInitializeX86AsmParser() + LLVMInitializeX86AsmPrinter() + + if WINDOWS: + # LLVM's default is x86_64-pc-windows-msvc + target.triple = "x86_64-pc-windows-gnu" + else: + triple = LLVMGetDefaultTargetTriple() + assert strlen(triple) < sizeof target.triple + strcpy(target.triple, triple) + LLVMDisposeMessage(triple) + + error: byte* = NULL + if LLVMGetTargetFromTriple(target.triple, &target.target, &error) != 0: + assert error != NULL + fprintf(stderr, "LLVMGetTargetFromTriple(\"%s\") failed: %s\n", target.triple, error) + exit(1) + assert error == NULL + assert target.target != NULL + + target.target_machine = LLVMCreateTargetMachine( + target.target, + target.triple, + "x86-64", + "", + LLVMCodeGenOptLevel::Default, + LLVMRelocMode::PIC, + LLVMCodeModel::Default, + ) + assert target.target_machine != NULL + + target.target_data = LLVMCreateTargetDataLayout(target.target_machine) + assert target.target_data != NULL + + tmp = LLVMCopyStringRepOfTargetData(target.target_data) + assert strlen(tmp) < sizeof target.data_layout + strcpy(target.data_layout, tmp) + LLVMDisposeMessage(tmp) diff --git a/self_hosted_old/token.jou b/self_hosted_old/token.jou new file mode 100644 index 00000000..008b1369 --- /dev/null +++ b/self_hosted_old/token.jou @@ -0,0 +1,139 @@ +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" +import "./errors_and_warnings.jou" + +# TODO: move to stdlib +declare isprint(b: int) -> int + + +enum TokenKind: + Short + Int + Long + Float + Double + Byte # example: 'a' is 97 as a byte + String + Name + Keyword + Newline + Indent + Dedent + Operator + EndOfFile # Marks the end of an array of tokens. + +class Token: + kind: TokenKind + location: Location + + union: + short_value: short # Short + int_value: int # Int + long_value: long # Long + byte_value: byte # Byte + indentation_level: int # Newline (indicates how many spaces there are after the newline) + short_string: byte[100] # Name, Keyword, Operator + long_string: byte* # String + + def print(self) -> None: + if self->kind == TokenKind::Byte: + printf("byte %#02x", self->byte_value) + if isprint(self->byte_value) != 0: + printf(" '%c'", self->byte_value) + printf("\n") + elif self->kind == TokenKind::Short: + printf("short %hd\n", self->short_value) + elif self->kind == TokenKind::Int: + printf("integer %d\n", self->int_value) + elif self->kind == TokenKind::Long: + printf("long %lld\n", self->long_value) + elif self->kind == TokenKind::Float: + printf("float %s\n", self->short_string) + elif self->kind == TokenKind::Double: + printf("double %s\n", self->short_string) + elif self->kind == TokenKind::EndOfFile: + printf("end of file\n") + elif self->kind == TokenKind::Operator: + printf("operator '%s'\n", self->short_string) + elif self->kind == TokenKind::Name: + printf("name \"%s\"\n", self->short_string) + elif self->kind == TokenKind::Keyword: + printf("keyword \"%s\"\n", self->short_string) + elif self->kind == TokenKind::Newline: + printf("newline token (next line has %d spaces of indentation)\n", self->indentation_level) + elif self->kind == TokenKind::String: + printf("string \"") + for s = self->long_string; *s != 0; s++: + if isprint(*s) != 0: + putchar(*s) + elif *s == '\n': + printf("\\n") + else: + printf("\\x%02x", *s) + printf("\"\n") + elif self->kind == TokenKind::Indent: + printf("indent (+4 spaces)\n") + elif self->kind == TokenKind::Dedent: + printf("dedent (-4 spaces)\n") + else: + printf("????\n") + + def is_keyword(self, kw: byte*) -> bool: + return self->kind == TokenKind::Keyword and strcmp(self->short_string, kw) == 0 + + def is_operator(self, op: byte*) -> bool: + return self->kind == TokenKind::Operator and strcmp(self->short_string, op) == 0 + + def is_comparison(self) -> bool: + return ( + self->is_operator("==") + or self->is_operator("!=") + or self->is_operator("<") + or self->is_operator(">") + or self->is_operator("<=") + or self->is_operator(">=") + ) + + def is_open_paren(self) -> bool: + return self->is_operator("(") or self->is_operator("[") or self->is_operator("{") + + def is_close_paren(self) -> bool: + return self->is_operator(")") or self->is_operator("]") or self->is_operator("}") + + def fail_expected_got(self, what_was_expected_instead: byte*) -> None: + got: byte[100] + if self->kind == TokenKind::Short: + got = "a short" + elif self->kind == TokenKind::Int: + got = "an integer" + elif self->kind == TokenKind::Long: + got = "a long integer" + elif self->kind == TokenKind::Float: + got = "a float constant" + elif self->kind == TokenKind::Double: + got = "a double constant" + elif self->kind == TokenKind::Byte: + got = "a byte literal" + elif self->kind == TokenKind::String: + got = "a string" + elif self->kind == TokenKind::Name: + snprintf(got, sizeof got, "a variable name '%s'", self->short_string) + elif self->kind == TokenKind::Keyword: + snprintf(got, sizeof got, "the '%s' keyword", self->short_string) + elif self->kind == TokenKind::Newline: + got = "end of line" + elif self->kind == TokenKind::Indent: + got = "more indentation" + elif self->kind == TokenKind::Dedent: + got = "less indentation" + elif self->kind == TokenKind::Operator: + snprintf(got, sizeof got, "'%s'", self->short_string) + elif self->kind == TokenKind::EndOfFile: + got = "end of file" + else: + assert False + + message: byte* = malloc(strlen(what_was_expected_instead) + 500) + sprintf(message, "expected %s, got %s", what_was_expected_instead, got) + fail(self->location, message) diff --git a/self_hosted_old/tokenizer.jou b/self_hosted_old/tokenizer.jou new file mode 100644 index 00000000..5612e256 --- /dev/null +++ b/self_hosted_old/tokenizer.jou @@ -0,0 +1,624 @@ +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" +import "stdlib/errno.jou" +import "./errors_and_warnings.jou" +import "./token.jou" + +def is_identifier_or_number_byte(b: byte) -> bool: + return ( + ('A' <= b and b <= 'Z') + or ('a' <= b and b <= 'z') + or ('0' <= b and b <= '9') + or b == '_' + ) + +def is_operator_byte(c: byte) -> bool: + return c != '\0' and strchr("=<>!.,()[]{};:+-*/&%|", c) != NULL + +def is_keyword(word: byte*) -> bool: + # This keyword list is in 3 places. Please keep them in sync: + # - the Jou compiler written in C + # - self-hosted compiler + # - syntax documentation + keywords = [ + "import", "def", "declare", "class", "union", "enum", "global", + "return", "if", "elif", "else", "while", "for", "pass", "break", "continue", + "True", "False", "None", "NULL", "void", "noreturn", + "and", "or", "not", "self", "as", "sizeof", "assert", + "bool", "byte", "short", "int", "long", "float", "double", + ] + + for i = 0; i < sizeof keywords / sizeof keywords[0]; i++: + if strcmp(keywords[i], word) == 0: + return True + return False + +def hexdigit_value(c: byte) -> int: + if 'A' <= c and c <= 'F': + return 10 + (c - 'A') + if 'a' <= c and c <= 'f': + return 10 + (c - 'a') + if '0' <= c and c <= '9': + return c - '0' + return -1 + +def parse_integer(string: byte*, location: Location, nbits: int) -> long: + if starts_with(string, "0b"): + base = 2 + digits = &string[2] + valid_digits = "01" + elif starts_with(string, "0o"): + base = 8 + digits = &string[2] + valid_digits = "01234567" + elif starts_with(string, "0x"): + base = 16 + digits = &string[2] + valid_digits = "0123456789ABCDEFabcdef" + elif starts_with(string, "0") and strlen(string) >= 2: + # wrong syntax like 0777 + fail(location, "unnecessary zero at start of number") + else: + # default decimal number + base = 10 + digits = string + valid_digits = "0123456789" + + if strlen(digits) == 0 or strspn(digits, valid_digits) != strlen(digits): + message = malloc(strlen(string) + 100) + sprintf(message, "invalid number or variable name \"%s\"", string) + fail(location, message) + + # We can't use strtoll() or similar, because there's not yet a way to check errno. + # We would need it to check for overflow. + result = 0L + overflow = False + + for i = 0; i < strlen(digits); i++: + # Overflow isn't UB in Jou + if (result * base) / base != result: + overflow = True + break + result *= base + + if result + hexdigit_value(digits[i]) < 0: + overflow = True + break + result += hexdigit_value(digits[i]) + + assert nbits == 16 or nbits == 32 or nbits == 64 + if nbits == 32 and (result as int) != result: + overflow = True + + if nbits == 16 and (result as short) != result: + overflow = True + + if overflow: + message = malloc(100) + sprintf(message, "value does not fit in a signed %d-bit integer", nbits) + fail(location, message) + + return result + +# Does a string contain multiple occurrences of the given byte? +def has_multiple(s: byte*, b: byte) -> bool: + return strchr(s, b) != strrchr(s, b) + +def is_valid_double(str: byte*) -> bool: + if strspn(str, "0123456789.-e") < strlen(str): + return False + if has_multiple(str, '.') or has_multiple(str, '-') or has_multiple(str, 'e'): + return False + + dot = strchr(str, '.') + minus = strchr(str, '-') + e = strchr(str, 'e') + + return ( + # 123.456 + e == NULL and minus == NULL and dot != NULL + ) or ( + # 1e-4 + e != NULL and e[1] != '\0' and (dot == NULL or dot < e) and (minus == NULL or (&e[1] == minus and e[2] != '\0')) + ) + +def is_valid_float(str: byte[100]) -> bool: + n = strlen(str) + if n == 0 or (str[n-1] != 'F' and str[n-1] != 'f'): + return False + str[n-1] = '\0' + return is_valid_double(str) or strspn(str, "0123456789") == strlen(str) + + +def flip_paren(c: byte) -> byte: + if c == '(': + return ')' + if c == ')': + return '(' + if c == '[': + return ']' + if c == ']': + return '[' + if c == '{': + return '}' + if c == '}': + return '{' + assert False + + +class Tokenizer: + f: FILE* + location: Location + pushback: byte* + pushback_len: int # TODO: dynamic array + # Parens array isn't dynamic, so that you can't segfault + # the compiler by feeding it lots of nested parentheses, + # which would make it recurse too deep. + open_parens: Token[50] + open_parens_len: int + + def read_byte(self) -> byte: + EOF = -1 # FIXME + + c: byte + if self->pushback_len > 0: + c = self->pushback[--self->pushback_len] + else: + temp = fgetc(self->f) + if temp == '\r': + # On Windows, \r just before \n is ignored. + temp = fgetc(self->f) + if temp != EOF and temp != '\n': + # TODO: test this, if possible? + fail(self->location, "source file contains a CR byte ('\\r') that isn't a part of a CRLF line ending") + + if temp == EOF: + if ferror(self->f) != 0: + # TODO: include errno in the error message + fail(self->location, "cannot read file") + # Use the zero byte to denote end of file. + c = '\0' + elif temp == '\0': + fail(self->location, "source file contains a zero byte") + else: + c = temp as byte + + if c == '\n': + self->location.lineno++ + return c + + def unread_byte(self, b: byte) -> None: + if b == '\0': + return + + assert b != '\r' + self->pushback = realloc(self->pushback, self->pushback_len + 1) + self->pushback[self->pushback_len++] = b + if b == '\n': + self->location.lineno-- + + def read_identifier_or_number(self, first_byte: byte) -> byte[100]: + dest: byte[100] + memset(&dest, 0, sizeof dest) + destlen = 0 + + assert is_identifier_or_number_byte(first_byte) + dest[destlen++] = first_byte + is_number = '0' <= first_byte and first_byte <= '9' + + while True: + b = self->read_byte() + if ( + is_identifier_or_number_byte(b) + or (is_number and (b == '.' or (b == '-' and dest[destlen-1] == 'e'))) + ): + if destlen == sizeof dest - 1: + if is_number: + template = "number is too long: %.20s..." + else: + template = "name is too long: %.20s..." + message: byte[100] + sprintf(message, template, dest) + fail(self->location, message) + dest[destlen++] = b + else: + self->unread_byte(b) + return dest + + def consume_rest_of_line(self) -> None: + while True: + c = self->read_byte() + if c == '\0' or c == '\n': + self->unread_byte(c) + break + + # Returns the indentation level for the next line + def read_newline_token(self) -> int: + level = 0 + while True: + c = self->read_byte() + if c == '\0': + # End of file. Do not validate that indentation is a + # multiple of 4 spaces. Add a trailing newline implicitly + # if needed. + # + # TODO: test this + return 0 + elif c == '\n': + level = 0 + elif c == '#': + self->consume_rest_of_line() + elif c == ' ': + level++ + else: + self->unread_byte(c) + return level + + def read_hex_escape_byte(self) -> byte: + n1 = hexdigit_value(self->read_byte()) + n2 = hexdigit_value(self->read_byte()) + if n1 == -1 or n2 == -1: + fail(self->location, "\\x must be followed by two hexadecimal digits (0-9, A-F) to specify a byte") + return (n1*16 + n2) as byte + + # Assumes the initial ' has been read. + def read_byte_literal(self) -> byte: + c = self->read_byte() + if c == '\'': + fail(self->location, "a byte literal cannot be empty, maybe use double quotes to instead make a string?") + if c == '\0' or c == '\n': + if c == '\n': + self->location.lineno-- + fail(self->location, "missing ' to end the byte literal") + + if c == '\\': + after_backslash = self->read_byte() + if after_backslash == '\0' or after_backslash == '\n': + self->location.lineno-- + fail(self->location, "missing ' to end the byte literal") + elif after_backslash == 'n': + c = '\n' + elif after_backslash == 'r': + c = '\r' + elif after_backslash == 't': + c = '\t' + elif after_backslash == '\\': + c = '\\' + elif after_backslash == '\'': + c = '\'' + elif after_backslash == '"': + fail(self->location, "double quotes shouldn't be escaped in byte literals") + elif after_backslash == '0': + c = '\0' + elif after_backslash == 'x': + c = self->read_hex_escape_byte() + elif after_backslash < 0x80 and isprint(after_backslash) != 0: + message: byte* = malloc(100) + sprintf(message, "unknown escape: '\\%c'", after_backslash) + fail(self->location, message) + else: + fail(self->location, "unknown '\\' escape") + + end = self->read_byte() + if end != '\'': + # If there's another single quote later on the same line, suggest using double quotes. + location = self->location + while True: + c = self->read_byte() + if c == '\0' or c == '\n': + break + if c == '\'': + fail(location, "single quotes are for specifying a byte, maybe use double quotes to instead make a string?") + fail(location, "missing ' to end the byte literal") + + return c + + # Assumes the initial " has been read. + def read_string(self) -> byte*: + result: byte* = NULL + len = 0 + + while True: + c = self->read_byte() + if c == '"': + break + elif c == '\n' or c == '\0': + if c == '\n': + self->location.lineno-- + fail(self->location, "missing \" to end the string") + elif c == '\\': + # \n means newline, for example + after_backslash = self->read_byte() + if after_backslash == '\0': + fail(self->location, "missing \" to end the string") + elif after_backslash == 'n': + result = realloc(result, len+1) + result[len++] = '\n' + elif after_backslash == 'r': + result = realloc(result, len+1) + result[len++] = '\r' + elif after_backslash == 't': + result = realloc(result, len+1) + result[len++] = '\t' + elif after_backslash == '\\' or after_backslash == '"': + result = realloc(result, len+1) + result[len++] = after_backslash + elif after_backslash == '\'': + fail(self->location, "single quotes shouldn't be escaped in strings") + elif after_backslash == '0': + fail(self->location, "strings cannot contain zero bytes (\\0), because that is the special end marker byte") + elif after_backslash == 'x': + b = self->read_hex_escape_byte() + if b == '\0': + fail(self->location, "strings cannot contain zero bytes (\\x00), because that is the special end marker byte") + result = realloc(result, len+1) + result[len++] = b + elif after_backslash == '\n': + # \ at end of line, string continues on next line + pass + else: + if after_backslash < 0x80 and isprint(after_backslash) != 0: + message: byte* = malloc(100) + sprintf(message, "unknown escape: '\\%c'", after_backslash) + fail(self->location, message) + else: + fail(self->location, "unknown '\\' escape") + else: + result = realloc(result, len+1) + result[len++] = c + + result = realloc(result, len+1) + result[len] = '\0' + return result + + def read_operator(self) -> byte[100]: + operators = [ + # This list of operators is in 3 places. Please keep them in sync: + # - the Jou compiler written in C + # - self-hosted compiler + # - syntax documentation + # + # Longer operators are first, so that '==' does not tokenize as '=' '=' + "...", "===", "!==", + "==", "!=", "->", "<=", ">=", "++", "--", "+=", "-=", "*=", "/=", "%=", "::", "&&", "||", + ".", ",", ":", ";", "=", "(", ")", "{", "}", "[", "]", "&", "%", "*", "/", "+", "-", "<", ">", "!", + ] + + operator: byte[100] + memset(&operator, 0, sizeof operator) + + # Read as many operator characters as we may need. + while strlen(operator) < 3: + c = self->read_byte() + if not is_operator_byte(c): + self->unread_byte(c) + break + operator[strlen(operator)] = c + + for i = 0; i < sizeof operators / sizeof operators[0]; i++: + if starts_with(operator, operators[i]): + # Unread the bytes we didn't use. + while strlen(operator) > strlen(operators[i]): + last = &operator[strlen(operator) - 1] + self->unread_byte(*last) + *last = '\0' + + # These operators are here only to give a better error message when you + # accidentally use syntax of another programming language. + if strcmp(operator, "===") == 0: + fail(self->location, "use '==' instead of '==='") + if strcmp(operator, "!==") == 0: + fail(self->location, "use '!=' instead of '!=='") + if strcmp(operator, "&&") == 0: + fail(self->location, "use 'and' instead of '&&'") + if strcmp(operator, "||") == 0: + fail(self->location, "use 'or' instead of '||'") + if strcmp(operator, "!") == 0: + fail(self->location, "use 'not' instead of '!'") + + return operator + + message: byte[100] + sprintf(message, "there is no '%s' operator", operator) + fail(self->location, message) + + def handle_parentheses(self, token: Token*) -> None: + if token->kind == TokenKind::EndOfFile and self->open_parens_len > 0: + open_token = self->open_parens[0] + actual_open = open_token.short_string[0] + expected_close = flip_paren(actual_open) + + message = malloc(100) + sprintf(message, "'%c' without a matching '%c'", actual_open, expected_close) + fail(open_token.location, message) + + if token->is_open_paren(): + if self->open_parens_len == sizeof self->open_parens / sizeof self->open_parens[0]: + fail(token->location, "too many nested parentheses") + self->open_parens[self->open_parens_len++] = *token + + if token->is_close_paren(): + actual_close = token->short_string[0] + expected_open = flip_paren(actual_close) + if self->open_parens_len == 0 or self->open_parens[--self->open_parens_len].short_string[0] != expected_open: + message = malloc(100) + sprintf(message, "'%c' without a matching '%c'", actual_close, expected_open) + fail(token->location, message) + + def read_token(self) -> Token: + while True: + token = Token{location = self->location} + b = self->read_byte() + + if b == ' ': + continue + if b == '#': + self->consume_rest_of_line() + continue + + if b == '\n': + if self->open_parens_len > 0: + continue + token.kind = TokenKind::Newline + token.indentation_level = self->read_newline_token() + elif b == '"': + token.kind = TokenKind::String + token.long_string = self->read_string() + elif b == '\'': + token.kind = TokenKind::Byte + token.byte_value = self->read_byte_literal() + elif is_identifier_or_number_byte(b): + token.short_string = self->read_identifier_or_number(b) + if is_keyword(token.short_string): + token.kind = TokenKind::Keyword + elif '0' <= token.short_string[0] and token.short_string[0] <= '9': + if is_valid_double(token.short_string): + token.kind = TokenKind::Double + elif is_valid_float(token.short_string): + token.short_string[strlen(token.short_string) - 1] = '\0' # delete trailing 'f' or 'F' + token.kind = TokenKind::Float + elif token.short_string[strlen(token.short_string) - 1] == 'L': + token.short_string[strlen(token.short_string) - 1] = '\0' + token.kind = TokenKind::Long + token.long_value = parse_integer(token.short_string, token.location, 64) + elif token.short_string[strlen(token.short_string) - 1] == 'S': + token.short_string[strlen(token.short_string) - 1] = '\0' + token.kind = TokenKind::Short + token.short_value = parse_integer(token.short_string, token.location, 16) as short + else: + token.kind = TokenKind::Int + token.int_value = parse_integer(token.short_string, token.location, 32) as int + else: + token.kind = TokenKind::Name + elif is_operator_byte(b): + self->unread_byte(b) + token.kind = TokenKind::Operator + token.short_string = self->read_operator() + elif b == '\t': + fail(self->location, "Jou files cannot contain tab characters (use 4 spaces for indentation)") + elif b == '\0': + token.kind = TokenKind::EndOfFile + else: + message: byte[100] + if b < 0x80 and isprint(b) != 0: + sprintf(message, "unexpected byte '%c' (%#02x)", b, b) + else: + sprintf(message, "unexpected byte %#02x", b) + fail(self->location, message) + + self->handle_parentheses(&token) + return token + + +def tokenize_without_indent_dedent_tokens(file: FILE*, path: byte*) -> Token*: + tokenizer = Tokenizer{ + location = Location{path = path}, + f = file, + } + + # Add a fake newline to the beginning. It does a few things: + # * Less special-casing: blank lines in the beginning of the file can + # cause there to be a newline token anyway. + # * It is easier to detect an unexpected indentation in the beginning + # of the file, as it becomes just like any other indentation. + # * Line numbers start at 1. + tokenizer.pushback = malloc(1) + tokenizer.pushback[0] = '\n' + tokenizer.pushback_len = 1 + + tokens: Token* = NULL + len = 0 + while len == 0 or tokens[len-1].kind != TokenKind::EndOfFile: + tokens = realloc(tokens, sizeof(tokens[0]) * (len+1)) + tokens[len++] = tokenizer.read_token() + + free(tokenizer.pushback) + return tokens + + +# Creates a new array of tokens with indent/dedent tokens added after +# newline tokens that change the indentation level. +def handle_indentations(raw_tokens: Token*) -> Token*: + tokens: Token* = NULL + ntokens = 0 + level = 0 + + for t = raw_tokens; True; t++: + if t->kind == TokenKind::EndOfFile: + # Add an extra newline token at end of file and the dedents after it. + # This makes it similar to how other newline and dedent tokens work: + # the dedents always come after a newline token. + tokens = realloc(tokens, sizeof tokens[0] * (ntokens + level/4 + 1)) + while level != 0: + tokens[ntokens++] = Token{location = t->location, kind = TokenKind::Dedent} + level -= 4 + tokens[ntokens++] = *t + break + + tokens = realloc(tokens, sizeof tokens[0] * (ntokens+1)) + tokens[ntokens++] = *t + + if t->kind == TokenKind::Newline: + after_newline = t->location + after_newline.lineno++ + + if t->indentation_level % 4 != 0: + fail(after_newline, "indentation must be a multiple of 4 spaces") + + while level < t->indentation_level: + tokens = realloc(tokens, sizeof tokens[0] * (ntokens+1)) + tokens[ntokens++] = Token{location = after_newline, kind = TokenKind::Indent} + level += 4 + + while level > t->indentation_level: + tokens = realloc(tokens, sizeof tokens[0] * (ntokens+1)) + tokens[ntokens++] = Token{location = after_newline, kind = TokenKind::Dedent} + level -= 4 + + # Delete the newline token in the beginning. + # + # If the file has indentations after it, they are now represented by separate + # indent tokens and parsing will fail. If the file doesn't have any blank/comment + # lines in the beginning, it has a newline token anyway to avoid special casing. + assert tokens[0].kind == TokenKind::Newline + memmove(&tokens[0], &tokens[1], sizeof tokens[0] * (ntokens - 1)) + + return tokens + + +def tokenize(path: byte*, import_location: Location*) -> Token*: + file = fopen(path, "rb") + if file == NULL: + message: byte[200] + if import_location == NULL: + # File is not imported + snprintf(message, sizeof message, "cannot open file: %s", strerror(get_errno())) + fail(Location{path=path}, message) + else: + snprintf(message, sizeof message, "cannot import from \"%s\": %s", path, strerror(get_errno())) + fail(*import_location, message) + + raw_tokens = tokenize_without_indent_dedent_tokens(file, path) + better_tokens = handle_indentations(raw_tokens) + free(raw_tokens) + return better_tokens + +def print_tokens(tokens: Token*) -> None: + printf("===== Tokens for file \"%s\" =====\n", tokens->location.path) + t = tokens + current_lineno = -1 + + while True: + if t->location.lineno != current_lineno: + current_lineno = t->location.lineno + printf("\nLine %d:\n", current_lineno) + + printf(" ") + t->print() + + if t->kind == TokenKind::EndOfFile: + break + t++ + + printf("\n") diff --git a/self_hosted_old/typecheck.jou b/self_hosted_old/typecheck.jou new file mode 100644 index 00000000..3eddb036 --- /dev/null +++ b/self_hosted_old/typecheck.jou @@ -0,0 +1,1433 @@ +# Type checking is split into several stages: +# 1. Create types. After this, classes defined in Jou exist, but +# they are opaque and contain no members. Enums exist and contain +# their members (although it doesn't really matter whether enum +# members are handled in stage 1 or 2). +# 2. Check signatures, global variables and class bodies, but ignore +# bodies of functions and methods. This stage assumes that all +# types exist, but doesn't need to know what fields each class has. +# 3. Check function and method bodies. +# +# The goal of this design is to make cyclic imports possible. At each +# stage, we don't need the results from the same stage, only from +# previous stages. This means that cyclic imports "just work" if we do +# each stage on all files before moving on to the next stage. + +import "stdlib/io.jou" +import "stdlib/str.jou" +import "stdlib/mem.jou" +import "./ast.jou" +import "./types.jou" +import "./errors_and_warnings.jou" +import "./evaluate.jou" + + +def can_cast_implicitly(from: Type*, to: Type*) -> bool: + # TODO: document these properly. But they are: + # array to pointer, e.g. int[3] --> int* (needs special-casing elsewhere) + # from one integer type to another bigger integer type, unless it is signed-->unsigned + # between two pointer types when one of the two is void* + # from float to double (TODO) + return ( + from == to + or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) + or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) + or ( + from->is_integer_type() + and to->is_integer_type() + and from->size_in_bits < to->size_in_bits + and not (from->kind == TypeKind::SignedInteger and to->kind == TypeKind::UnsignedInteger) + ) + or (from == &float_type and to == &double_type) + or (from->is_integer_type() and to->kind == TypeKind::FloatingPoint) + or (from->is_pointer_type() and to->is_pointer_type() and (from == &void_ptr_type or to == &void_ptr_type)) + ) + +def can_cast_explicitly(from: Type*, to: Type*) -> bool: + return ( + from == to + or (from->kind == TypeKind::Array and to->kind == TypeKind::Pointer and from->array.item_type == to->value_type) + or (from->kind == TypeKind::Array and to->kind == TypeKind::VoidPointer) + or (from->is_pointer_type() and to->is_pointer_type()) + or (from->is_number_type() and to->is_number_type()) + or (from->is_integer_type() and to->kind == TypeKind::Enum) + or (from->kind == TypeKind::Enum and to->is_integer_type()) + or (from == &bool_type and to->is_integer_type()) + or (from->is_pointer_type() and to == long_type) + or (from == long_type and to->is_pointer_type()) + ) + +# Implicit casts are used in many places, e.g. function arguments. +# +# When you pass an argument of the wrong type, it's best to give an error message +# that says so, instead of some generic "expected type foo, got object of type bar" +# kind of message. +# +# The template can contain "" and "". They will be substituted with names +# of types. We cannot use printf() style functions because the arguments can be in +# any order. +def fail_with_implicit_cast_error(location: Location, template: byte*, from: Type*, to: Type*) -> None: + assert template != NULL + + n = 0 + for i = 0; template[i] != '\0'; i++: + if template[i] == '<': + n++ + + message: byte* = malloc(sizeof(from->name)*n + strlen(template) + 1) + message[0] = '\0' + while *template != '\0': + if starts_with(template, ""): + template = &template[6] + strcat(message, from->name) + elif starts_with(template, ""): + template = &template[4] + strcat(message, to->name) + else: + s = [*template++, '\0'] + strcat(message, s) + + fail(location, message) + + +# To understand the purpose of ExportSymbol, suppose file A imports file B. +# - Type checking file B produces an ExportSymbol that matches the import in file A. +# - Before the next type checking stage, the ExportSymbol is added to file A's types. +# - During the next stage, file A can use the imported symbol. +enum ExportSymbolKind: + Function + Type + GlobalVariable + +class ExportSymbol: + kind: ExportSymbolKind + name: byte[100] + + union: + signature: Signature # ExportSymbolKind::Function + type: Type* # ExportSymbolKind::Type, ExportSymbolKind::GlobalVariable + + def print(self) -> None: + if self->kind == ExportSymbolKind::Function: + s = self->signature.to_string(True, True) + printf("ExportSymbol: function %s\n", s) + free(s) + elif self->kind == ExportSymbolKind::Type: + printf("ExportSymbol: type %s as \"%s\"\n", self->type->name, self->name) + elif self->kind == ExportSymbolKind::GlobalVariable: + printf("ExportSymbol: variable %s: %s\n", self->name, self->type->name) + else: + assert False + +class ExpressionTypes: + expression: AstExpression* + original_type: Type* + implicit_cast_type: Type* # NULL if no implicit casting is needed + next: ExpressionTypes* # TODO: switch to more efficient structure than linked list? + + # Flags to indicate whether special kinds of implicit casts happened + implicit_array_to_pointer_cast: bool # Foo[N] to Foo* + implicit_string_to_array_cast: bool # "..." to byte[N] + + def get_type_after_implicit_cast(self) -> Type*: + assert self->original_type != NULL + if self->implicit_cast_type == NULL: + return self->original_type + return self->implicit_cast_type + + # TODO: error_location is probably unnecessary, can get location from self->expression + def do_implicit_cast(self, to: Type*, error_location: Location, error_template: byte*) -> None: + # This cannot be called multiple times + assert self->implicit_cast_type == NULL + assert not self->implicit_array_to_pointer_cast + assert not self->implicit_string_to_array_cast + + from = self->original_type + if from == to: + return + + if ( + self->expression->kind == AstExpressionKind::String + and from == byte_type->get_pointer_type() + and to->kind == TypeKind::Array + and to->array.item_type == byte_type + ): + string_size = strlen(self->expression->string) + 1 + if to->array.length < string_size: + message: byte[100] + snprintf( + message, sizeof message, + "a string of %d bytes (including '\\0') does not fit into %s", + string_size, to->name, + ) + fail(error_location, message) + self->implicit_string_to_array_cast = True + # Passing in NULL for error_template can be used to force a cast to happen. + elif error_template != NULL and not can_cast_implicitly(from, to): + fail_with_implicit_cast_error(error_location, error_template, from, to) + + self->implicit_cast_type = to + if from->kind == TypeKind::Array and to->is_pointer_type(): + self->implicit_array_to_pointer_cast = True + ensure_can_take_address( + self->expression, + "cannot create a pointer into an array that comes from %s (try storing it to a local variable first)", + ) + + # Does not store the new type to self, because explicit casts have their own AstExpression which has its own expression types. + def do_explicit_cast(self, to: Type*, error_location: Location) -> None: + assert self->implicit_cast_type == NULL + assert not self->implicit_array_to_pointer_cast + + from = self->original_type + if not can_cast_explicitly(from, to): + message: byte[500] + snprintf(&message[0], sizeof message, "cannot cast from type %s to %s", from->name, to->name) + fail(error_location, message) + + if from->kind == TypeKind::Array and to->is_pointer_type(): + self->cast_array_to_pointer() + + def cast_array_to_pointer(self) -> None: + assert self->original_type->kind == TypeKind::Array + self->do_implicit_cast(self->original_type->array.item_type->get_pointer_type(), Location{}, NULL) + +class LocalVariable: + name: byte[100] + type: Type* + next: LocalVariable* # TODO: switch to more efficient structure than linked list? + +class GlobalVariable: + name: byte[100] + type: Type* + +class FunctionOrMethodTypes: + signature: Signature + expression_types: ExpressionTypes* + local_vars: LocalVariable* + + def get_expression_types(self, expr: AstExpression*) -> ExpressionTypes*: + for et = self->expression_types; et != NULL; et = et->next: + if et->expression == expr: + return et + return NULL + + def find_local_var(self, name: byte*) -> LocalVariable*: + for v = self->local_vars; v != NULL; v = v->next: + if strcmp(v->name, name) == 0: + return v + return NULL + +# All type information for a Jou file. This is initially empty, and is filled during each stage of type checking. +class FileTypes: + # Includes imported and defined functions. + all_functions: Signature* + n_all_functions: int + + defined_functions: FunctionOrMethodTypes* + n_defined_functions: int + + types: Type** + ntypes: int + + globals: GlobalVariable* + nglobals: int + + def add_imported_symbol(self, symbol: ExportSymbol*) -> None: + if symbol->kind == ExportSymbolKind::Type: + self->types = realloc(self->types, (self->ntypes + 1) * sizeof(self->types[0])) + self->types[self->ntypes++] = symbol->type + elif symbol->kind == ExportSymbolKind::Function: + self->all_functions = realloc(self->all_functions, sizeof self->all_functions[0] * (self->n_all_functions + 1)) + self->all_functions[self->n_all_functions++] = symbol->signature.copy() + elif symbol->kind == ExportSymbolKind::GlobalVariable: + pass # TODO + else: + symbol->print() + assert False + + def find_function(self, name: byte*) -> Signature*: + for i = 0; i < self->n_all_functions; i++: + if strcmp(self->all_functions[i].name, name) == 0: + return &self->all_functions[i] + return NULL + + # If class_type is NULL, this finds a function + def find_defined_function_or_method(self, name: byte*, class_type: Type*) -> FunctionOrMethodTypes*: + assert class_type == NULL or class_type->kind == TypeKind::Class + for i = 0; i < self->n_defined_functions; i++: + if ( + strcmp(self->defined_functions[i].signature.name, name) == 0 + and self->defined_functions[i].signature.get_containing_class() == class_type + ): + return &self->defined_functions[i] + return NULL + + def find_type(self, name: byte*) -> Type*: + for i = 0; i < self->ntypes; i++: + if strcmp(self->types[i]->name, name) == 0: + return self->types[i] + return NULL + + def find_global_var(self, name: byte*) -> Type*: + for i = 0; i < self->nglobals; i++: + if strcmp(self->globals[i].name, name) == 0: + return self->globals[i].type + return NULL + +def check_type_doesnt_exist(ft: FileTypes*, name: byte*, location: Location) -> None: + existing = ft->find_type(name) + if existing != NULL: + description = short_type_description(existing) + message: byte[500] + snprintf(message, sizeof message, "%s named '%s' already exists", description, name) + fail(location, message) + +def typecheck_stage1_create_types(ft: FileTypes*, file: AstFile*) -> ExportSymbol*: + exports: ExportSymbol* = NULL + nexports = 0 + + for i = 0; i < file->body.nstatements; i++: + if file->body.statements[i].kind == AstStatementKind::Class: + classdef = &file->body.statements[i].classdef + check_type_doesnt_exist(ft, classdef->name, classdef->name_location) + t = create_opaque_class(classdef->name) + elif file->body.statements[i].kind == AstStatementKind::Enum: + enumdef = &file->body.statements[i].enumdef + check_type_doesnt_exist(ft, enumdef->name, enumdef->name_location) + t = create_enum(enumdef->name, enumdef->member_count, enumdef->member_names) + else: + continue + + ft->types = realloc(ft->types, (ft->ntypes + 1) * sizeof ft->types[0]) + ft->types[ft->ntypes++] = t + exports = realloc(exports, (nexports + 1) * sizeof exports[0]) + exports[nexports++] = ExportSymbol{ + kind = ExportSymbolKind::Type, + name = t->name, + type = t, + } + + exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports[nexports] = ExportSymbol{} + return exports + + +def evaluate_array_length(expression: AstExpression*) -> int: + # TODO: support something more fancy? + if expression->kind == AstExpressionKind::Int: + return expression->int_value + fail(expression->location, "cannot evaluate array length at compile time") + +def type_from_ast(ft: FileTypes*, ast_type: AstType*) -> Type*: + if ast_type->is_void(): + fail(ast_type->location, "'void' cannot be used here because it is not a type") + if ast_type->is_none(): + fail(ast_type->location, "'None' cannot be used here because it is not a type") + if ast_type->is_noreturn(): + fail(ast_type->location, "'noreturn' cannot be used here because it is not a type") + + if ast_type->kind == AstTypeKind::Named: + if strcmp(ast_type->name, "short") == 0: + return short_type + if strcmp(ast_type->name, "int") == 0: + return int_type + if strcmp(ast_type->name, "long") == 0: + return long_type + if strcmp(ast_type->name, "byte") == 0: + return byte_type + if strcmp(ast_type->name, "bool") == 0: + return &bool_type + if strcmp(ast_type->name, "float") == 0: + return &float_type + if strcmp(ast_type->name, "double") == 0: + return &double_type + + result = ft->find_type(ast_type->name) + if result != NULL: + return result + + message: byte* = malloc(strlen(ast_type->name) + 100) + sprintf(message, "there is no type named '%s'", ast_type->name) + fail(ast_type->location, message) + + if ast_type->kind == AstTypeKind::Pointer: + if ast_type->value_type->is_void(): + return &void_ptr_type + return type_from_ast(ft, ast_type->value_type)->get_pointer_type() + + if ast_type->kind == AstTypeKind::Array: + member_type = type_from_ast(ft, ast_type->array.member_type) + length = evaluate_array_length(ast_type->array.length) + if length <= 0: + fail(ast_type->array.length->location, "array length must be positive") + return member_type->get_array_type(length) + + ast_type->print(True) + printf("\n") + assert False # TODO + +def handle_signature(ft: FileTypes*, astsig: AstSignature*, self_type: Type*) -> Signature: + assert self_type == NULL or self_type->kind == TypeKind::Class + + sig = Signature{ + name = astsig->name, + nargs = astsig->nargs, + takes_varargs = astsig->takes_varargs, + } + + sig.argnames = malloc(sizeof sig.argnames[0] * sig.nargs) + for i = 0; i < sig.nargs; i++: + sig.argnames[i] = astsig->args[i].name + + sig.argtypes = malloc(sizeof sig.argtypes[0] * sig.nargs) + for i = 0; i < sig.nargs; i++: + if strcmp(astsig->args[i].name, "self") == 0: + assert self_type != NULL + sig.argtypes[i] = self_type->get_pointer_type() + else: + sig.argtypes[i] = type_from_ast(ft, &astsig->args[i].type) + + if astsig->return_type.is_none() or astsig->return_type.is_noreturn(): + sig.return_type = NULL + else: + sig.return_type = type_from_ast(ft, &astsig->return_type) + + if self_type == NULL and strcmp(sig.name, "main") == 0: + # special main() function checks + if sig.return_type != int_type: + fail(astsig->return_type.location, "the main() function must return int") + if sig.nargs != 0 and not ( + sig.nargs == 2 + and sig.argtypes[0] == int_type + and sig.argtypes[1] == byte_type->get_pointer_type()->get_pointer_type() + ): + fail( + astsig->args[0].type.location, + "if the main() function takes parameters, it should be defined like this: def main(argc: int, argv: byte**) -> int" + ) + + return sig + +def handle_class_members_stage2(ft: FileTypes*, classdef: AstClassDef*) -> None: + # Previous type-checking stage created an opaque type. + type: Type* = NULL + for i = 0; i < ft->ntypes; i++: + if strcmp(ft->types[i]->name, classdef->name) == 0: + type = ft->types[i] + break + assert type != NULL + + assert type->kind == TypeKind::OpaqueClass + type->kind = TypeKind::Class + + memset(&type->class_members, 0, sizeof type->class_members) + + union_id = 0 + for i = 0; i < classdef->nmembers; i++: + member = &classdef->members[i] + if member->kind == AstClassMemberKind::Field: + type->class_members.fields = realloc(type->class_members.fields, (type->class_members.nfields + 1) * sizeof type->class_members.fields[0]) + type->class_members.fields[type->class_members.nfields++] = ClassField{ + name = member->field.name, + type = type_from_ast(ft, &member->field.type), + union_id = union_id++, + } + elif member->kind == AstClassMemberKind::Union: + uid = union_id++ + for k = 0; k < member->union_fields.nfields; k++: + type->class_members.fields = realloc(type->class_members.fields, (type->class_members.nfields + 1) * sizeof type->class_members.fields[0]) + type->class_members.fields[type->class_members.nfields++] = ClassField{ + name = member->union_fields.fields[k].name, + type = type_from_ast(ft, &member->union_fields.fields[k].type), + union_id = uid, + } + elif member->kind == AstClassMemberKind::Method: + # Don't handle the method body yet: that is a part of stage 3, not stage 2 + sig = handle_signature(ft, &member->method.signature, type) + type->class_members.methods = realloc(type->class_members.methods, sizeof type->class_members.methods[0] * (type->class_members.nmethods + 1)) + type->class_members.methods[type->class_members.nmethods++] = sig + else: + assert False + +# Returned array is terminated by ExportSymbol with empty name. +def typecheck_stage2_populate_types(ft: FileTypes*, ast_file: AstFile*) -> ExportSymbol*: + message: byte[200] + + exports: ExportSymbol* = NULL + nexports = 0 + + for i = 0; i < ast_file->body.nstatements; i++: + ts = &ast_file->body.statements[i] + + if ts->kind == AstStatementKind::Function: + if ft->find_function(ts->function.signature.name) != NULL: + snprintf( + message, sizeof message, + "a function named '%s' already exists", + ts->function.signature.name, + ) + fail(ts->location, message) + + sig = handle_signature(ft, &ts->function.signature, NULL) + ft->all_functions = realloc(ft->all_functions, sizeof ft->all_functions[0] * (ft->n_all_functions + 1)) + ft->all_functions[ft->n_all_functions++] = sig.copy() + exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports[nexports++] = ExportSymbol{ + kind = ExportSymbolKind::Function, + name = sig.name, + signature = sig, + } + + if ts->kind == AstStatementKind::Class: + handle_class_members_stage2(ft, &ts->classdef) + + if ( + ts->kind == AstStatementKind::GlobalVariableDeclaration + or ts->kind == AstStatementKind::GlobalVariableDefinition + ): + if ft->find_global_var(ts->var_declaration.name) != NULL: + snprintf( + message, sizeof message, + "a global variable named '%s' already exists", + ts->var_declaration.name, + ) + fail(ts->location, message) + + assert ts->var_declaration.value == NULL + type = type_from_ast(ft, &ts->var_declaration.type) + ft->globals = realloc(ft->globals, (ft->nglobals + 1) * sizeof ft->globals[0]) + ft->globals[ft->nglobals++] = GlobalVariable{name = ts->var_declaration.name, type = type} + + exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports[nexports++] = ExportSymbol{ + kind = ExportSymbolKind::GlobalVariable, + name = ts->var_declaration.name, + type = type, + } + + exports = realloc(exports, sizeof exports[0] * (nexports + 1)) + exports[nexports] = ExportSymbol{} + return exports + + +def plural_s(n: int) -> byte*: + if n == 1: + return "" + return "s" + +def nth(n: int) -> byte[100]: + first_few = [NULL as byte*, "first", "second", "third", "fourth", "fifth", "sixth"] + result: byte[100] + + if n < sizeof first_few / sizeof first_few[0]: + strcpy(result, first_few[n]) + else: + sprintf(result, "%dth", n) + return result + +def short_type_description(t: Type*) -> byte*: + if t->kind == TypeKind::Class or t->kind == TypeKind::OpaqueClass: + return "a class" + if t->kind == TypeKind::Enum: + return "an enum" + if t->is_pointer_type(): + return "a pointer type" + if t->is_number_type(): + return "a number type" + if t->kind == TypeKind::Array: + return "an array type" + if t == &bool_type: + return "the built-in bool type" + assert False + +# TODO: make this a method in class AstExpression? +def short_expression_description(expr: AstExpression*) -> byte[200]: + result: byte[200] + + # Imagine "cannot assign to" in front of these, e.g. "cannot assign to a constant" + if ( + expr->kind == AstExpressionKind::String + or expr->kind == AstExpressionKind::Short + or expr->kind == AstExpressionKind::Int + or expr->kind == AstExpressionKind::Long + or expr->kind == AstExpressionKind::Byte + or expr->kind == AstExpressionKind::Bool + or expr->kind == AstExpressionKind::Null + ): + return "a constant" + elif ( + expr->kind == AstExpressionKind::Negate + or expr->kind == AstExpressionKind::Add + or expr->kind == AstExpressionKind::Subtract + or expr->kind == AstExpressionKind::Multiply + or expr->kind == AstExpressionKind::Divide + or expr->kind == AstExpressionKind::Modulo + ): + return "the result of a calculation" + elif ( + expr->kind == AstExpressionKind::Eq + or expr->kind == AstExpressionKind::Ne + or expr->kind == AstExpressionKind::Gt + or expr->kind == AstExpressionKind::Ge + or expr->kind == AstExpressionKind::Lt + or expr->kind == AstExpressionKind::Le + ): + return "the result of a comparison" + elif expr->kind == AstExpressionKind::Call: + sprintf(result, "a %s call", expr->call.function_or_method()) + return result + elif expr->kind == AstExpressionKind::Instantiate: + return "a newly created instance" + elif expr->kind == AstExpressionKind::GetVariable: + if get_special_constant(expr->varname) == -1: + return "a variable" + return "a special constant" + elif expr->kind == AstExpressionKind::GetEnumMember: + return "an enum member" + elif expr->kind == AstExpressionKind::GetClassField: + snprintf(result, sizeof result, "field '%s'", expr->class_field.field_name) + return result + elif expr->kind == AstExpressionKind::As: + return "the result of a cast" + elif expr->kind == AstExpressionKind::SizeOf: + return "a sizeof expression" + elif expr->kind == AstExpressionKind::AddressOf: + subresult = short_expression_description(expr->operands) + snprintf(result, sizeof result, "address of %s", subresult) + return result + elif expr->kind == AstExpressionKind::Dereference: + return "the value of a pointer" + elif expr->kind == AstExpressionKind::And: + return "the result of 'and'" + elif expr->kind == AstExpressionKind::Or: + return "the result of 'or'" + elif expr->kind == AstExpressionKind::Not: + return "the result of 'not'" + elif expr->kind == AstExpressionKind::PreIncr or expr->kind == AstExpressionKind::PostIncr: + return "the result of incrementing a value" + elif expr->kind == AstExpressionKind::PreDecr or expr->kind == AstExpressionKind::PostDecr: + return "the result of decrementing a value" + elif expr->kind == AstExpressionKind::Indexing: + return "an indexed value" + elif expr->kind == AstExpressionKind::Self: + return "self" + elif expr->kind == AstExpressionKind::Array: + return "an array literal" + else: + expr->print() + printf("*** %d\n", expr->kind) + assert False + +# The & operator can't go in front of most expressions. +# You can't do &(1 + 2), for example. +# +# The same rules apply to assignments: "foo = bar" is treated as setting the +# value of the pointer &foo to bar. +# +# error_template can be e.g. "cannot take address of %s" or "cannot assign to %s" +def ensure_can_take_address(expression: AstExpression*, error_template: byte*) -> None: + if expression->kind == AstExpressionKind::GetClassField: + # &foo.bar --> must ensure we can take address of foo. + # Doesn't apply to &foo->bar because that's foo + some offset, so foo is already a pointer. + if not expression->class_field.uses_arrow_operator: + # Turn "cannot assign to %s" into "cannot assign to a field of %s". + # This assumes that error_template is relatively simple, i.e. it only contains one %s somewhere. + new_template = malloc(strlen(error_template) + 50) + sprintf(new_template, error_template, "a field of %s") + ensure_can_take_address(&expression->operands[0], new_template) + free(new_template) + return + + if expression->kind == AstExpressionKind::GetVariable: + # &foo is usually fine, but &WINDOWS is not + if get_special_constant(expression->varname) == -1: + return + + if ( + expression->kind == AstExpressionKind::Dereference # &*foo + or expression->kind == AstExpressionKind::Indexing # &foo[bar] = foo + some offset (foo is a pointer) + ): + return + + # Anything else is an error. + desc: byte[200] = short_expression_description(expression) + error = malloc(strlen(error_template) + 300) + sprintf(error, error_template, desc) + fail(expression->location, error) + +def max(a: int, b: int) -> int: + if a > b: + return a + return b + +def check_binop( + op: AstExpressionKind, + location: Location, + lhs_types: ExpressionTypes*, + rhs_types: ExpressionTypes*, +) -> Type*: + result_is_bool = False + if op == AstExpressionKind::Add: + do_what = "add" + elif op == AstExpressionKind::Subtract: + do_what = "subtract" + elif op == AstExpressionKind::Multiply: + do_what = "multiply" + elif op == AstExpressionKind::Divide: + do_what = "divide" + elif op == AstExpressionKind::Modulo: + do_what = "take remainder with" + else: + assert ( + op == AstExpressionKind::Eq + or op == AstExpressionKind::Ne + or op == AstExpressionKind::Gt + or op == AstExpressionKind::Ge + or op == AstExpressionKind::Lt + or op == AstExpressionKind::Le + ) + do_what = "compare" + result_is_bool = True + + got_bools = lhs_types->original_type == &bool_type and rhs_types->original_type == &bool_type + got_integers = lhs_types->original_type->is_integer_type() and rhs_types->original_type->is_integer_type() + got_numbers = lhs_types->original_type->is_number_type() and rhs_types->original_type->is_number_type() + got_enums = lhs_types->original_type->kind == TypeKind::Enum and rhs_types->original_type->kind == TypeKind::Enum + got_pointers = ( + lhs_types->original_type->is_pointer_type() + and rhs_types->original_type->is_pointer_type() + and ( + # Ban comparisons like int* == byte*, unless one of the two types is void* + lhs_types->original_type == rhs_types->original_type + or lhs_types->original_type == &void_ptr_type + or rhs_types->original_type == &void_ptr_type + ) + ) + + if ( + (not got_bools and not got_numbers and not got_enums and not got_pointers) + or (op != AstExpressionKind::Eq and op != AstExpressionKind::Ne and not got_numbers) + ): + message: byte[500] + snprintf( + message, sizeof message, + "wrong types: cannot %s %s and %s", + do_what, lhs_types->original_type->name, rhs_types->original_type->name, + ) + fail(location, message) + + if got_bools: + cast_type = &bool_type + elif got_integers: + size = max(lhs_types->original_type->size_in_bits, rhs_types->original_type->size_in_bits) + if ( + lhs_types->original_type->kind == TypeKind::SignedInteger + or rhs_types->original_type->kind == TypeKind::SignedInteger + ): + cast_type = &signed_integers[size] + else: + cast_type = &unsigned_integers[size] + elif got_numbers: + if lhs_types->original_type == &double_type or rhs_types->original_type == &double_type: + cast_type = &double_type + else: + cast_type = &float_type + elif got_pointers: + cast_type = &void_ptr_type + elif got_enums: + cast_type = int_type + else: + assert False + + lhs_types->do_implicit_cast(cast_type, Location{}, NULL) + rhs_types->do_implicit_cast(cast_type, Location{}, NULL) + + if result_is_bool: + return &bool_type + else: + return cast_type + +def check_class_field(location: Location, class_type: Type*, field_name: byte*) -> ClassField*: + assert class_type->kind == TypeKind::Class + + field = class_type->class_members.find_field(field_name) + if field == NULL: + message: byte[500] + snprintf(message, sizeof message, "class %s has no field named '%s'", class_type->name, field_name) + fail(location, message) + return field + +def cast_array_items_to_a_common_type(error_location: Location, types: ExpressionTypes**, ntypes: int) -> Type*: + # Avoid O(ntypes^2) code in a long array where all or almost all items have the same type. + # This is at most O(ntypes*ndistinct). + distinct: Type** = malloc(sizeof distinct[0] * ntypes) + ndistinct = 0 + for i = 0; i < ntypes; i++: + found = False + for k = 0; k < ndistinct; k++: + if distinct[k] == types[i]->original_type: + found = True + break + if not found: + distinct[ndistinct++] = types[i]->original_type + + compatible_with_all: Type** = malloc(sizeof compatible_with_all[0] * ndistinct) + n_compatible_with_all = 0 + for i = 0; i < ndistinct; i++: + compat = True + for k = 0; k < ndistinct; k++: + if not can_cast_implicitly(distinct[k], distinct[i]): + compat = False + break + if compat: + compatible_with_all[n_compatible_with_all++] = distinct[i] + + if n_compatible_with_all != 1: + # Can't make an unambiguous choice. Mention all types we considered in the error message. + assert sizeof distinct[0]->name == 100 + message: byte* = calloc(200, ndistinct+1) + strcpy(message, "array items have different types (") + for i = 0; i < ndistinct; i++: + if i != 0: + strcat(message, ", ") + strcat(message, distinct[i]->name) + strcat(message, ")") + fail(error_location, message) + + item_type = compatible_with_all[0] + free(distinct) + free(compatible_with_all) + + for i = 0; i < ntypes; i++: + types[i]->do_implicit_cast(item_type, Location{}, NULL) + return item_type + + +class Stage3TypeChecker: + file_types: FileTypes* + current_function_or_method: FunctionOrMethodTypes* + nested_loop_count: int + + def add_local_var(self, name: byte*, type: Type*) -> LocalVariable*: + v: LocalVariable* = calloc(1, sizeof *v) + assert strlen(name) < sizeof v->name + strcpy(v->name, name) + v->type = type + + dest_pointer = &self->current_function_or_method->local_vars + while *dest_pointer != NULL: + dest_pointer = &(*dest_pointer)->next + + *dest_pointer = v + return v + + def find_var(self, name: byte*) -> Type*: + if get_special_constant(name) != -1: + return &bool_type + local_var = self->current_function_or_method->find_local_var(name) + if local_var != NULL: + return local_var->type + for i = 0; i < self->file_types->nglobals; i++: + if strcmp(self->file_types->globals[i].name, name) == 0: + return self->file_types->globals[i].type + return NULL + + def find_function_or_method(self, self_type: Type*, name: byte*) -> Signature*: + if self_type == NULL: + return self->file_types->find_function(name) + elif self_type->kind == TypeKind::Class: + return self_type->class_members.find_method(name) + else: + return NULL + + def do_call(self, call: AstCall*) -> Type*: + message: byte[500] + + if call->method_call_self != NULL: + self_type = self->do_expression(call->method_call_self)->original_type + if call->uses_arrow_operator: + if self_type->kind != TypeKind::Pointer or self_type->value_type->kind != TypeKind::Class: + snprintf( + message, sizeof message, + "left side of the '->' operator must be a pointer, not %s", + self_type->name, + ) + fail(call->location, message) + self_type = self_type->value_type + else: + self_type = NULL + + signature = self->find_function_or_method(self_type, call->name) + if signature == NULL: + if self_type == NULL: + snprintf(message, sizeof message, "function '%s' not found", call->name) + elif ( + self_type->kind == TypeKind::Pointer + and self_type->value_type->kind == TypeKind::Class + and self_type->value_type->class_members.find_method(call->name) != NULL + ): + snprintf( + message, sizeof message, + "the method '%s' is defined on class %s, not on the pointer type %s, so you need to dereference the pointer first (e.g. by using '->' instead of '.')", + call->name, self_type->value_type->name, self_type->name, + ) + elif self_type->kind == TypeKind::Class: + snprintf( + message, sizeof message, + "class %s does not have a method named '%s'", self_type->name, call->name, + ) + else: + snprintf( + message, sizeof message, + "type %s does not have any methods because it is %s, not a class", + self_type->name, short_type_description(self_type), + ) + fail(call->location, message) + + if call->method_call_self != NULL and not call->uses_arrow_operator: + snprintf( + message, sizeof message, + "cannot take address of %%s, needed for calling the %s() method", call->name) + ensure_can_take_address(call->method_call_self, message) + + signature_string = signature->to_string(False, False) + + expected = signature->nargs + if self_type != NULL: + expected-- # exclude self + + if call->nargs < expected or (call->nargs > expected and not signature->takes_varargs): + snprintf( + message, sizeof message, + "%s %s takes %d argument%s, but it was called with %d argument%s", + signature->function_or_method(), + signature_string, + expected, + plural_s(expected), + call->nargs, + plural_s(call->nargs), + ) + fail(call->location, message) + + k = 0 + for i = 0; i < signature->nargs; i++: + if strcmp(signature->argnames[i], "self") == 0: + continue + + # This is a common error, so worth spending some effort to get a good error message. + tmp = nth(i+1) + snprintf( + message, sizeof message, + "%s argument of %s %s should have type , not ", + tmp, signature->function_or_method(), signature_string, + ) + self->do_expression_and_implicit_cast(&call->args[k++], signature->argtypes[i], message) + + for i = k; i < call->nargs; i++: + # This code runs for varargs, e.g. the things to format in printf(). + types = self->do_expression(&call->args[i]) + + if ( + (types->original_type->is_integer_type() and types->original_type->size_in_bits < 32) + or types->original_type == &bool_type + ): + # Add implicit cast to signed int, just like in C. + types->do_implicit_cast(int_type, Location{}, NULL) + elif types->original_type == &float_type: + types->do_implicit_cast(&double_type, Location{}, NULL) + elif types->original_type->kind == TypeKind::Array: + types->cast_array_to_pointer() + + free(signature_string) + return signature->return_type + + def do_increment_or_decrement(self, expression: AstExpression*, increment_or_decrement: byte*) -> Type*: + assert strcmp(increment_or_decrement, "increment") == 0 or strcmp(increment_or_decrement, "decrement") == 0 + + bad_expression_error_template: byte[50] + sprintf(bad_expression_error_template, "cannot %s %%s", increment_or_decrement) + ensure_can_take_address(&expression->operands[0], bad_expression_error_template) + + t = self->do_expression(&expression->operands[0])->original_type + if not t->is_integer_type() and not t->is_pointer_type(): + error: byte* = malloc(strlen(t->name) + 100) + sprintf(error, "cannot %s a value of type %s", increment_or_decrement, t->name) + fail(expression->location, error) + return t + + def do_enum_member(self, location: Location, enum_name: byte*, member_name: byte*) -> Type*: + message: byte[200] + + enum_type = self->file_types->find_type(enum_name) + if enum_type == NULL: + snprintf(message, sizeof message, "there is no type named '%s'", enum_name) + fail(location, message) + + if enum_type->kind != TypeKind::Enum: + snprintf( + message, sizeof message, + "the '::' syntax is only for enums, but %s is %s", + enum_name, short_type_description(enum_type), + ) + fail(location, message) + + if enum_type->enum_members.find_index(member_name) == -1: + snprintf(message, sizeof message, "enum %s has no member named '%s'", enum_name, member_name) + fail(location, message) + + return enum_type + + def do_instantiation(self, instantiation: AstInstantiation*) -> Type*: + message:byte[500] + + t = self->file_types->find_type(instantiation->class_name) + if t == NULL: + snprintf( + message, sizeof message, + "there is no type named '%s'", instantiation->class_name, + ) + fail(instantiation->class_name_location, message) + + if t->kind != TypeKind::Class: + description = short_type_description(t) + snprintf( + message, sizeof message, + "the %s{...} syntax is only for classes, but %s is %s", + t->name, t->name, description, + ) + fail(instantiation->class_name_location, message) + + specified_fields: ClassField** = malloc(sizeof specified_fields[0] * instantiation->nfields) + for i = 0; i < instantiation->nfields; i++: + snprintf( + message, sizeof message, + "value for field '%s' of class %s must be of type , not ", + instantiation->field_names[i], t->name, + ) + specified_fields[i] = check_class_field( + instantiation->field_values[i].location, + t, + instantiation->field_names[i], + ) + self->do_expression_and_implicit_cast( + &instantiation->field_values[i], + specified_fields[i]->type, + message, + ) + + for i1 = 0; i1 < instantiation->nfields; i1++: + for i2 = i1+1; i2 < instantiation->nfields; i2++: + if specified_fields[i1]->union_id == specified_fields[i2]->union_id: + snprintf( + message, sizeof message, + "fields '%s' and '%s' cannot be set simultaneously because they belong to the same union", + specified_fields[i1]->name, + specified_fields[i2]->name, + ) + fail(instantiation->field_values[i2].location, message) + + return t + + def do_indexing(self, pointer: AstExpression*, index: AstExpression*) -> Type*: + message: byte[500] + types = self->do_expression(pointer) + + if types->original_type->kind == TypeKind::Array: + types->cast_array_to_pointer() + pointer_type = types->implicit_cast_type + elif types->original_type->kind == TypeKind::Pointer: + pointer_type = types->original_type + else: + snprintf(message, sizeof message[0], "value of type %s cannot be indexed", types->original_type->name) + fail(pointer->location, message) + + index_types = self->do_expression(index) + assert index_types != NULL + + if not index_types->original_type->is_integer_type(): + snprintf(message, sizeof message[0], "the index inside [...] must be an integer, not %s", index_types->original_type->name) + fail(index->location, message) + + # LLVM assumes that indexes smaller than 64 bits are signed. + # https://github.com/Akuli/jou/issues/48 + index_types->do_implicit_cast(long_type, Location{}, NULL) + + return pointer_type->value_type + + def do_expression_maybe_void(self, expression: AstExpression*) -> ExpressionTypes*: + result: Type* + message: byte[200] + + if expression->kind == AstExpressionKind::String: + result = byte_type->get_pointer_type() + elif expression->kind == AstExpressionKind::Bool: + result = &bool_type + elif expression->kind == AstExpressionKind::Byte: + result = byte_type + elif expression->kind == AstExpressionKind::Short: + result = short_type + elif expression->kind == AstExpressionKind::Int: + result = int_type + elif expression->kind == AstExpressionKind::Long: + result = long_type + elif expression->kind == AstExpressionKind::Float: + result = &float_type + elif expression->kind == AstExpressionKind::Double: + result = &double_type + elif expression->kind == AstExpressionKind::Null: + result = &void_ptr_type + elif expression->kind == AstExpressionKind::Array: + n = expression->array.length + item_types: ExpressionTypes** = malloc(n * sizeof item_types[0]) + for i = 0; i < n; i++: + item_types[i] = self->do_expression(&expression->array.items[i]) + member_type = cast_array_items_to_a_common_type(expression->location, item_types, n) + free(item_types) + result = member_type->get_array_type(n) + elif expression->kind == AstExpressionKind::Call: + result = self->do_call(&expression->call) + if result == NULL: + return NULL + elif expression->kind == AstExpressionKind::GetVariable: + result = self->find_var(expression->varname) + if result == NULL: + snprintf(message, sizeof message, "no variable named '%s'", expression->varname) + fail(expression->location, message) + elif expression->kind == AstExpressionKind::As: + value_types = self->do_expression(&expression->as_expression->value) + result = type_from_ast(self->file_types, &expression->as_expression->type) + value_types->do_explicit_cast(result, expression->location) + elif expression->kind == AstExpressionKind::GetEnumMember: + result = self->do_enum_member( + expression->location, + expression->enum_member.enum_name, + expression->enum_member.member_name, + ) + elif expression->kind == AstExpressionKind::And: + self->do_expression_and_implicit_cast(&expression->operands[0], &bool_type, "'and' only works with booleans, not ") + self->do_expression_and_implicit_cast(&expression->operands[1], &bool_type, "'and' only works with booleans, not ") + result = &bool_type + elif expression->kind == AstExpressionKind::Or: + self->do_expression_and_implicit_cast(&expression->operands[0], &bool_type, "'or' only works with booleans, not ") + self->do_expression_and_implicit_cast(&expression->operands[1], &bool_type, "'or' only works with booleans, not ") + result = &bool_type + elif ( + expression->kind == AstExpressionKind::Add + or expression->kind == AstExpressionKind::Subtract + or expression->kind == AstExpressionKind::Multiply + or expression->kind == AstExpressionKind::Divide + or expression->kind == AstExpressionKind::Modulo + or expression->kind == AstExpressionKind::Eq + or expression->kind == AstExpressionKind::Ne + or expression->kind == AstExpressionKind::Gt + or expression->kind == AstExpressionKind::Ge + or expression->kind == AstExpressionKind::Lt + or expression->kind == AstExpressionKind::Le + ): + lhs_types = self->do_expression(&expression->operands[0]) + rhs_types = self->do_expression(&expression->operands[1]) + result = check_binop(expression->kind, expression->location, lhs_types, rhs_types) + elif expression->kind == AstExpressionKind::Negate: + result = self->do_expression(&expression->operands[0])->original_type + # TODO: check for floats/doubles too + if result->kind != TypeKind::SignedInteger and result->kind != TypeKind::FloatingPoint: + snprintf( + message, sizeof message, + "value after '-' must be a float or double or a signed integer, not %s", + result->name, + ) + fail(expression->location, message) + elif expression->kind == AstExpressionKind::PreIncr or expression->kind == AstExpressionKind::PostIncr: + result = self->do_increment_or_decrement(expression, "increment") + elif expression->kind == AstExpressionKind::PreDecr or expression->kind == AstExpressionKind::PostDecr: + result = self->do_increment_or_decrement(expression, "decrement") + elif expression->kind == AstExpressionKind::GetClassField: + lhs_type = self->do_expression(expression->class_field.instance)->original_type + if expression->class_field.uses_arrow_operator: + if lhs_type->kind != TypeKind::Pointer or lhs_type->value_type->kind != TypeKind::Class: + snprintf( + message, sizeof message, + "left side of the '->' operator must be a pointer to a class, not %s", + lhs_type->name, + ) + fail(expression->location, message) + result = check_class_field(expression->location, lhs_type->value_type, expression->class_field.field_name)->type + else: + if lhs_type->kind != TypeKind::Class: + snprintf( + message, sizeof message, + "left side of the '.' operator must be an instance of a class, not %s", + lhs_type->name, + ) + fail(expression->location, message) + result = check_class_field(expression->location, lhs_type, expression->class_field.field_name)->type + elif expression->kind == AstExpressionKind::AddressOf: + ensure_can_take_address(&expression->operands[0], "the '&' operator cannot be used with %s") + result = self->do_expression(&expression->operands[0])->original_type->get_pointer_type() + elif expression->kind == AstExpressionKind::Dereference: + pointer_type = self->do_expression(expression->operands)->original_type + if pointer_type->kind != TypeKind::Pointer: + snprintf( + message, sizeof message, + "the dereference operator '*' is only for pointers, not for %s", + pointer_type->name, + ) + fail(expression->location, message) + result = pointer_type->value_type + elif expression->kind == AstExpressionKind::Instantiate: + result = self->do_instantiation(&expression->instantiation) + elif expression->kind == AstExpressionKind::Indexing: + result = self->do_indexing(&expression->operands[0], &expression->operands[1]) + elif expression->kind == AstExpressionKind::Not: + self->do_expression_and_implicit_cast( + &expression->operands[0], &bool_type, + "value after 'not' must be a boolean, not ", + ) + result = &bool_type + elif expression->kind == AstExpressionKind::Self: + class_type = self->current_function_or_method->signature.get_containing_class() + assert class_type != NULL + result = class_type->get_pointer_type() + elif expression->kind == AstExpressionKind::SizeOf: + self->do_expression(&expression->operands[0]) + result = long_type + else: + printf("*** expr %d\n", expression->kind as int) + expression->print() + assert False + + p: ExpressionTypes* = malloc(sizeof *p) + *p = ExpressionTypes{ + expression = expression, + original_type = result, + next = self->current_function_or_method->expression_types, + } + self->current_function_or_method->expression_types = p + return p + + def do_expression(self, expression: AstExpression*) -> ExpressionTypes*: + types = self->do_expression_maybe_void(expression) + if types == NULL: + assert expression->kind == AstExpressionKind::Call + name = expression->call.name + message = malloc(strlen(name) + 100) + sprintf(message, "%s '%s' does not return a value", expression->call.function_or_method(), name) + fail(expression->location, message) + return types + + def do_expression_and_implicit_cast( + self, + expression: AstExpression*, + cast_type: Type*, + error_message_template: byte*, + ) -> ExpressionTypes*: + types = self->do_expression(expression) + types->do_implicit_cast(cast_type, expression->location, error_message_template) + return types + + def do_in_place_operation( + self, + location: Location, + target: AstExpression*, # the foo of "foo += 1" + value: AstExpression*, # the 1 of "foo += 1" + op_expr_kind: AstExpressionKind, # e.g. AstExpressionKind::Add + op_description: byte[20], # e.g. "addition" + ) -> None: + ensure_can_take_address(target, "cannot assign to %s") + target_types = self->do_expression(target) + value_types = self->do_expression(value) + + t = check_binop(op_expr_kind, location, target_types, value_types) + temp_value_types = ExpressionTypes{ expression = target, original_type = t } + + error_template: byte[200] + strcpy(error_template, op_description) + strcat(error_template, " produced a value of type which cannot be assigned back to ") + temp_value_types.do_implicit_cast(target_types->original_type, location, error_template) + + # I think it is currently impossible to cast target. + # If this assert fails, we probably need a new error message. + assert target_types->implicit_cast_type == NULL + + def do_statement(self, statement: AstStatement*) -> None: + if statement->kind == AstStatementKind::Assert: + self->do_expression_and_implicit_cast( + &statement->expression, &bool_type, "assertion must be a boolean, not " + ) + + elif statement->kind == AstStatementKind::ExpressionStatement: + self->do_expression_maybe_void(&statement->expression) + + elif statement->kind == AstStatementKind::Return: + sig = &self->current_function_or_method->signature + + # TODO: check for noreturn functions + + msg: byte[500] + + if statement->return_value != NULL and sig->return_type == NULL: + snprintf( + msg, sizeof msg, + "%s '%s' cannot return a value because it was defined with '-> None'", + sig->function_or_method(), sig->name, + ) + fail(statement->location, msg) + if statement->return_value == NULL and sig->return_type != NULL: + snprintf( + msg, sizeof msg, + "%s '%s' must return a value because it was defined with '-> %s'", + sig->function_or_method(), sig->name, sig->return_type->name, + ) + fail(statement->location, msg) + + if statement->return_value != NULL: + cast_error_msg: byte[500] + snprintf( + cast_error_msg, sizeof cast_error_msg, + "attempting to return a value of type from %s '%s' defined with '-> '", + sig->function_or_method(), sig->name, + ) + self->do_expression_and_implicit_cast( + statement->return_value, sig->return_type, cast_error_msg + ) + + elif statement->kind == AstStatementKind::Assign: + target_expr = &statement->assignment.target + value_expr = &statement->assignment.value + ensure_can_take_address(target_expr, "cannot assign to %s") + + if ( + target_expr->kind == AstExpressionKind::GetVariable + and self->find_var(target_expr->varname) == NULL + ): + # Making a new variable. Use the type of the value being assigned. + types = self->do_expression(value_expr) + self->add_local_var(target_expr->varname, types->original_type) + else: + # Convert value to the type of an existing variable or other assignment target. + # This tends to fail often, so try to produce a helpful error message. + error_template: byte[500] + if target_expr->kind == AstExpressionKind::Dereference: + error_template = "cannot place a value of type into a pointer of type *" + else: + target_description: byte[200] = short_expression_description(target_expr) + snprintf( + error_template, sizeof error_template, + "cannot assign a value of type to %s of type ", + target_description, + ) + + target_types = self->do_expression(target_expr) + self->do_expression_and_implicit_cast(value_expr, target_types->original_type, error_template) + + elif statement->kind == AstStatementKind::DeclareLocalVar: + ntv: AstNameTypeValue* = &statement->var_declaration + if self->find_var(ntv->name) != NULL: + message: byte[200] + snprintf(message, sizeof message, "a variable named '%s' already exists", ntv->name) + fail(statement->location, message) + + type = type_from_ast(self->file_types, &ntv->type) + self->add_local_var(ntv->name, type) + if ntv->value != NULL: + self->do_expression_and_implicit_cast( + ntv->value, type, + "initial value for variable of type cannot be of type ", + ) + + elif statement->kind == AstStatementKind::If: + for i = 0; i < statement->if_statement.n_if_and_elifs; i++: + if i == 0: + template = "'if' condition must be a boolean, not " + else: + template = "'elif' condition must be a boolean, not " + self->do_expression_and_implicit_cast( + &statement->if_statement.if_and_elifs[i].condition, &bool_type, template + ) + self->do_body(&statement->if_statement.if_and_elifs[i].body) + self->do_body(&statement->if_statement.else_body) + + elif statement->kind == AstStatementKind::WhileLoop: + self->do_expression_and_implicit_cast( + &statement->while_loop.condition, &bool_type, + "'while' condition must be a boolean, not ", + ) + self->nested_loop_count++ + self->do_body(&statement->while_loop.body) + self->nested_loop_count-- + + elif statement->kind == AstStatementKind::ForLoop: + self->do_statement(statement->for_loop.init) + self->do_expression_and_implicit_cast( + &statement->for_loop.cond, &bool_type, + "'for' condition must be a boolean, not ", + ) + self->nested_loop_count++ + self->do_body(&statement->for_loop.body) + self->nested_loop_count-- + self->do_statement(statement->for_loop.incr) + + elif statement->kind == AstStatementKind::Pass: + pass + + elif statement->kind == AstStatementKind::Break: + if self->nested_loop_count == 0: + fail(statement->location, "'break' can only be used inside a loop") + + elif statement->kind == AstStatementKind::Continue: + if self->nested_loop_count == 0: + fail(statement->location, "'continue' can only be used inside a loop") + + elif statement->kind == AstStatementKind::InPlaceAdd: + self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Add, "addition") + elif statement->kind == AstStatementKind::InPlaceSubtract: + self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Subtract, "subtraction") + elif statement->kind == AstStatementKind::InPlaceMultiply: + self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Multiply, "multiplication") + elif statement->kind == AstStatementKind::InPlaceDivide: + self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Divide, "division") + elif statement->kind == AstStatementKind::InPlaceModulo: + self->do_in_place_operation(statement->location, &statement->assignment.target, &statement->assignment.value, AstExpressionKind::Modulo, "modulo") + + else: + statement->print() + printf("*** typecheck: unknown statement kind %d\n", statement->kind) + assert False + + def do_body(self, body: AstBody*) -> None: + for i = 0; i < body->nstatements; i++: + self->do_statement(&body->statements[i]) + + def define_function_or_method(self, signature: Signature*, body: AstBody*) -> None: + assert self->current_function_or_method == NULL + self->file_types->defined_functions = realloc( + self->file_types->defined_functions, + (self->file_types->n_defined_functions + 1) * sizeof self->file_types->defined_functions[0], + ) + self->current_function_or_method = &self->file_types->defined_functions[self->file_types->n_defined_functions++] + *self->current_function_or_method = FunctionOrMethodTypes{signature = signature->copy()} + + for k = 0; k < signature->nargs; k++: + self->add_local_var(signature->argnames[k], signature->argtypes[k]) + + self->do_body(body) + self->current_function_or_method = NULL + + +def typecheck_stage3_function_and_method_bodies(file_types: FileTypes*, ast_file: AstFile*) -> None: + checker = Stage3TypeChecker{file_types = file_types} + for i = 0; i < ast_file->body.nstatements; i++: + ts = &ast_file->body.statements[i] + if ts->kind == AstStatementKind::Function and ts->function.body.nstatements > 0: + signature = file_types->find_function(ts->function.signature.name) + assert signature != NULL + checker.define_function_or_method(signature, &ts->function.body) + elif ts->kind == AstStatementKind::Class: + class_type = file_types->find_type(ts->classdef.name) + assert class_type != NULL # created in previous typecheck stage + assert class_type->kind == TypeKind::Class + for k = 0; k < ts->classdef.nmembers; k++: + if ts->classdef.members[k].kind == AstClassMemberKind::Method: + signature = class_type->class_members.find_method(ts->classdef.members[k].method.signature.name) + checker.define_function_or_method(signature, &ts->classdef.members[k].method.body) diff --git a/self_hosted_old/types.jou b/self_hosted_old/types.jou new file mode 100644 index 00000000..5656fa0e --- /dev/null +++ b/self_hosted_old/types.jou @@ -0,0 +1,246 @@ +import "stdlib/str.jou" +import "stdlib/mem.jou" + +enum TypeKind: + Bool + SignedInteger + UnsignedInteger + FloatingPoint + Pointer + VoidPointer + Class + OpaqueClass + Enum + Array + +class EnumMembers: + count: int + names: byte[100]* + + # Returns -1 for not found + def find_index(self, name: byte*) -> int: + for i = 0; i < self->count; i++: + if strcmp(self->names[i], name) == 0: + return i + return -1 + +class ClassField: + name: byte[100] + type: Type* + # If multiple fields have the same union_id, they belong to the same union. + # It means that only one of the fields can be used at a time. + union_id: int + +class ClassMembers: + fields: ClassField* + nfields: int + methods: Signature* + nmethods: int + + def find_field(self, name: byte*) -> ClassField*: + for i = 0; i < self->nfields; i++: + if strcmp(self->fields[i].name, name) == 0: + return &self->fields[i] + return NULL + + def find_method(self, name: byte*) -> Signature*: + for i = 0; i < self->nmethods; i++: + if strcmp(self->methods[i].name, name) == 0: + return &self->methods[i] + return NULL + +class ArrayInfo: + length: int + item_type: Type* + +class Type: + name: byte[100] + kind: TypeKind + + union: + size_in_bits: int # SignedInteger, UnsignedInteger, FloatingPoint + value_type: Type* # Pointer (not used for VoidPointer) + enum_members: EnumMembers + class_members: ClassMembers + array: ArrayInfo + + # Pointers and arrays of a given type live as long as the type itself. + # To make it possible, we just store them within the type. + # These are initially NULL and created dynamically as needed. + # + # Do not access these outside this file. + cached_pointer_type: Type* + cached_array_types: Type** + n_cached_array_types: int + + def is_integer_type(self) -> bool: + return self->kind == TypeKind::SignedInteger or self->kind == TypeKind::UnsignedInteger + + def is_number_type(self) -> bool: + return self->is_integer_type() or self->kind == TypeKind::FloatingPoint + + def is_pointer_type(self) -> bool: + return self->kind == TypeKind::Pointer or self->kind == TypeKind::VoidPointer + + def get_pointer_type(self) -> Type*: + if self->cached_pointer_type == NULL: + pointer_name: byte[100] + snprintf(pointer_name, sizeof pointer_name, "%s*", self->name) + + self->cached_pointer_type = malloc(sizeof *self->cached_pointer_type) + *self->cached_pointer_type = Type{ + name = pointer_name, + kind = TypeKind::Pointer, + value_type = self, + } + + return self->cached_pointer_type + + def get_array_type(self, length: int) -> Type*: + assert length > 0 + + for i = 0; i < self->n_cached_array_types; i++: + if self->cached_array_types[i]->array.length == length: + return self->cached_array_types[i] + + array_name: byte[100] + snprintf(array_name, sizeof array_name, "%s[%d]", self->name, length) + + t: Type* = malloc(sizeof *t) + *t = Type{ + name = array_name, + kind = TypeKind::Array, + array = ArrayInfo{length = length, item_type = self}, + } + + self->cached_array_types = realloc(self->cached_array_types, sizeof self->cached_array_types[0] * (self->n_cached_array_types + 1)) + self->cached_array_types[self->n_cached_array_types++] = t + return t + +# Typese are cached into global state, so you can use == between +# pointers to compare them. Also, you don't usually need to copy a +# type, you can just pass around a pointer to it. +global signed_integers: Type[65] # indexed by size in bits (8, 16, 32, 64) +global unsigned_integers: Type[65] # indexed by size in bits (8, 16, 32, 64) +global bool_type: Type +global void_ptr_type: Type +global float_type: Type +global double_type: Type + +# TODO: it seems weird in other files these are pointers but bool_type isn't +global byte_type: Type* +global short_type: Type* +global int_type: Type* +global long_type: Type* + +def init_types() -> None: + void_ptr_type = Type{name = "void*", kind = TypeKind::VoidPointer} + bool_type = Type{name = "bool", kind = TypeKind::Bool} + float_type = Type{name = "float", size_in_bits = 32, kind = TypeKind::FloatingPoint} + double_type = Type{name = "double", size_in_bits = 64, kind = TypeKind::FloatingPoint} + + for size = 8; size <= 64; size *= 2: + sprintf(signed_integers[size].name, "<%d-bit signed integer>", size) + sprintf(unsigned_integers[size].name, "<%d-bit unsigned integer>", size) + signed_integers[size].kind = TypeKind::SignedInteger + unsigned_integers[size].kind = TypeKind::UnsignedInteger + signed_integers[size].size_in_bits = size + unsigned_integers[size].size_in_bits = size + + byte_type = &unsigned_integers[8] + short_type = &signed_integers[16] + int_type = &signed_integers[32] + long_type = &signed_integers[64] + + byte_type->name = "byte" + short_type->name = "short" + int_type->name = "int" + long_type->name = "long" + +def create_opaque_class(name: byte*) -> Type*: + result: Type* = malloc(sizeof *result) + *result = Type{kind = TypeKind::OpaqueClass} + assert strlen(name) < sizeof result->name + strcpy(result->name, name) + return result + +def create_enum(name: byte*, member_count: int, member_names: byte[100]*) -> Type*: + copied_member_names: byte[100]* = malloc(member_count * sizeof copied_member_names[0]) + memcpy(copied_member_names, member_names, member_count * sizeof copied_member_names[0]) + + result: Type* = malloc(sizeof *result) + *result = Type{ + kind = TypeKind::Enum, + enum_members = EnumMembers{count = member_count, names = copied_member_names}, + } + assert strlen(name) < sizeof result->name + strcpy(result->name, name) + return result + + +class Signature: + name: byte[100] # name of function or method, after "def" keyword + nargs: int + argnames: byte[100]* + argtypes: Type** + takes_varargs: bool # True for functions like printf() + return_type: Type* + + def get_containing_class(self) -> Type*: + for i = 0; i < self->nargs; i++: + if strcmp(self->argnames[i], "self") == 0: + assert self->argtypes[i]->kind == TypeKind::Pointer + assert self->argtypes[i]->value_type->kind == TypeKind::Class + return self->argtypes[i]->value_type + return NULL + + def is_method(self) -> bool: + return self->get_containing_class() != NULL + + def function_or_method(self) -> byte*: + if self->is_method(): + return "method" + else: + return "function" + + def to_string(self, include_self: bool, include_return_type: bool) -> byte*: + result: byte* = malloc(500*(self->nargs + 1)) + strcpy(result, self->name) + + strcat(result, "(") + + for i = 0; i < self->nargs; i++: + if strcmp(self->argnames[i], "self") == 0 and not include_self: + continue + strcat(result, self->argnames[i]) + strcat(result, ": ") + strcat(result, self->argtypes[i]->name) + strcat(result, ", ") + + if self->takes_varargs: + strcat(result, "...") + elif ends_with(result, ", "): + result[strlen(result)-2] = '\0' + + strcat(result, ")") + + if include_return_type: + if self->return_type == NULL: + strcat(result, " -> None") + else: + strcat(result, " -> ") + strcat(result, self->return_type->name) + + return result + + def copy(self) -> Signature: + result = *self + result.argnames = malloc(result.nargs * sizeof(result.argnames[0])) + result.argtypes = malloc(result.nargs * sizeof(result.argtypes[0])) + memcpy(result.argnames, self->argnames, result.nargs * sizeof(result.argnames[0])) + memcpy(result.argtypes, self->argtypes, result.nargs * sizeof(result.argtypes[0])) + return result + + def free(self) -> None: + free(self->argnames) + free(self->argtypes)