Skip to content

Commit

Permalink
Assert (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Mar 12, 2023
1 parent 269ef3b commit f32a487
Show file tree
Hide file tree
Showing 27 changed files with 266 additions and 111 deletions.
1 change: 1 addition & 0 deletions doc/syntax-spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Jou has a few different kinds of tokens:
- `not`
- `as`
- `sizeof`
- `assert`
- `void`
- `noreturn`
- `bool`
Expand Down
14 changes: 9 additions & 5 deletions self_hosted/ast.jou
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class AstCall:

enum AstStatementKind:
ExpressionStatement # Evaluate an expression. Discard the result.
Assert
Return
If
WhileLoop
Expand All @@ -321,19 +322,22 @@ class AstStatement:
kind: AstStatementKind

# TODO: union
expression: AstExpression # AstStatementKind::ExpressionStatement
expression: AstExpression # ExpressionStatement, Assert
if_statement: AstIfStatement
while_loop: AstConditionAndBody
for_loop: AstForLoop
return_value: AstExpression* # AstStatementKind::Return (can be NULL)
return_value: AstExpression* # can be NULL
assignment: AstAssignment
var_declaration: AstNameTypeValue # AstStatementKind::DeclareLocalVar
var_declaration: AstNameTypeValue # DeclareLocalVar

def print(self, tp: TreePrinter) -> void:
printf("[line %d] ", self->location.lineno)
if self->kind == AstStatementKind::ExpressionStatement:
printf("expression statement\n")
self->expression.print(tp.print_prefix(True))
elif self->kind == AstStatementKind::Assert:
printf("assert\n")
self->expression.print(tp.print_prefix(True))
elif self->kind == AstStatementKind::Return:
printf("return\n")
if self->return_value != NULL:
Expand Down Expand Up @@ -479,7 +483,7 @@ class AstNameTypeValue:
printf("%s: ", &self->name[0])
self->type.print(True)
if tp == NULL:
assert(self->value == NULL)
assert self->value == NULL
else:
printf("\n")
if self->value != NULL:
Expand Down Expand Up @@ -597,7 +601,7 @@ class AstFile:
def next_import(self, imp: AstImport**) -> bool:
# Get the corresponding AstToplevelStatement.
ts = *imp as AstToplevelStatement*
assert(&ts->the_import as void* == ts) # TODO: offsetof() or similar
assert &ts->the_import as void* == ts # TODO: offsetof() or similar

# Assume all imports are in the beginning of the file.
if ts == NULL:
Expand Down
49 changes: 36 additions & 13 deletions self_hosted/create_llvm_ir.jou
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import "./typecheck.jou"
import "./types.jou"
import "./ast.jou"
import "./target.jou"
import "./errors_and_warnings.jou"
import "stdlib/io.jou"
import "stdlib/mem.jou"
import "stdlib/str.jou"
Expand All @@ -21,8 +20,7 @@ class AstToIR:
if type->kind == TypeKind::Pointer:
return LLVMPointerType(self->do_type(type->value_type), 0)
printf("asd-Asd., %s\n", &type->name)
assert(False)
return NULL
assert False

def declare_function(self, signature: Signature*) -> void:
argtypes: LLVMType**
Expand All @@ -41,7 +39,7 @@ class AstToIR:
LLVMAddFunction(self->module, &signature->name[0], function_type)

def new_block(self, name_hint: byte*) -> void:
assert(self->current_function != NULL)
assert self->current_function != NULL
block = LLVMAppendBasicBlock(self->current_function, name_hint)
LLVMPositionBuilderAtEnd(self->builder, block)

Expand All @@ -54,6 +52,29 @@ class AstToIR:
string_type = LLVMPointerType(LLVMInt8Type(), 0)
return LLVMBuildBitCast(self->builder, global_var, string_type, "string_ptr")

def build_assert(self, condition: LLVMValue*) -> void:
true_block = LLVMAppendBasicBlock(self->current_function, "assert_true")
false_block = LLVMAppendBasicBlock(self->current_function, "assert_false")
LLVMBuildCondBr(self->builder, condition, true_block, false_block)

LLVMPositionBuilderAtEnd(self->builder, false_block)

argtypes = [LLVMPointerType(LLVMInt8Type(), 0), LLVMPointerType(LLVMInt8Type(), 0), LLVMInt32Type()]
assert_fail_func_type = LLVMFunctionType(LLVMVoidType(), &argtypes[0], 3, False)
assert_fail_func = LLVMGetNamedFunction(self->module, "_jou_assert_fail")
if assert_fail_func == NULL:
assert_fail_func = LLVMAddFunction(self->module, "_jou_assert_fail", assert_fail_func_type)
assert assert_fail_func != NULL

args = [
self->make_a_string_constant("foo"),
self->make_a_string_constant("bar"),
LLVMConstInt(LLVMInt32Type(), 123, False),
]

LLVMBuildCall2(self->builder, assert_fail_func_type, assert_fail_func, &args[0], 3, "")
LLVMPositionBuilderAtEnd(self->builder, true_block)

def do_expression(self, ast: AstExpression*) -> LLVMValue*:
if ast->kind == AstExpressionKind::String:
return self->make_a_string_constant(ast->string)
Expand All @@ -66,10 +87,10 @@ class AstToIR:

elif ast->kind == AstExpressionKind::FunctionCall:
function = LLVMGetNamedFunction(self->module, &ast->call.called_name[0])
assert(function != NULL)
assert(LLVMGetTypeKind(LLVMTypeOf(function)) == LLVMTypeKind::Pointer)
assert function != NULL
assert LLVMGetTypeKind(LLVMTypeOf(function)) == LLVMTypeKind::Pointer
function_type = LLVMGetElementType(LLVMTypeOf(function))
assert(LLVMGetTypeKind(function_type) == LLVMTypeKind::Function)
assert LLVMGetTypeKind(function_type) == LLVMTypeKind::Function

args: LLVMValue** = malloc(sizeof args[0] * ast->call.nargs)
for i = 0; i < ast->call.nargs; i++:
Expand All @@ -80,8 +101,7 @@ class AstToIR:

else:
printf("Asd-asd. Unknown expr %d...\n", ast->kind)
assert(False)
return NULL
assert False

def do_statement(self, ast: AstStatement*) -> void:
if ast->kind == AstStatementKind::ExpressionStatement:
Expand All @@ -94,9 +114,12 @@ class AstToIR:
LLVMBuildRetVoid(self->builder)
# If more code follows, place it into a new block that never actually runs
self->new_block("after_return")
elif ast->kind == AstStatementKind::Assert:
condition = self->do_expression(&ast->expression)
self->build_assert(condition)
else:
printf("Asd-asd. Unknown statement...\n")
assert(False)
assert False

def do_body(self, body: AstBody*) -> void:
for i = 0; i < body->nstatements; i++:
Expand All @@ -105,12 +128,12 @@ class AstToIR:
# The function must already be declared.
def define_function(self, funcdef: AstFunction*) -> void:
llvm_func = LLVMGetNamedFunction(self->module, &funcdef->signature.name[0])
assert(llvm_func != NULL)
assert(self->current_function == NULL)
assert llvm_func != NULL
assert self->current_function == NULL
self->current_function = llvm_func

self->new_block("start")
assert(funcdef->body.nstatements > 0) # it is a definition
assert funcdef->body.nstatements > 0 # it is a definition
self->do_body(&funcdef->body)
LLVMBuildUnreachable(self->builder)

Expand Down
8 changes: 0 additions & 8 deletions self_hosted/errors_and_warnings.jou
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,3 @@ def fail(location: Location, message: byte*) -> noreturn:
fprintf(stderr, ": %s\n", message)

exit(1)

# TODO: doesn't really belong here
def assert(b: bool) -> void:
if not b:
fflush(stdout)
fflush(stderr)
fprintf(stderr, "assertion failed\n")
exit(1)
2 changes: 1 addition & 1 deletion self_hosted/llvm.jou
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ declare LLVMDisposeBuilder(Builder: LLVMBuilder*) -> void
declare LLVMBuildRet(Builder: LLVMBuilder*, V: LLVMValue*) -> LLVMValue*
declare LLVMBuildRetVoid(Builder: LLVMBuilder*) -> LLVMValue*
declare LLVMBuildBr(Builder: LLVMBuilder*, Dest: LLVMBasicBlock*) -> LLVMValue*
declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMValue*, Else: LLVMBasicBlock*) -> LLVMValue*
declare LLVMBuildCondBr(Builder: LLVMBuilder*, If: LLVMValue*, Then: LLVMBasicBlock*, Else: LLVMBasicBlock*) -> LLVMValue*
declare LLVMBuildUnreachable(Builder: LLVMBuilder*) -> LLVMValue*
declare LLVMBuildAdd(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
declare LLVMBuildSub(Builder: LLVMBuilder*, LHS: LLVMValue*, RHS: LLVMValue*, Name: byte*) -> LLVMValue*
Expand Down
29 changes: 20 additions & 9 deletions self_hosted/main.jou
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import "../config.jou"
import "./errors_and_warnings.jou"
import "./ast.jou"
import "./tokenizer.jou"
import "./parser.jou"
Expand Down Expand Up @@ -104,7 +103,7 @@ def get_sane_filename(path: byte*) -> byte[50]:

name: byte[50]
snprintf(&name[0], sizeof name, "%s", path)
assert(name[0] != '\0')
assert name[0] != '\0'

if name[0] == '.':
name[0] = '_'
Expand All @@ -129,11 +128,20 @@ class Compiler:
args: CommandLineArgs*
files: FileState*
nfiles: int
automagic_files: byte*[10]

def determine_automagic_files(self) -> void:
# TODO: this breaks too much stuff
return
# self->automagic_files[0] = malloc(strlen(self->stdlib_path) + 40)
# sprintf(self->automagic_files[0], "%s/_assert_fail.jou", self->stdlib_path)

def parse_all_files(self) -> void:
queue: byte** = malloc(sizeof queue[0])
queue[0] = self->args->main_path
queue_len = 1
queue: byte** = malloc(50 * sizeof queue[0])
queue_len = 0
queue[queue_len++] = self->args->main_path
for i = 0; self->automagic_files[i] != NULL; i++:
queue[queue_len++] = self->automagic_files[i]

while queue_len > 0:
path = queue[--queue_len]
Expand Down Expand Up @@ -195,7 +203,7 @@ class Compiler:
if self->verbosity >= 1:
printf("Type-check stage 2: %s\n", self->files[i].ast.path)

assert(self->files[i].pending_exports == NULL)
assert self->files[i].pending_exports == NULL
self->files[i].pending_exports = typecheck_stage2_signatures_globals_structbodies(
&self->files[i].typectx,
&self->files[i].ast,
Expand Down Expand Up @@ -280,10 +288,10 @@ class Compiler:

error: byte* = NULL
if LLVMTargetMachineEmitToFile(target.target_machine, module, path, LLVMCodeGenFileType::ObjectFile, &error):
assert(error != NULL)
assert error != NULL
fprintf(stderr, "error in LLVMTargetMachineEmitToFile(): %s\n", error)
exit(1)
assert(error == NULL)
assert error == NULL

return paths

Expand Down Expand Up @@ -368,6 +376,7 @@ def main(argc: int, argv: byte**) -> int:
stdlib_path = find_stdlib(),
args = &args,
}
compiler.determine_automagic_files()
compiler.parse_all_files()
compiler.typecheck_stage2_all_files()
compiler.process_imports_and_exports()
Expand All @@ -383,8 +392,10 @@ def main(argc: int, argv: byte**) -> int:
if args.mode == CompilerMode::CompileAndRun:
compiler.run(executable)
free(executable)
for i = 0; compiler.automagic_files[i] != NULL; i++:
free(compiler.automagic_files[i])

else:
assert(False)
assert False

return 0
Loading

0 comments on commit f32a487

Please sign in to comment.