Skip to content

Commit

Permalink
Evaluate all if statements at compile time when possible (#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli authored Jan 27, 2025
1 parent edc0ba8 commit 8a0bb4b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 32 deletions.
126 changes: 96 additions & 30 deletions compiler/evaluate.jou
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,60 @@ def get_special_constant(name: byte*) -> int:
return -1


def evaluate_condition(expr: AstExpression*) -> bool:
def evaluate_condition(expr: AstExpression*) -> int: # 1=true 0=false -1=error
match expr->kind:
case AstExpressionKind.GetVariable:
v = get_special_constant(expr->varname)
if v == 0:
return False
if v == 1:
return True
fail(expr->location, "cannot evaluate condition at compile time")
return get_special_constant(expr->varname)
case AstExpressionKind.And:
return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1])
match evaluate_condition(&expr->operands[0]):
case 1:
return evaluate_condition(&expr->operands[1])
case 0:
return 0 # left side false, don't evaluate right side
case -1:
return -1
case _:
assert False
case AstExpressionKind.Or:
return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1])
match evaluate_condition(&expr->operands[0]):
case 1:
return 1 # left side true, don't evaluate right side
case 0:
return evaluate_condition(&expr->operands[1])
case -1:
return -1 # error
case _:
assert False
case AstExpressionKind.Not:
return not evaluate_condition(&expr->operands[0])
match evaluate_condition(&expr->operands[0]):
case 1:
return 0
case 0:
return 1
case -1:
return -1
case _:
assert False
case _:
fail(expr->location, "cannot evaluate condition at compile time")
return -1


# returns the statements to replace if statement with
@public
def evaluate_compile_time_if_statement(if_stmt: AstIfStatement*) -> AstBody:
result = &if_stmt->else_body
# returns the statements to replace if statement with, as a pointer inside if_stmt
def choose_if_elif_branch(if_stmt: AstIfStatement*) -> AstBody*:
for i = 0; i < if_stmt->n_if_and_elifs; i++:
if evaluate_condition(&if_stmt->if_and_elifs[i].condition):
result = &if_stmt->if_and_elifs[i].body
break

ret = *result
*result = AstBody{} # avoid double-free
return ret
match evaluate_condition(&if_stmt->if_and_elifs[i].condition):
case -1:
# don't know how to evaluate it
return NULL
case 0:
# try the next elif or else
pass
case 1:
# condition is true, let's use this if or elif
return &if_stmt->if_and_elifs[i].body
case _:
assert False
return &if_stmt->else_body


# Replace body->statements[i] with zero or more statements from another body.
Expand All @@ -72,12 +95,55 @@ def replace(body: AstBody*, i: int, new: AstBody) -> None:
body->nstatements += new.nstatements


# This handles nested if statements.
def evaluate_if_statements_in_body(body: AstBody*, must_succeed: bool) -> None:
for i = 0; i < body->nstatements; i++:
match body->statements[i].kind:
case AstStatementKind.If:
ptr = choose_if_elif_branch(&body->statements[i].if_statement)
if ptr == NULL and must_succeed:
fail(body->statements[i].location, "cannot evaluate condition at compile time")
if ptr != NULL:
replacement = *ptr
*ptr = AstBody{} # avoid double-free
replace(body, i, replacement)
i-- # cancels i++ to do same index again, so that we handle nested if statements
case AstStatementKind.WhileLoop:
evaluate_if_statements_in_body(&body->statements[i].while_loop.body, False)
case AstStatementKind.ForLoop:
evaluate_if_statements_in_body(&body->statements[i].for_loop.body, False)
case AstStatementKind.Class:
evaluate_if_statements_in_body(body->statements[i].classdef.body, True)
case AstStatementKind.FunctionDef:
evaluate_if_statements_in_body(&body->statements[i].function.body, False)
case AstStatementKind.MethodDef:
evaluate_if_statements_in_body(&body->statements[i].method.body, False)
case (
AstStatementKind.ExpressionStatement
| AstStatementKind.Assert
| AstStatementKind.Pass
| AstStatementKind.Return
| AstStatementKind.Match
| AstStatementKind.Break
| AstStatementKind.Continue
| AstStatementKind.DeclareLocalVar
| AstStatementKind.Assign
| AstStatementKind.InPlaceAdd
| AstStatementKind.InPlaceSub
| AstStatementKind.InPlaceMul
| AstStatementKind.InPlaceDiv
| AstStatementKind.InPlaceMod
| AstStatementKind.FunctionDeclare
| AstStatementKind.Enum
| AstStatementKind.GlobalVariableDeclare
| AstStatementKind.GlobalVariableDef
| AstStatementKind.Import
| AstStatementKind.ClassField
| AstStatementKind.ClassUnion
):
# these statements cannot contain if statements, no need to recurse inside
pass


@public
def evaluate_compile_time_if_statements(body: AstBody*) -> None:
i = 0
while i < body->nstatements:
if body->statements[i].kind == AstStatementKind.If:
replace(body, i, evaluate_compile_time_if_statement(&body->statements[i].if_statement))
else:
i++
def evaluate_compile_time_if_statements(file: AstFile*) -> None:
evaluate_if_statements_in_body(&file->body, True)
2 changes: 1 addition & 1 deletion compiler/main.jou
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class CompileState:

if command_line_args.verbosity >= 1:
printf("Evaluating compile-time if statements in %s\n", filename)
evaluate_compile_time_if_statements(&fs.ast.body)
evaluate_compile_time_if_statements(&fs.ast)

if command_line_args.verbosity >= 2:
fs.ast.print()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import "stdlib/io.jou"
import "stdlib/process.jou"

def foo() -> bool:
return WINDOWS

def main() -> int:
f = fopen("tmp/tests/asdasd.txt", "w")
assert f != NULL
fprintf(f, "asd asd\n")
fclose(f)

# Output: asd asd
if WINDOWS:
if foo():
system("type tmp\\tests\\asdasd.txt")
else:
system("cat tmp/tests/asdasd.txt")
Expand Down
34 changes: 34 additions & 0 deletions tests/should_succeed/if_WINDOWS_in_class.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
if WINDOWS:
import "stdlib/io.jou"
else:
import "stdlib/str.jou"
# Low level function for writing to file. Does not exist on Windows.
declare write(fd: int, buf: byte*, count: long) -> long


class EnterprisePrintWriterFactory:
if WINDOWS:
windows_message: byte*

def set_message(self, m: byte*) -> None:
self->windows_message = m

def show_message(self) -> None:
puts(self->windows_message)

else:
posix_message: byte*

def set_message(self, m: byte*) -> None:
self->posix_message = m

def show_message(self) -> None:
write(1, self->posix_message, strlen(self->posix_message))
write(1, "\n", 1)


def main() -> int:
e = EnterprisePrintWriterFactory{}
e.set_message("hello")
e.show_message() # Output: hello
return 0
15 changes: 15 additions & 0 deletions tests/should_succeed/if_WINDOWS_in_function.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
if WINDOWS:
import "stdlib/io.jou"
else:
# Low level function for writing to file. Does not exist on Windows.
declare write(fd: int, buf: byte*, count: long) -> long

def main() -> int:
# Output: hello
if WINDOWS:
printf("hello\n")
else:
# If this code is compiled on Windows, there will be linker errors.
write(1, "hello\n", 6)

return 0

0 comments on commit 8a0bb4b

Please sign in to comment.