From 0ee3d57f563e732fd9c3b3b4738124e79e35d914 Mon Sep 17 00:00:00 2001 From: Ajani Bilby Date: Mon, 25 Mar 2024 11:42:21 +1100 Subject: [PATCH] Fix operator precedence not applying correctly (#12) --- source/compiler/codegen/expression/infix.ts | 154 +++++++++++++----- .../compiler/codegen/expression/precedence.ts | 127 +++++++++------ tests/e2e/compiler/numeric.test.ts | 111 ++++++++----- 3 files changed, 257 insertions(+), 135 deletions(-) diff --git a/source/compiler/codegen/expression/infix.ts b/source/compiler/codegen/expression/infix.ts index 9860fb6..049fa6a 100644 --- a/source/compiler/codegen/expression/infix.ts +++ b/source/compiler/codegen/expression/infix.ts @@ -281,52 +281,128 @@ function CompileRem(ctx: Context, lhs: IntrinsicValue, rhs: IntrinsicValue, ref: return lhs; } - if (lhs === f32.value) { - const regA = ctx.scope.register.allocate(f32.bitcode, false); - const regB = ctx.scope.register.allocate(f32.bitcode, false); - ctx.block.push(Instruction.local.set(regB.ref)); - ctx.block.push(Instruction.local.set(regA.ref)); - - ctx.block.push(Instruction.local.get(regA.ref)); // a - - - ctx.block.push(Instruction.local.get(regA.ref)); // floor(a/b) - ctx.block.push(Instruction.local.get(regB.ref)); - ctx.block.push(Instruction.f32.div()); - ctx.block.push(Instruction.f32.trunc()); - - ctx.block.push(Instruction.local.get(regB.ref)); // * b - ctx.block.push(Instruction.f32.mul()); + if (lhs === f32.value || lhs === f64.value) return CompileFloatRemainder(ctx, lhs, ref); - ctx.block.push(Instruction.f32.sub()); - - regA.free(); - regB.free(); - return lhs; - } + Panic(`${colors.red("Error")}: Unhandled type ${lhs.type.name}\n`, { + path: ctx.file.path, name: ctx.file.name, ref + }); +} - if (lhs === f64.value) { - const regA = ctx.scope.register.allocate(f64.bitcode, false); - const regB = ctx.scope.register.allocate(f64.bitcode, false); - ctx.block.push(Instruction.local.set(regA.ref)); - ctx.block.push(Instruction.local.set(regB.ref)); +function CompileFloatRemainder(ctx: Context, type: IntrinsicValue, ref: ReferenceRange) { + /** + * float fmod(float x, float y) { + if (y == 0.0) return NaN; - ctx.block.push(Instruction.local.get(regA.ref)); - ctx.block.push(Instruction.local.get(regB.ref)); - ctx.block.push(Instruction.f64.div()); - ctx.block.push(Instruction.f64.trunc()); + float quotient = x / y; + float remainder = x - trunc(quotient) * y; - ctx.block.push(Instruction.local.get(regB.ref)); - ctx.block.push(Instruction.f64.mul()); + if (remainder == 0.0 && quotient < 0.0) return -0.0; + else return remainder; + }*/ - ctx.block.push(Instruction.local.get(regA.ref)); - ctx.block.push(Instruction.f64.sub()); + if (type === f32.value) { + const x = ctx.scope.register.allocate(f32.bitcode); + const y = ctx.scope.register.allocate(f32.bitcode); + ctx.block.push(Instruction.local.set(y.ref)); + ctx.block.push(Instruction.local.set(x.ref)); - regA.free(); - regB.free(); - return lhs; - } + const q = ctx.scope.register.allocate(f32.bitcode); + const r = ctx.scope.register.allocate(f32.bitcode); - Panic(`${colors.red("Error")}: Unhandled type ${lhs.type.name}\n`, { + // if (y == 0) return NaN; + ctx.block.push(Instruction.local.get(y.ref)); + ctx.block.push(Instruction.const.f32(0.0)); + ctx.block.push(Instruction.f32.eq()); + ctx.block.push(Instruction.if(type.type.bitcode, [ + Instruction.const.f32(NaN) + ], [ + Instruction.local.get(x.ref), // q = x / y + Instruction.local.get(y.ref), + Instruction.f32.div(), + Instruction.local.set(q.ref), + + Instruction.local.get(x.ref), // x - trunc(q)*y + Instruction.local.get(q.ref), + Instruction.f32.trunc(), + Instruction.local.get(y.ref), + Instruction.f32.mul(), + Instruction.f32.sub(), + Instruction.local.set(r.ref), + + Instruction.local.get(r.ref), // remainder == 0.0 + Instruction.const.f32(0.0), + Instruction.f32.eq(), + + Instruction.local.get(q.ref), // quotient < 0.0 + Instruction.const.f32(0.0), + Instruction.f32.lt(), + + Instruction.i32.and(), // && + Instruction.if(f32.bitcode, [ + Instruction.const.f32(-0.0) + ], [ + Instruction.local.get(r.ref) + ]) + ])); + + x.free(); y.free(); + q.free(); r.free(); + + return type; + } + + if (type === f64.value) { + const x = ctx.scope.register.allocate(f64.bitcode); + const y = ctx.scope.register.allocate(f64.bitcode); + ctx.block.push(Instruction.local.set(y.ref)); + ctx.block.push(Instruction.local.set(x.ref)); + + const q = ctx.scope.register.allocate(f64.bitcode); + const r = ctx.scope.register.allocate(f64.bitcode); + + // if (y == 0) return NaN; + ctx.block.push(Instruction.local.get(y.ref)); + ctx.block.push(Instruction.const.f64(0.0)); + ctx.block.push(Instruction.f64.eq()); + ctx.block.push(Instruction.if(type.type.bitcode, [ + Instruction.const.f64(NaN) + ], [ + Instruction.local.get(x.ref), // q = x / y + Instruction.local.get(y.ref), + Instruction.f64.div(), + Instruction.local.set(q.ref), + + Instruction.local.get(x.ref), // x - trunc(q)*y + Instruction.local.get(q.ref), + Instruction.f64.trunc(), + Instruction.local.get(y.ref), + Instruction.f64.mul(), + Instruction.f64.sub(), + Instruction.local.set(r.ref), + + Instruction.local.get(r.ref), // remainder == 0.0 + Instruction.const.f64(0.0), + Instruction.f64.eq(), + + Instruction.local.get(q.ref), // quotient < 0.0 + Instruction.const.f64(0.0), + Instruction.f64.lt(), + + Instruction.i32.and(), // && + Instruction.if(f64.bitcode, [ + Instruction.const.f64(-0.0) + ], [ + Instruction.local.get(r.ref) + ]) + ])); + + x.free(); y.free(); + q.free(); r.free(); + + return type; + } + + Panic(`${colors.red("Error")}: Unhandled type ${type.type.name}\n`, { path: ctx.file.path, name: ctx.file.name, ref }); } diff --git a/source/compiler/codegen/expression/precedence.ts b/source/compiler/codegen/expression/precedence.ts index b7ceb0d..bb1fbc0 100644 --- a/source/compiler/codegen/expression/precedence.ts +++ b/source/compiler/codegen/expression/precedence.ts @@ -1,29 +1,32 @@ import type { Term_Expr, Term_Expr_arg, _Literal } from "~/bnf/syntax.d.ts"; import { ReferenceRange } from "~/parser.ts"; import { Panic } from "~/compiler/helper.ts"; +import { assert } from "https://deno.land/std@0.201.0/assert/assert.ts"; const precedence = { ".": 1, "->": 1, - "*" : 3, "/" : 3, "%" : 3, - "+" : 4, "-" : 4, - "<<": 5, ">>": 5, - "<" : 6, ">" : 6, "<=": 6, ">=": 6, - "instanceof": 6.5, - "==": 7, "!=": 7, - "as": 7.5, - "&": 8, - "^": 9, - "|": 10, - "&&": 11, - "||": 12, + "**": 2, + "%": 3, + "*" : 4, "/" : 4, + "+" : 5, "-" : 5, + "<<": 6, ">>": 6, + "<" : 7, ">" : 7, "<=": 7, ">=": 7, + "instanceof": 8, + "==": 9, "!=": 9, + "&": 10, + "^": 11, + "|": 12, + "as": 13, + "&&": 14, + "||": 15, } as { [key: string]: number }; export function GetPrecedence (a: string, b: string) { const A = precedence[a]; const B = precedence[b]; - if (!A) Panic(`Unknown infix operation ${a}`); - if (!B) Panic(`Unknown infix operation ${a}`); + if (A === undefined) Panic(`Unknown infix operation ${a} 0x${a.charCodeAt(0).toString(16)}`); + if (B === undefined) Panic(`Unknown infix operation ${b} 0x${b.charCodeAt(0).toString(16)}`); return A !== B ? Math.min(1, Math.max(-1, A-B)) @@ -39,51 +42,69 @@ export type PrecedenceTree = Term_Expr_arg | { }; export function ApplyPrecedence(syntax: Term_Expr) { - let root: PrecedenceTree = syntax.value[0] as PrecedenceTree; + const rpn = new Array(); + const op_stack = new Array(); + rpn.push(syntax.value[0]); for (const action of syntax.value[1].value) { - const op = action.value[0].value; - const arg = action.value[1] - - // First action - if (root.type !== "infix") { - root = { - type: "infix", - lhs: root, - op, - rhs: arg, - ref: ReferenceRange.union(root.ref, arg.ref) - }; - continue; + const op = action.value[0].value; + while (op_stack.length > 0) { + const prev = op_stack[op_stack.length - 1]!; // peak + if (GetPrecedence(prev, op) <= 0) { + rpn.push(op_stack.pop()!); + } else break; } + op_stack.push(op); + rpn.push(action.value[1]); + } + + // Drain remaining operators + while (op_stack.length > 0) { + rpn.push(op_stack.pop()!); + } - const p = GetPrecedence(root.op, op); - if (p > 0) { - // Transform stealing previous operand - // (1 + 2) * 3 -> (2 * 3) + 1 - root = { - type: "infix", - lhs: { - type: "infix", - lhs: root.rhs, - op, - rhs: arg, - ref: ReferenceRange.union(root.ref, arg.ref) - }, - op: root.op, - rhs: root.lhs, - ref: ReferenceRange.union(arg.ref, arg.ref) - } - } else { - root = { - type: "infix", - lhs: root, - op: op, - rhs: arg, - ref: ReferenceRange.union(root.ref, arg.ref) - } + // This could probably be optimised in the future to not use a stack, and just manipulate a raw root node + const stack = new Array(); + while (rpn.length > 0) { + const token = rpn.shift()!; + + if (typeof token != "string") { + stack.push(token); + continue; } + + const rhs = stack.pop()!; + const lhs = stack.pop()!; + + stack.push({ + type: "infix", + lhs: lhs, + op: token, + rhs: rhs, + ref: ReferenceRange.union(lhs.ref, rhs.ref) + }) } + const root = stack.pop()!; + assert(typeof root !== "string", "Expression somehow has no arguments during precedence calculation"); + assert(stack.length == 0, "Expression somehow has only operators during precedence calculation"); + return root; +} + + +// For debugging assistance when hell breaks loose +function StringifyPrecedence(tree: PrecedenceTree | string): string { + if (typeof tree === "string") return tree; + + if (tree.type === "infix") return `(${StringifyPrecedence(tree.lhs)} ${tree.op} ${StringifyPrecedence(tree.rhs)})`; + + const arg = tree.value[1].value[0]; + if (arg.type == "expr_brackets") return `(...)`; + if (arg.type != "constant") return `type[${arg.type}]`; + + if (arg.value[0].type == "boolean") return arg.value[0].value[0].value; + if (arg.value[0].type == "integer") return arg.value[0].value[0].value; + if (arg.value[0].type == "float") return arg.value[0].value[0].value; + return "str"; } \ No newline at end of file diff --git a/tests/e2e/compiler/numeric.test.ts b/tests/e2e/compiler/numeric.test.ts index 9311509..a0234d3 100644 --- a/tests/e2e/compiler/numeric.test.ts +++ b/tests/e2e/compiler/numeric.test.ts @@ -1,18 +1,55 @@ /// -import { fail, assertEquals, assertNotEquals, assert } from "https://deno.land/std@0.201.0/assert/mod.ts"; +import { fail, assertNotEquals, assert } from "https://deno.land/std@0.201.0/assert/mod.ts"; import * as CompilerFunc from "~/compiler/function.ts"; import Package from "~/compiler/package.ts"; import Project from "~/compiler/project.ts"; import { FuncRef } from "~/wasm/funcRef.ts"; -const decoder = new TextDecoder(); +const source = ` +fn left(): f32 { + // (-2.5 % 2.0) * -3.0; + return -2.5 % 2.0 * -3.0; +} + +fn right(): i32 { + // 10.0 - 0.5 - 8.0 + // return 10.0 - 1.0 / 2.0 % 10.0 - 8.0; + return 10 - 0 - 8; +} + +fn main(): i32 { + + // if ( 10.0 - ( 3.0 / 2.0 ) - 8.0 != 10.0 - 3.0 / 2.0 - 8.0 ) { + // return 20; + // }; + + // (-2.5 % 2.0) * -3.0 == 10.0 - ((1.0 / 2.0) % 10.0) - 8.0; + // 1.5 == 1.5 + // true == 1 + + // doing this in a single expression to also ensure == is applied correctly + if ( (-2.5 % 2.0) * -3.0 ) != ( 10.0 - ( (1.0 / 2.0) % 10.0 ) - 8.0 ) { + return 29; + }; -const goalStdout = ""; + if ( (-2.5 % 2.0) * -3.0 != 10.0 - ( (1.0 / 2.0) % 10.0 ) - 8.0 ) { + return 33; + }; -const source = ` -fn main(): f32 { - return -3.5 % 2.0; + if ( (-2.5 % 2.0) * -3.0 != 10.0 - ( 1.0 / 2.0 % 10.0 ) - 8.0 ) { + return 37; + }; + + if ( -2.5 % 2.0 * -3.0 != 10.0 - ( 1.0 / 2.0 % 10.0 ) - 8.0 ) { + return 41; + }; + + if ( -2.5 % 2.0 * -3.0 != 10.0 - 1.0 / 2.0 % 10.0 - 8.0 ) { + return 45; + }; + + return 0; }`; Deno.test(`Numeric logic test`, async () => { @@ -27,48 +64,36 @@ Deno.test(`Numeric logic test`, async () => { assertNotEquals(mainFunc.ref, null, "Main function hasn't compiled"); project.module.exportFunction("_start", mainFunc.ref as FuncRef); - let stdout = ""; - let memory: WebAssembly.Memory; - - const imports = { - wasi_snapshot_preview1: { - fd_write: (fd: number, iovs: number, iovs_len: number, n_written: number) => { - const memoryArray = new Int32Array(memory.buffer); - const byteArray = new Uint8Array(memory.buffer); - for (let iovIdx = 0; iovIdx < iovs_len; iovIdx++) { - const bufPtr = memoryArray.at(iovs/4 + iovIdx*2) || 0; - const bufLen = memoryArray.at(iovs/4 + iovIdx*2 + 1) || 0; - const data = decoder.decode(byteArray.slice(bufPtr, bufPtr + bufLen)); - stdout += data; - } - return 0; // Return 0 to indicate success - } - } - }; + const left = mainFile.namespace["left"]; + assert(left instanceof CompilerFunc.default, "Missing left function"); + left.compile(); + assertNotEquals(left.ref, null, "Left function hasn't compiled"); + project.module.exportFunction("left", left.ref as FuncRef); - // Load the wasm module - const wasmModule = new WebAssembly.Module(project.module.toBinary()); + const right = mainFile.namespace["right"]; + assert(right instanceof CompilerFunc.default, "Missing right function"); + right.compile(); + assertNotEquals(right.ref, null, "Right function hasn't compiled"); + project.module.exportFunction("right", right.ref as FuncRef); - try { - // Instantiate the wasm module - const instance = await WebAssembly.instantiate(wasmModule, imports); + const wasmModule = new WebAssembly.Module(project.module.toBinary()); + const instance = await WebAssembly.instantiate(wasmModule, {}); - const exports = instance.exports; - memory = exports.memory as WebAssembly.Memory; + const exports = instance.exports; - // Call the _start function - if (typeof exports._start === "function") { - (exports._start as Function)() as any; - } else { - fail(`Expected _start to be a function`); - } + // Call the _start function + let main: () => number = typeof exports._start === "function" + ? exports._start as any + : fail(`Expected _start to be a function`); - // Check stdout - assertEquals(stdout, goalStdout); + const code = main() as number; + if (code !== 0) { + const leftFn: () => number = exports.left as any; + assert(leftFn instanceof Function, "Missing left function"); - } catch (err) { - // If there's an error, the test will fail - fail(`Failed to run wasm module: ${err}`); - } + const rightFn: () => number = exports.right as any; + assert(rightFn instanceof Function, "Missing right function"); + fail(`equivalence checks failed ${leftFn()} != ${rightFn()} at line ${code}`); + }; }); \ No newline at end of file