From 8a0bb4b50ef77932243f4f65ae90be84e763ec45 Mon Sep 17 00:00:00 2001 From: Akuli Date: Mon, 27 Jan 2025 23:02:49 +0200 Subject: [PATCH] Evaluate all if statements at compile time when possible (#710) --- compiler/evaluate.jou | 126 +++++++++++++----- compiler/main.jou | 2 +- ..._at_runtime.jou => WINDOWS_at_runtime.jou} | 5 +- tests/should_succeed/if_WINDOWS_in_class.jou | 34 +++++ .../should_succeed/if_WINDOWS_in_function.jou | 15 +++ 5 files changed, 150 insertions(+), 32 deletions(-) rename tests/should_succeed/{if_WINDOWS_at_runtime.jou => WINDOWS_at_runtime.jou} (85%) create mode 100644 tests/should_succeed/if_WINDOWS_in_class.jou create mode 100644 tests/should_succeed/if_WINDOWS_in_function.jou diff --git a/compiler/evaluate.jou b/compiler/evaluate.jou index f0325647..496fb009 100644 --- a/compiler/evaluate.jou +++ b/compiler/evaluate.jou @@ -25,37 +25,60 @@ def get_special_constant(name: byte*) -> int: return -1 -def evaluate_condition(expr: AstExpression*) -> bool: +def evaluate_condition(expr: AstExpression*) -> int: # 1=true 0=false -1=error match expr->kind: case AstExpressionKind.GetVariable: - v = get_special_constant(expr->varname) - if v == 0: - return False - if v == 1: - return True - fail(expr->location, "cannot evaluate condition at compile time") + return get_special_constant(expr->varname) case AstExpressionKind.And: - return evaluate_condition(&expr->operands[0]) and evaluate_condition(&expr->operands[1]) + match evaluate_condition(&expr->operands[0]): + case 1: + return evaluate_condition(&expr->operands[1]) + case 0: + return 0 # left side false, don't evaluate right side + case -1: + return -1 + case _: + assert False case AstExpressionKind.Or: - return evaluate_condition(&expr->operands[0]) or evaluate_condition(&expr->operands[1]) + match evaluate_condition(&expr->operands[0]): + case 1: + return 1 # left side true, don't evaluate right side + case 0: + return evaluate_condition(&expr->operands[1]) + case -1: + return -1 # error + case _: + assert False case AstExpressionKind.Not: - return not evaluate_condition(&expr->operands[0]) + match evaluate_condition(&expr->operands[0]): + case 1: + return 0 + case 0: + return 1 + case -1: + return -1 + case _: + assert False case _: - fail(expr->location, "cannot evaluate condition at compile time") + return -1 -# returns the statements to replace if statement with -@public -def evaluate_compile_time_if_statement(if_stmt: AstIfStatement*) -> AstBody: - result = &if_stmt->else_body +# returns the statements to replace if statement with, as a pointer inside if_stmt +def choose_if_elif_branch(if_stmt: AstIfStatement*) -> AstBody*: for i = 0; i < if_stmt->n_if_and_elifs; i++: - if evaluate_condition(&if_stmt->if_and_elifs[i].condition): - result = &if_stmt->if_and_elifs[i].body - break - - ret = *result - *result = AstBody{} # avoid double-free - return ret + match evaluate_condition(&if_stmt->if_and_elifs[i].condition): + case -1: + # don't know how to evaluate it + return NULL + case 0: + # try the next elif or else + pass + case 1: + # condition is true, let's use this if or elif + return &if_stmt->if_and_elifs[i].body + case _: + assert False + return &if_stmt->else_body # Replace body->statements[i] with zero or more statements from another body. @@ -72,12 +95,55 @@ def replace(body: AstBody*, i: int, new: AstBody) -> None: body->nstatements += new.nstatements -# This handles nested if statements. +def evaluate_if_statements_in_body(body: AstBody*, must_succeed: bool) -> None: + for i = 0; i < body->nstatements; i++: + match body->statements[i].kind: + case AstStatementKind.If: + ptr = choose_if_elif_branch(&body->statements[i].if_statement) + if ptr == NULL and must_succeed: + fail(body->statements[i].location, "cannot evaluate condition at compile time") + if ptr != NULL: + replacement = *ptr + *ptr = AstBody{} # avoid double-free + replace(body, i, replacement) + i-- # cancels i++ to do same index again, so that we handle nested if statements + case AstStatementKind.WhileLoop: + evaluate_if_statements_in_body(&body->statements[i].while_loop.body, False) + case AstStatementKind.ForLoop: + evaluate_if_statements_in_body(&body->statements[i].for_loop.body, False) + case AstStatementKind.Class: + evaluate_if_statements_in_body(body->statements[i].classdef.body, True) + case AstStatementKind.FunctionDef: + evaluate_if_statements_in_body(&body->statements[i].function.body, False) + case AstStatementKind.MethodDef: + evaluate_if_statements_in_body(&body->statements[i].method.body, False) + case ( + AstStatementKind.ExpressionStatement + | AstStatementKind.Assert + | AstStatementKind.Pass + | AstStatementKind.Return + | AstStatementKind.Match + | AstStatementKind.Break + | AstStatementKind.Continue + | AstStatementKind.DeclareLocalVar + | AstStatementKind.Assign + | AstStatementKind.InPlaceAdd + | AstStatementKind.InPlaceSub + | AstStatementKind.InPlaceMul + | AstStatementKind.InPlaceDiv + | AstStatementKind.InPlaceMod + | AstStatementKind.FunctionDeclare + | AstStatementKind.Enum + | AstStatementKind.GlobalVariableDeclare + | AstStatementKind.GlobalVariableDef + | AstStatementKind.Import + | AstStatementKind.ClassField + | AstStatementKind.ClassUnion + ): + # these statements cannot contain if statements, no need to recurse inside + pass + + @public -def evaluate_compile_time_if_statements(body: AstBody*) -> None: - i = 0 - while i < body->nstatements: - if body->statements[i].kind == AstStatementKind.If: - replace(body, i, evaluate_compile_time_if_statement(&body->statements[i].if_statement)) - else: - i++ +def evaluate_compile_time_if_statements(file: AstFile*) -> None: + evaluate_if_statements_in_body(&file->body, True) diff --git a/compiler/main.jou b/compiler/main.jou index f3dfbbe6..37da2898 100644 --- a/compiler/main.jou +++ b/compiler/main.jou @@ -202,7 +202,7 @@ class CompileState: if command_line_args.verbosity >= 1: printf("Evaluating compile-time if statements in %s\n", filename) - evaluate_compile_time_if_statements(&fs.ast.body) + evaluate_compile_time_if_statements(&fs.ast) if command_line_args.verbosity >= 2: fs.ast.print() diff --git a/tests/should_succeed/if_WINDOWS_at_runtime.jou b/tests/should_succeed/WINDOWS_at_runtime.jou similarity index 85% rename from tests/should_succeed/if_WINDOWS_at_runtime.jou rename to tests/should_succeed/WINDOWS_at_runtime.jou index 46e2d7c5..8d6e530c 100644 --- a/tests/should_succeed/if_WINDOWS_at_runtime.jou +++ b/tests/should_succeed/WINDOWS_at_runtime.jou @@ -1,6 +1,9 @@ import "stdlib/io.jou" import "stdlib/process.jou" +def foo() -> bool: + return WINDOWS + def main() -> int: f = fopen("tmp/tests/asdasd.txt", "w") assert f != NULL @@ -8,7 +11,7 @@ def main() -> int: fclose(f) # Output: asd asd - if WINDOWS: + if foo(): system("type tmp\\tests\\asdasd.txt") else: system("cat tmp/tests/asdasd.txt") diff --git a/tests/should_succeed/if_WINDOWS_in_class.jou b/tests/should_succeed/if_WINDOWS_in_class.jou new file mode 100644 index 00000000..50eb1775 --- /dev/null +++ b/tests/should_succeed/if_WINDOWS_in_class.jou @@ -0,0 +1,34 @@ +if WINDOWS: + import "stdlib/io.jou" +else: + import "stdlib/str.jou" + # Low level function for writing to file. Does not exist on Windows. + declare write(fd: int, buf: byte*, count: long) -> long + + +class EnterprisePrintWriterFactory: + if WINDOWS: + windows_message: byte* + + def set_message(self, m: byte*) -> None: + self->windows_message = m + + def show_message(self) -> None: + puts(self->windows_message) + + else: + posix_message: byte* + + def set_message(self, m: byte*) -> None: + self->posix_message = m + + def show_message(self) -> None: + write(1, self->posix_message, strlen(self->posix_message)) + write(1, "\n", 1) + + +def main() -> int: + e = EnterprisePrintWriterFactory{} + e.set_message("hello") + e.show_message() # Output: hello + return 0 diff --git a/tests/should_succeed/if_WINDOWS_in_function.jou b/tests/should_succeed/if_WINDOWS_in_function.jou new file mode 100644 index 00000000..3d89ade4 --- /dev/null +++ b/tests/should_succeed/if_WINDOWS_in_function.jou @@ -0,0 +1,15 @@ +if WINDOWS: + import "stdlib/io.jou" +else: + # Low level function for writing to file. Does not exist on Windows. + declare write(fd: int, buf: byte*, count: long) -> long + +def main() -> int: + # Output: hello + if WINDOWS: + printf("hello\n") + else: + # If this code is compiled on Windows, there will be linker errors. + write(1, "hello\n", 6) + + return 0