diff --git a/doc/syntax-spec.md b/doc/syntax-spec.md index ebed4f8d..bbe85b57 100644 --- a/doc/syntax-spec.md +++ b/doc/syntax-spec.md @@ -90,6 +90,7 @@ Jou has a few different kinds of tokens: - `not` - `as` - `sizeof` + - `assert` - `void` - `noreturn` - `bool` diff --git a/self_hosted/ast.jou b/self_hosted/ast.jou index 2dc97124..65c7ba2f 100644 --- a/self_hosted/ast.jou +++ b/self_hosted/ast.jou @@ -302,6 +302,7 @@ class AstCall: enum AstStatementKind: ExpressionStatement # Evaluate an expression. Discard the result. + Assert Return If WhileLoop @@ -321,19 +322,22 @@ class AstStatement: kind: AstStatementKind # TODO: union - expression: AstExpression # AstStatementKind::ExpressionStatement + expression: AstExpression # ExpressionStatement, Assert if_statement: AstIfStatement while_loop: AstConditionAndBody for_loop: AstForLoop - return_value: AstExpression* # AstStatementKind::Return (can be NULL) + return_value: AstExpression* # can be NULL assignment: AstAssignment - var_declaration: AstNameTypeValue # AstStatementKind::DeclareLocalVar + var_declaration: AstNameTypeValue # DeclareLocalVar def print(self, tp: TreePrinter) -> void: printf("[line %d] ", self->location.lineno) if self->kind == AstStatementKind::ExpressionStatement: printf("expression statement\n") self->expression.print(tp.print_prefix(True)) + elif self->kind == AstStatementKind::Assert: + printf("assert\n") + self->expression.print(tp.print_prefix(True)) elif self->kind == AstStatementKind::Return: printf("return\n") if self->return_value != NULL: @@ -479,7 +483,7 @@ class AstNameTypeValue: printf("%s: ", &self->name[0]) self->type.print(True) if tp == NULL: - assert(self->value == NULL) + assert self->value == NULL else: printf("\n") if self->value != NULL: @@ -597,7 +601,7 @@ class AstFile: def next_import(self, imp: AstImport**) -> bool: # Get the corresponding AstToplevelStatement. ts = *imp as AstToplevelStatement* - assert(&ts->the_import as void* == ts) # TODO: offsetof() or similar + assert &ts->the_import as void* == ts # TODO: offsetof() or similar # Assume all imports are in the beginning of the file. if ts == NULL: diff --git a/self_hosted/create_llvm_ir.jou b/self_hosted/create_llvm_ir.jou index 6c339fc9..93f2929d 100644 --- a/self_hosted/create_llvm_ir.jou +++ b/self_hosted/create_llvm_ir.jou @@ -4,7 +4,6 @@ import "./typecheck.jou" import "./types.jou" import "./ast.jou" import "./target.jou" -import "./errors_and_warnings.jou" import "stdlib/io.jou" import "stdlib/mem.jou" import "stdlib/str.jou" @@ -21,8 +20,7 @@ class AstToIR: if type->kind == TypeKind::Pointer: return LLVMPointerType(self->do_type(type->value_type), 0) printf("asd-Asd., %s\n", &type->name) - assert(False) - return NULL + assert False def declare_function(self, signature: Signature*) -> void: argtypes: LLVMType** @@ -41,7 +39,7 @@ class AstToIR: LLVMAddFunction(self->module, &signature->name[0], function_type) def new_block(self, name_hint: byte*) -> void: - assert(self->current_function != NULL) + assert self->current_function != NULL block = LLVMAppendBasicBlock(self->current_function, name_hint) LLVMPositionBuilderAtEnd(self->builder, block) @@ -54,6 +52,29 @@ class AstToIR: string_type = LLVMPointerType(LLVMInt8Type(), 0) return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr") + def build_assert(self, condition: LLVMValue*) -> void: + true_block = LLVMAppendBasicBlock(self->current_function, "assert_true") + false_block = LLVMAppendBasicBlock(self->current_function, "assert_false") + LLVMBuildCondBr(self->builder, condition, true_block, false_block) + + LLVMPositionBuilderAtEnd(self->builder, false_block) + + argtypes = [LLVMPointerType(LLVMInt8Type(), 0), LLVMPointerType(LLVMInt8Type(), 0), LLVMInt32Type()] + assert_fail_func_type = LLVMFunctionType(LLVMVoidType(), &argtypes[0], 3, False) + assert_fail_func = LLVMGetNamedFunction(self->module, "_jou_assert_fail") + if assert_fail_func == NULL: + assert_fail_func = LLVMAddFunction(self->module, "_jou_assert_fail", assert_fail_func_type) + assert assert_fail_func != NULL + + args = [ + self->make_a_string_constant("foo"), + self->make_a_string_constant("bar"), + LLVMConstInt(LLVMInt32Type(), 123, False), + ] + + LLVMBuildCall2(self->builder, assert_fail_func_type, assert_fail_func, &args[0], 3, "") + LLVMPositionBuilderAtEnd(self->builder, true_block) + def do_expression(self, ast: AstExpression*) -> LLVMValue*: if ast->kind == AstExpressionKind::String: return self->make_a_string_constant(ast->string) @@ -66,10 +87,10 @@ class AstToIR: elif ast->kind == AstExpressionKind::FunctionCall: function = LLVMGetNamedFunction(self->module, &ast->call.called_name[0]) - assert(function != NULL) - assert(LLVMGetTypeKind(LLVMTypeOf(function)) == LLVMTypeKind::Pointer) + assert function != NULL + assert LLVMGetTypeKind(LLVMTypeOf(function)) == LLVMTypeKind::Pointer function_type = LLVMGetElementType(LLVMTypeOf(function)) - assert(LLVMGetTypeKind(function_type) == LLVMTypeKind::Function) + assert LLVMGetTypeKind(function_type) == LLVMTypeKind::Function args: LLVMValue** = malloc(sizeof args[0] * ast->call.nargs) for i = 0; i < ast->call.nargs; i++: @@ -80,8 +101,7 @@ class AstToIR: else: printf("Asd-asd. Unknown expr %d...\n", ast->kind) - assert(False) - return NULL + assert False def do_statement(self, ast: AstStatement*) -> void: if ast->kind == AstStatementKind::ExpressionStatement: @@ -94,9 +114,12 @@ class AstToIR: LLVMBuildRetVoid(self->builder) # If more code follows, place it into a new block that never actually runs self->new_block("after_return") + elif ast->kind == AstStatementKind::Assert: + condition = self->do_expression(&ast->expression) + self->build_assert(condition) else: printf("Asd-asd. Unknown statement...\n") - assert(False) + assert False def do_body(self, body: AstBody*) -> void: for i = 0; i < body->nstatements; i++: @@ -105,12 +128,12 @@ class AstToIR: # The function must already be declared. def define_function(self, funcdef: AstFunction*) -> void: llvm_func = LLVMGetNamedFunction(self->module, &funcdef->signature.name[0]) - assert(llvm_func != NULL) - assert(self->current_function == NULL) + assert llvm_func != NULL + assert self->current_function == NULL self->current_function = llvm_func self->new_block("start") - assert(funcdef->body.nstatements > 0) # it is a definition + assert funcdef->body.nstatements > 0 # it is a definition self->do_body(&funcdef->body) LLVMBuildUnreachable(self->builder) diff --git a/self_hosted/errors_and_warnings.jou b/self_hosted/errors_and_warnings.jou index 39a1cf49..0e520cd2 100644 --- a/self_hosted/errors_and_warnings.jou +++ b/self_hosted/errors_and_warnings.jou @@ -17,11 +17,3 @@ def fail(location: Location, message: byte*) -> noreturn: fprintf(stderr, ": %s\n", message) exit(1) - -# TODO: doesn't really belong here -def assert(b: bool) -> void: - if not b: - fflush(stdout) - fflush(stderr) - fprintf(stderr, "assertion failed\n") - exit(1) diff --git a/self_hosted/llvm.jou b/self_hosted/llvm.jou index acb724ef..8a189b91 100644 --- a/self_hosted/llvm.jou +++ b/self_hosted/llvm.jou @@ -230,7 +230,7 @@ declare LLVMDisposeBuilder(Builder: LLVMBuilder*) -> void declare LLVMBuildRet(Builder: LLVMBuilder*, V: LLVMValue*) -> LLVMValue* declare LLVMBuildRetVoid(Builder: LLVMBuilder*) -> LLVMValue* declare LLVMBuildBr(Builder: LLVMBuilder*, Dest: LLVMBasicBlock*) -> LLVMValue* -declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMValue*, Else: LLVMBasicBlock*) -> LLVMValue* +declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMBasicBlock*, Else: LLVMBasicBlock*) -> LLVMValue* declare LLVMBuildUnreachable(Builder: LLVMBuilder*) -> LLVMValue* declare LLVMBuildAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* declare LLVMBuildSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue* diff --git a/self_hosted/main.jou b/self_hosted/main.jou index 37021f4e..36e97bf0 100644 --- a/self_hosted/main.jou +++ b/self_hosted/main.jou @@ -1,5 +1,4 @@ import "../config.jou" -import "./errors_and_warnings.jou" import "./ast.jou" import "./tokenizer.jou" import "./parser.jou" @@ -104,7 +103,7 @@ def get_sane_filename(path: byte*) -> byte[50]: name: byte[50] snprintf(&name[0], sizeof name, "%s", path) - assert(name[0] != '\0') + assert name[0] != '\0' if name[0] == '.': name[0] = '_' @@ -129,11 +128,20 @@ class Compiler: args: CommandLineArgs* files: FileState* nfiles: int + automagic_files: byte*[10] + + def determine_automagic_files(self) -> void: + # TODO: this breaks too much stuff + return +# self->automagic_files[0] = malloc(strlen(self->stdlib_path) + 40) +# sprintf(self->automagic_files[0], "%s/_assert_fail.jou", self->stdlib_path) def parse_all_files(self) -> void: - queue: byte** = malloc(sizeof queue[0]) - queue[0] = self->args->main_path - queue_len = 1 + queue: byte** = malloc(50 * sizeof queue[0]) + queue_len = 0 + queue[queue_len++] = self->args->main_path + for i = 0; self->automagic_files[i] != NULL; i++: + queue[queue_len++] = self->automagic_files[i] while queue_len > 0: path = queue[--queue_len] @@ -195,7 +203,7 @@ class Compiler: if self->verbosity >= 1: printf("Type-check stage 2: %s\n", self->files[i].ast.path) - assert(self->files[i].pending_exports == NULL) + assert self->files[i].pending_exports == NULL self->files[i].pending_exports = typecheck_stage2_signatures_globals_structbodies( &self->files[i].typectx, &self->files[i].ast, @@ -280,10 +288,10 @@ class Compiler: error: byte* = NULL if LLVMTargetMachineEmitToFile(target.target_machine, module, path, LLVMCodeGenFileType::ObjectFile, &error): - assert(error != NULL) + assert error != NULL fprintf(stderr, "error in LLVMTargetMachineEmitToFile(): %s\n", error) exit(1) - assert(error == NULL) + assert error == NULL return paths @@ -368,6 +376,7 @@ def main(argc: int, argv: byte**) -> int: stdlib_path = find_stdlib(), args = &args, } + compiler.determine_automagic_files() compiler.parse_all_files() compiler.typecheck_stage2_all_files() compiler.process_imports_and_exports() @@ -383,8 +392,10 @@ def main(argc: int, argv: byte**) -> int: if args.mode == CompilerMode::CompileAndRun: compiler.run(executable) free(executable) + for i = 0; compiler.automagic_files[i] != NULL; i++: + free(compiler.automagic_files[i]) else: - assert(False) + assert False return 0 diff --git a/self_hosted/parser.jou b/self_hosted/parser.jou index 356cf215..f99cfec8 100644 --- a/self_hosted/parser.jou +++ b/self_hosted/parser.jou @@ -37,7 +37,7 @@ def parse_type(tokens: Token**) -> AstType: def parse_name_type_value(tokens: Token**, expected_what_for_name: byte*) -> AstNameTypeValue: if (*tokens)->kind != TokenKind::Name: - assert(expected_what_for_name != NULL) + assert expected_what_for_name != NULL (*tokens)->fail_expected_got(expected_what_for_name) result = AstNameTypeValue{name = (*tokens)->short_string, name_location = (*tokens)->location} @@ -144,7 +144,7 @@ def parse_import_path(path_token: Token*, stdlib_path: byte*) -> AstImport: } def parse_call(tokens: Token**, open_paren: byte*, close_paren: byte*) -> AstCall: - assert((*tokens)->kind == TokenKind::Name) # must be checked when calling this function + assert (*tokens)->kind == TokenKind::Name # must be checked when calling this function result = AstCall{location = (*tokens)->location, called_name = (*tokens)->short_string} ++*tokens @@ -233,7 +233,7 @@ def parse_elementary_expression(tokens: Token**) -> AstExpression: # This cannot be used for ++ and --, because with them we can't know the kind from # just the token (e.g. ++ could mean pre-increment or post-increment). def build_operator_expression(t: Token*, arity: int, operands: AstExpression*) -> AstExpression: - assert(arity == 1 or arity == 2) + assert arity == 1 or arity == 2 nbytes = arity * sizeof operands[0] ptr = malloc(nbytes) memcpy(ptr, operands, nbytes) @@ -241,31 +241,31 @@ def build_operator_expression(t: Token*, arity: int, operands: AstExpression*) - result = AstExpression{location = t->location, operands = ptr} if t->is_operator("&"): - assert(arity == 1) + assert arity == 1 result.kind = AstExpressionKind::AddressOf elif t->is_operator("["): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Indexing elif t->is_operator("=="): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Eq elif t->is_operator("!="): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Ne elif t->is_operator(">"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Gt elif t->is_operator(">="): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Ge elif t->is_operator("<"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Lt elif t->is_operator("<="): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Le elif t->is_operator("+"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Add elif t->is_operator("-"): if arity == 2: @@ -278,24 +278,24 @@ def build_operator_expression(t: Token*, arity: int, operands: AstExpression*) - else: result.kind = AstExpressionKind::Dereference elif t->is_operator("/"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Divide elif t->is_operator("%"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Modulo elif t->is_keyword("and"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::And elif t->is_keyword("or"): - assert(arity == 2) + assert arity == 2 result.kind = AstExpressionKind::Or elif t->is_keyword("not"): - assert(arity == 1) + assert arity == 1 result.kind = AstExpressionKind::Not else: - assert(False) + assert False - assert(result.get_arity() == arity) + assert result.get_arity() == arity return result def parse_expression_with_unary_operators(tokens: Token**) -> AstExpression: @@ -335,7 +335,7 @@ def parse_expression_with_unary_operators(tokens: Token**) -> AstExpression: kind = AstExpressionKind::PostDecr else: # We don't have ++ or --, so it must be something in the prefix - assert(prefix_start != prefix_end and suffix_start == suffix_end) + assert prefix_start != prefix_end and suffix_start == suffix_end token = --prefix_end if token->is_operator("*"): kind = AstExpressionKind::Dereference @@ -344,8 +344,7 @@ def parse_expression_with_unary_operators(tokens: Token**) -> AstExpression: elif token->is_keyword("sizeof"): kind = AstExpressionKind::SizeOf else: - assert(False) - kind = AstExpressionKind::SizeOf # dummy value to silence compiler warning + assert False p: AstExpression* = malloc(sizeof(*p)) *p = result @@ -468,6 +467,10 @@ def parse_oneline_statement(tokens: Token**) -> AstStatement: if (*tokens)->kind != TokenKind::Newline: result.return_value = malloc(sizeof *result.return_value) *result.return_value = parse_expression(tokens) + elif (*tokens)->is_keyword("assert"): + ++*tokens + result.kind = AstStatementKind::Assert + result.expression = parse_expression(tokens) elif (*tokens)->is_keyword("break"): ++*tokens result.kind = AstStatementKind::Break @@ -498,7 +501,7 @@ def parse_if_statement(tokens: Token**) -> AstIfStatement: ifs_and_elifs: AstConditionAndBody* = NULL n = 0 - assert((*tokens)->is_keyword("if")) + assert (*tokens)->is_keyword("if") while True: ++*tokens cond = parse_expression(tokens) @@ -521,14 +524,14 @@ def parse_if_statement(tokens: Token**) -> AstIfStatement: } def parse_while_loop(tokens: Token**) -> AstConditionAndBody: - assert((*tokens)->is_keyword("while")) + assert (*tokens)->is_keyword("while") ++*tokens cond = parse_expression(tokens) body = parse_body(tokens) return AstConditionAndBody{condition = cond, body = body} def parse_for_loop(tokens: Token**) -> AstForLoop: - assert((*tokens)->is_keyword("for")) + assert (*tokens)->is_keyword("for") ++*tokens init: AstStatement* = malloc(sizeof *init) diff --git a/self_hosted/runs_wrong.txt b/self_hosted/runs_wrong.txt index 4f6193a5..c2e117b6 100644 --- a/self_hosted/runs_wrong.txt +++ b/self_hosted/runs_wrong.txt @@ -145,3 +145,6 @@ tests/should_succeed/indirect_method_import.jou tests/404/indirect_import_symbol.jou tests/other_errors/noreturn_but_return_without_value.jou tests/other_errors/noreturn_but_return_with_value.jou +stdlib/_assert_fail.jou +tests/other_errors/assert_fail.jou +tests/wrong_type/assert.jou diff --git a/self_hosted/target.jou b/self_hosted/target.jou index 212661da..40b74dff 100644 --- a/self_hosted/target.jou +++ b/self_hosted/target.jou @@ -9,7 +9,6 @@ import "./llvm.jou" import "./paths.jou" -import "./errors_and_warnings.jou" import "stdlib/str.jou" import "stdlib/io.jou" import "stdlib/process.jou" @@ -40,17 +39,17 @@ def init_target() -> void: strcpy(&target.triple[0], "x86_64-pc-windows-gnu") else: triple = LLVMGetDefaultTargetTriple() - assert(strlen(triple) < sizeof target.triple) + assert strlen(triple) < sizeof target.triple strcpy(&target.triple[0], triple) LLVMDisposeMessage(triple) error: byte* = NULL if LLVMGetTargetFromTriple(&target.triple[0], &target.target, &error): - assert(error != NULL) + assert error != NULL fprintf(stderr, "LLVMGetTargetFromTriple(\"%s\") failed: %s\n", &target.triple[0], error) exit(1) - assert(error == NULL) - assert(target.target != NULL) + assert error == NULL + assert target.target != NULL target.target_machine = LLVMCreateTargetMachine( target.target, @@ -61,12 +60,12 @@ def init_target() -> void: LLVMRelocMode::Default, LLVMCodeModel::Default, ) - assert(target.target_machine != NULL) + assert target.target_machine != NULL target.target_data = LLVMCreateTargetDataLayout(target.target_machine) - assert(target.target_data != NULL) + assert target.target_data != NULL tmp = LLVMCopyStringRepOfTargetData(target.target_data) - assert(strlen(tmp) < sizeof target.data_layout) + assert strlen(tmp) < sizeof target.data_layout strcpy(&target.data_layout[0], tmp) LLVMDisposeMessage(tmp) diff --git a/self_hosted/token.jou b/self_hosted/token.jou index 5847477b..d2535e89 100644 --- a/self_hosted/token.jou +++ b/self_hosted/token.jou @@ -127,7 +127,7 @@ class Token: elif self->kind == TokenKind::EndOfFile: strcpy(&got[0], "end of file") else: - assert(False) + assert False message: byte* = malloc(strlen(what_was_expected_instead) + 500) sprintf(message, "expected %s, got %s", what_was_expected_instead, &got[0]) diff --git a/self_hosted/tokenizer.jou b/self_hosted/tokenizer.jou index a633c275..499e857e 100644 --- a/self_hosted/tokenizer.jou +++ b/self_hosted/tokenizer.jou @@ -24,7 +24,7 @@ def is_keyword(word: byte*) -> bool: "import", "def", "declare", "class", "enum", "global", "return", "if", "elif", "else", "while", "for", "break", "continue", "True", "False", "NULL", "self", - "and", "or", "not", "as", "sizeof", + "and", "or", "not", "as", "sizeof", "assert", "void", "noreturn", "bool", "byte", "int", "long", "float", "double", ] @@ -87,7 +87,7 @@ def parse_integer(string: byte*, location: Location, nbits: int) -> long: break result += hexdigit_value(digits[i]) - assert(nbits == 32 or nbits == 64) + assert nbits == 32 or nbits == 64 if nbits == 32 and (result as int) != result: overflow = True @@ -111,8 +111,7 @@ def flip_paren(c: byte) -> byte: return '}' if c == '}': return '{' - assert(False) - return 'x' # never actually runs, but silences compiler warning + assert False class Tokenizer: @@ -161,7 +160,7 @@ class Tokenizer: if b == '\0': return - assert(b != '\r') + assert b != '\r' self->pushback = realloc(self->pushback, self->pushback_len + 1) self->pushback[self->pushback_len++] = b if b == '\n': @@ -172,7 +171,7 @@ class Tokenizer: memset(&dest, 0, sizeof dest) destlen = 0 - assert(is_identifier_or_number_byte(first_byte)) + assert is_identifier_or_number_byte(first_byte) dest[destlen++] = first_byte while True: @@ -520,7 +519,7 @@ def handle_indentations(raw_tokens: Token*) -> Token*: # If the file has indentations after it, they are now represented by separate # indent tokens and parsing will fail. If the file doesn't have any blank/comment # lines in the beginning, it has a newline token anyway to avoid special casing. - assert(tokens[0].kind == TokenKind::Newline) + assert tokens[0].kind == TokenKind::Newline memmove(&tokens[0], &tokens[1], sizeof tokens[0] * (ntokens - 1)) return tokens diff --git a/self_hosted/typecheck.jou b/self_hosted/typecheck.jou index eea6e36e..efc20864 100644 --- a/self_hosted/typecheck.jou +++ b/self_hosted/typecheck.jou @@ -55,7 +55,7 @@ class TypeContext: current_function_signature: Signature* def add_imported_symbol(self, symbol: ExportSymbol*) -> void: - assert(symbol->kind == ExportSymbolKind::Function) + assert symbol->kind == ExportSymbolKind::Function self->functions = realloc(self->functions, sizeof self->functions[0] * (self->nfunctions + 1)) self->functions[self->nfunctions++] = symbol->signature.copy() @@ -68,7 +68,7 @@ class TypeContext: # TODO: implement #def typecheck_stage1_create_types(ctx: TypeContext*, file: AstFile*) -> ExportSymbol*: -# assert(False) +# assert False def type_from_ast(ctx: TypeContext*, ast_type: AstType*) -> Type*: if ast_type->is_void(): @@ -98,8 +98,7 @@ def type_from_ast(ctx: TypeContext*, ast_type: AstType*) -> Type*: ast_type->print(True) printf("\n") - assert(False) # TODO - return NULL # never runs, but silences compiler warning + assert False # TODO def handle_signature(ctx: TypeContext*, astsig: AstSignature*) -> Signature: sig = Signature{ @@ -213,7 +212,7 @@ def do_implicit_cast( if error_template != NULL and not can_cast_implicitly(from, to): fail_with_implicit_cast_error(error_location, error_template, from, to) - assert(types->type_after_cast == NULL) + assert types->type_after_cast == NULL types->type_after_cast = to @@ -300,8 +299,7 @@ def typecheck_expression_maybe_void(ctx: TypeContext*, expression: AstExpression return NULL else: printf("*** %d\n", expression->kind as int) - assert(False) - result = NULL # never runs, but silences compiler warning + assert False p: ExpressionTypes* = malloc(sizeof *p) *p = ExpressionTypes{ @@ -315,7 +313,7 @@ def typecheck_expression_maybe_void(ctx: TypeContext*, expression: AstExpression def typecheck_expression(ctx: TypeContext*, expression: AstExpression*) -> ExpressionTypes*: types = typecheck_expression_maybe_void(ctx, expression) if types == NULL: - assert(expression->kind == AstExpressionKind::FunctionCall) + assert expression->kind == AstExpressionKind::FunctionCall name = &expression->call.called_name[0] message = malloc(strlen(name) + 100) sprintf(message, "function '%s' does not return a value", name) @@ -369,7 +367,7 @@ def typecheck_statement(ctx: TypeContext*, statement: AstStatement*) -> void: ) else: - assert(False) + assert False def typecheck_body(ctx: TypeContext*, body: AstBody*) -> void: for i = 0; i < body->nstatements; i++: @@ -382,9 +380,9 @@ def typecheck_stage3_function_and_method_bodies(ctx: TypeContext*, ast_file: Ast continue sig = ctx->find_function(&ts->function.signature.name[0]) - assert(sig != NULL) + assert sig != NULL - assert(ctx->current_function_signature == NULL) + assert ctx->current_function_signature == NULL ctx->current_function_signature = sig typecheck_body(ctx, &ts->function.body) ctx->current_function_signature = NULL diff --git a/src/build_cfg.c b/src/build_cfg.c index e61f43a5..d4f257e4 100644 --- a/src/build_cfg.c +++ b/src/build_cfg.c @@ -703,6 +703,81 @@ static void build_if_statement(struct State *st, const AstIfStatement *ifstmt) add_jump(st, NULL, done, done, done); } +// TODO: this function is just bad... +static char *read_assertion_from_file(Location location) +{ + FILE *f = fopen(location.filename, "rb"); + if (!f) + return strdup("???"); + + for (int i = 1; i < location.lineno; i++) { + while(1){ + int c= fgetc(f); + if (c==EOF || c=='\n') break; + } + } + + char line[1024] = {0}; + fgets(line, sizeof line, f); + fclose(f); + + if (strstr(line, "#")) + *strstr(line, "#") = '\0'; + trim_whitespace(line); + + if(!strncmp(line, "assert ",7)) + return strdup(line+7); + else + return strdup(line); +} + +static void build_assert(struct State *st, const AstExpression *cond) +{ + const LocalVariable *condvar = build_expression(st, cond); + + // If the condition is true, we jump to a block where the rest of the code goes. + // If the condition is false, we jump to a block that calls _jou_assert_fail(). + CfBlock *trueblock = add_block(st); + CfBlock *falseblock = add_block(st); + add_jump(st, condvar, trueblock, falseblock, falseblock); + + char (*argnames)[100] = malloc(3 * sizeof *argnames); + strcpy(argnames[0], "assertion"); + strcpy(argnames[1], "path"); + strcpy(argnames[2], "lineno"); + + const Type **argtypes = malloc(3 * sizeof(argtypes[0])); // NOLINT + argtypes[0] = get_pointer_type(byteType); + argtypes[1] = get_pointer_type(byteType); + argtypes[2] = intType; + + const LocalVariable *args[4]; + for (int i = 0; i < 3; i++) + args[i] = add_local_var(st, argtypes[i]); + args[3] = NULL; + + char *tmp = read_assertion_from_file(cond->location); + add_constant(st, cond->location, ((Constant){CONSTANT_STRING,{.str=tmp}}), args[0]); + free(tmp); + tmp = strdup(cond->location.filename); + add_constant(st, cond->location, ((Constant){CONSTANT_STRING,{.str=tmp}}), args[1]); + free(tmp); + add_constant(st, cond->location, int_constant(intType, cond->location.lineno), args[2]); + + union CfInstructionData data = { .signature = { + .name = "_jou_assert_fail", + .nargs = 3, + .argtypes = argtypes, + .argnames = argnames, + .takes_varargs = false, + .is_noreturn = true, + .returntype_location = cond->location, + } }; + add_instruction(st, cond->location, CF_CALL, &data, args, NULL); + + st->current_block = trueblock; +} + static void build_statement(struct State *st, const AstStatement *stmt); // for init; cond; incr: @@ -754,6 +829,10 @@ static void build_statement(struct State *st, const AstStatement *stmt) build_if_statement(st, &stmt->data.ifstatement); break; + case AST_STMT_ASSERT: + build_assert(st, &stmt->data.expression); + break; + case AST_STMT_WHILE: build_loop( st, "while", diff --git a/src/free.c b/src/free.c index 5080fdb1..1e379dc8 100644 --- a/src/free.c +++ b/src/free.c @@ -156,6 +156,7 @@ static void free_statement(const AstStatement *stmt) free_ast_body(&stmt->data.forloop.body); break; case AST_STMT_EXPRESSION_STATEMENT: + case AST_STMT_ASSERT: free_expression(&stmt->data.expression); break; case AST_STMT_RETURN: diff --git a/src/jou_compiler.h b/src/jou_compiler.h index db0cc96c..5a011c6b 100644 --- a/src/jou_compiler.h +++ b/src/jou_compiler.h @@ -270,6 +270,7 @@ struct AstStatement { Location location; enum AstStatementKind { AST_STMT_RETURN, + AST_STMT_ASSERT, AST_STMT_IF, AST_STMT_WHILE, AST_STMT_FOR, @@ -285,7 +286,7 @@ struct AstStatement { AST_STMT_EXPRESSION_STATEMENT, // Evaluate an expression and discard the result. } kind; union { - AstExpression expression; // AST_STMT_EXPRESSION_STATEMENT + AstExpression expression; // AST_STMT_EXPRESSION_STATEMENT, AST_STMT_ASSERT AstExpression *returnvalue; // AST_STMT_RETURN (can be NULL) AstConditionAndBody whileloop; AstIfStatement ifstatement; diff --git a/src/main.c b/src/main.c index 1ac21045..ef23daf4 100644 --- a/src/main.c +++ b/src/main.c @@ -397,6 +397,14 @@ static void add_imported_symbols(struct CompileState *compst) } } +static void include_special_stdlib_file(struct CompileState *compst, const char *filename) +{ + char *path = malloc(strlen(compst->stdlib_path) + strlen(filename) + 123); + sprintf(path, "%s/%s", compst->stdlib_path, filename); + parse_file(compst, path, NULL); + free(path); +} + int main(int argc, char **argv) { init_target(); @@ -428,11 +436,9 @@ int main(int argc, char **argv) if (command_line_args.verbosity >= 1) printf("Parsing Jou files...\n"); + include_special_stdlib_file(&compst, "_assert_fail.jou"); #ifdef _WIN32 - char *startup_path = malloc(strlen(compst.stdlib_path) + 50); - sprintf(startup_path, "%s/_windows_startup.jou", compst.stdlib_path); - parse_file(&compst, startup_path, NULL); - free(startup_path); + include_special_stdlib_file(&compst, "_windows_startup.jou"); #endif parse_file(&compst, command_line_args.infile, NULL); diff --git a/src/parse.c b/src/parse.c index 0fadada9..28eaaca7 100644 --- a/src/parse.c +++ b/src/parse.c @@ -703,6 +703,10 @@ static AstStatement parse_oneline_statement(const Token **tokens) result.data.returnvalue = malloc(sizeof *result.data.returnvalue); *result.data.returnvalue = parse_expression(tokens); } + } else if (is_keyword(*tokens, "assert")) { + ++*tokens; + result.kind = AST_STMT_ASSERT; + result.data.expression = parse_expression(tokens); } else if (is_keyword(*tokens, "break")) { ++*tokens; result.kind = AST_STMT_BREAK; diff --git a/src/print.c b/src/print.c index b46f6a9d..a8c02fbf 100644 --- a/src/print.c +++ b/src/print.c @@ -300,6 +300,10 @@ static void print_ast_statement(const AstStatement *stmt, struct TreePrinter tp) printf("expression statement\n"); print_ast_expression(&stmt->data.expression, print_tree_prefix(tp, true)); break; + case AST_STMT_ASSERT: + printf("assert\n"); + print_ast_expression(&stmt->data.expression, print_tree_prefix(tp, true)); + break; case AST_STMT_RETURN: printf("return\n"); if (stmt->data.returnvalue) diff --git a/src/tokenize.c b/src/tokenize.c index 2f4319e2..655f8b79 100644 --- a/src/tokenize.c +++ b/src/tokenize.c @@ -222,7 +222,7 @@ static bool is_keyword(const char *s) "import", "def", "declare", "class", "enum", "global", "return", "if", "elif", "else", "while", "for", "break", "continue", "True", "False", "NULL", "self", - "and", "or", "not", "as", "sizeof", + "and", "or", "not", "as", "sizeof", "assert", "void", "noreturn", "bool", "byte", "int", "long", "float", "double", }; diff --git a/src/typecheck.c b/src/typecheck.c index 92f31a1b..bd4355f7 100644 --- a/src/typecheck.c +++ b/src/typecheck.c @@ -1183,6 +1183,10 @@ static void typecheck_statement(FileTypes *ft, const AstStatement *stmt) case AST_STMT_EXPRESSION_STATEMENT: typecheck_expression(ft, &stmt->data.expression); break; + + case AST_STMT_ASSERT: + typecheck_expression_with_implicit_cast(ft, &stmt->data.expression, boolType, "assertion must be a boolean, not FROM"); + break; } } diff --git a/src/update.c b/src/update.c index 200d3cd4..bd44e896 100644 --- a/src/update.c +++ b/src/update.c @@ -28,20 +28,6 @@ static noreturn void fail() exit(1); } -static void trim_whitespace(char *s) -{ - char *start = s; - while (*start && isspace(*start)) - start++; - - char *end = &s[strlen(s)]; - while (end > start && isspace(end[-1])) - end--; - - *end = '\0'; - memmove(s, start, end-start+1); -} - static void confirm(const char *prompt) { printf("%s (y/n) ", prompt); diff --git a/src/util.c b/src/util.c index c9aeb8ab..8f515091 100644 --- a/src/util.c +++ b/src/util.c @@ -5,6 +5,7 @@ #endif // _WIN32 #include "util.h" +#include <ctype.h> #include <errno.h> #include <stdio.h> #include <stdlib.h> @@ -14,6 +15,20 @@ static void delete_slice(char *start, char *end) memmove(start, end, strlen(end) + 1); } +void trim_whitespace(char *s) +{ + char *start = s; + while (*start && isspace(*start)) + start++; + + char *end = &s[strlen(s)]; + while (end > start && isspace(end[-1])) + end--; + + *end = '\0'; + delete_slice(s, start); +} + void simplify_path(char *path) { #ifdef _WIN32 diff --git a/src/util.h b/src/util.h index fa7e2549..21c382c5 100644 --- a/src/util.h +++ b/src/util.h @@ -88,6 +88,9 @@ Gotchas to watch out for: strcpy((dest),(src)); \ } while(0) +// Similar to .strip() in Python. Depends on the current locale (ctype.h) +void trim_whitespace(char *s); + /* On windows, change backslash to forward slash. Delete unnecessary "." and ".." components. diff --git a/stdlib/_assert_fail.jou b/stdlib/_assert_fail.jou new file mode 100644 index 00000000..ad194b69 --- /dev/null +++ b/stdlib/_assert_fail.jou @@ -0,0 +1,9 @@ +# We could import more stdlib stuff but I prefer to keep this minimal. +# This simplifies debugging the compiler, because this is always added implicitly when compiling. +declare printf(pattern: byte*, ...) -> int +declare exit(status: int) -> noreturn + +def _jou_assert_fail(assertion: byte*, path: byte*, lineno: int) -> noreturn: + # TODO: print to stderr, when self-hosted compiler supports it + printf("Assertion '%s' failed in file \"%s\", line %d.\n", assertion, path, lineno) + exit(1) diff --git a/tests/other_errors/assert_fail.jou b/tests/other_errors/assert_fail.jou new file mode 100644 index 00000000..4000c18e --- /dev/null +++ b/tests/other_errors/assert_fail.jou @@ -0,0 +1,8 @@ +def main() -> int: + lol = True + wat = True + assert lol and wat + assert lol and not wat # Output: Assertion 'lol and not wat' failed in file "tests/other_errors/assert_fail.jou", line 5. + + # TODO: Compiler should be clever enough to not warn about missing return statement, but it isn't. + return 0 diff --git a/tests/should_succeed/compiler_cli.jou b/tests/should_succeed/compiler_cli.jou index f4d596b3..b242b7a2 100644 --- a/tests/should_succeed/compiler_cli.jou +++ b/tests/should_succeed/compiler_cli.jou @@ -49,8 +49,8 @@ def main() -> int: # Test that double-verbose kinda works, without asserting the output in too much detail. # See README for an explanation of why CFG is twice. - # TODO: shouldn't need to hide stdlib/io.jou or _windows_startup stuff, ideally it would be precompiled - run_jou("-vv examples/hello.jou | grep === | grep -v stdlib/io.jou | grep -v stdlib/_windows_startup.jou") + # TODO: shouldn't need to hide stdlib/io.jou or _windows_startup or _assert_fail stuff, ideally it would be precompiled + run_jou("-vv examples/hello.jou | grep === | grep -v stdlib/io.jou | grep -v stdlib/_") # Output: ===== Tokens for file "examples/hello.jou" ===== # Output: ===== AST for file "examples/hello.jou" ===== # Output: ===== Control Flow Graphs for file "examples/hello.jou" ===== @@ -59,7 +59,7 @@ def main() -> int: # Output: ===== Optimized LLVM IR for file "examples/hello.jou" ===== # With optimizations disabled, we don't see the optimized LLVM IR. - run_jou("-vv -O0 examples/hello.jou | grep 'LLVM IR for file' | grep -v stdlib/_windows_startup.jou") + run_jou("-vv -O0 examples/hello.jou | grep 'LLVM IR for file' | grep -v stdlib/_") # Output: ===== Unoptimized LLVM IR for file "examples/hello.jou" ===== # Output: ===== Unoptimized LLVM IR for file "<joudir>/stdlib/io.jou" ===== diff --git a/tests/wrong_type/assert.jou b/tests/wrong_type/assert.jou new file mode 100644 index 00000000..9c6e7808 --- /dev/null +++ b/tests/wrong_type/assert.jou @@ -0,0 +1,2 @@ +def main() -> int: + assert 123 # Error: assertion must be a boolean, not int