diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f06235c --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +node_modules +dist diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..7bdd046 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1694529238, + "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1699099776, + "narHash": "sha256-X09iKJ27mGsGambGfkKzqvw5esP1L/Rf8H3u3fCqIiU=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "85f1ba3e51676fa8cc604a3d863d729026a6b8eb", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..1b374bb --- /dev/null +++ b/flake.nix @@ -0,0 +1,21 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils, ... }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + }; + in + { + devShell = pkgs.mkShell { + buildInputs = [ + pkgs.nodejs + ]; + }; + }); +} diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..2a1baa3 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,386 @@ +{ + "name": "auto-diff-ts", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "auto-diff-ts", + "version": "1.0.0", + "license": "ISC", + "devDependencies": { + "@swc-node/register": "^1.6.8", + "@types/node": "^20.9.0", + "typescript": "^5.2.2" + } + }, + "node_modules/@swc-node/core": { + "version": "1.10.6", + "resolved": "https://registry.npmjs.org/@swc-node/core/-/core-1.10.6.tgz", + "integrity": "sha512-lDIi/rPosmKIknWzvs2/Fi9zWRtbkx8OJ9pQaevhsoGzJSal8Pd315k1W5AIrnknfdAB4HqRN12fk6AhqnrEEw==", + "dev": true, + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + }, + "peerDependencies": { + "@swc/core": ">= 1.3" + } + }, + "node_modules/@swc-node/register": { + "version": "1.6.8", + "resolved": "https://registry.npmjs.org/@swc-node/register/-/register-1.6.8.tgz", + "integrity": "sha512-74ijy7J9CWr1Z88yO+ykXphV29giCrSpANQPQRooE0bObpkTO1g4RzQovIfbIaniBiGDDVsYwDoQ3FIrCE8HcQ==", + "dev": true, + "dependencies": { + "@swc-node/core": "^1.10.6", + "@swc-node/sourcemap-support": "^0.3.0", + "colorette": "^2.0.19", + "debug": "^4.3.4", + "pirates": "^4.0.5", + "tslib": "^2.5.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + }, + "peerDependencies": { + "@swc/core": ">= 1.3", + "typescript": ">= 4.3" + } + }, + "node_modules/@swc-node/sourcemap-support": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/@swc-node/sourcemap-support/-/sourcemap-support-0.3.0.tgz", + "integrity": "sha512-gqBJSmJMWomZFxlppaKea7NeAqFrDrrS0RMt24No92M3nJWcyI9YKGEQKl+EyJqZ5gh6w1s0cTklMHMzRwA1NA==", + "dev": true, + "dependencies": { + "source-map-support": "^0.5.21", + "tslib": "^2.5.0" + } + }, + "node_modules/@swc/core": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.3.96.tgz", + "integrity": "sha512-zwE3TLgoZwJfQygdv2SdCK9mRLYluwDOM53I+dT6Z5ZvrgVENmY3txvWDvduzkV+/8IuvrRbVezMpxcojadRdQ==", + "dev": true, + "hasInstallScript": true, + "peer": true, + "dependencies": { + "@swc/counter": "^0.1.1", + "@swc/types": "^0.1.5" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/swc" + }, + "optionalDependencies": { + "@swc/core-darwin-arm64": "1.3.96", + "@swc/core-darwin-x64": "1.3.96", + "@swc/core-linux-arm-gnueabihf": "1.3.96", + "@swc/core-linux-arm64-gnu": "1.3.96", + "@swc/core-linux-arm64-musl": "1.3.96", + "@swc/core-linux-x64-gnu": "1.3.96", + "@swc/core-linux-x64-musl": "1.3.96", + "@swc/core-win32-arm64-msvc": "1.3.96", + "@swc/core-win32-ia32-msvc": "1.3.96", + "@swc/core-win32-x64-msvc": "1.3.96" + }, + "peerDependencies": { + "@swc/helpers": "^0.5.0" + }, + "peerDependenciesMeta": { + "@swc/helpers": { + "optional": true + } + } + }, + "node_modules/@swc/core-darwin-arm64": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-darwin-arm64/-/core-darwin-arm64-1.3.96.tgz", + "integrity": "sha512-8hzgXYVd85hfPh6mJ9yrG26rhgzCmcLO0h1TIl8U31hwmTbfZLzRitFQ/kqMJNbIBCwmNH1RU2QcJnL3d7f69A==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-darwin-x64": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-darwin-x64/-/core-darwin-x64-1.3.96.tgz", + "integrity": "sha512-mFp9GFfuPg+43vlAdQZl0WZpZSE8sEzqL7sr/7Reul5McUHP0BaLsEzwjvD035ESfkY8GBZdLpMinblIbFNljQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm-gnueabihf": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm-gnueabihf/-/core-linux-arm-gnueabihf-1.3.96.tgz", + "integrity": "sha512-8UEKkYJP4c8YzYIY/LlbSo8z5Obj4hqcv/fUTHiEePiGsOddgGf7AWjh56u7IoN/0uEmEro59nc1ChFXqXSGyg==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm64-gnu": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-gnu/-/core-linux-arm64-gnu-1.3.96.tgz", + "integrity": "sha512-c/IiJ0s1y3Ymm2BTpyC/xr6gOvoqAVETrivVXHq68xgNms95luSpbYQ28rqaZC8bQC8M5zdXpSc0T8DJu8RJGw==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-arm64-musl": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-arm64-musl/-/core-linux-arm64-musl-1.3.96.tgz", + "integrity": "sha512-i5/UTUwmJLri7zhtF6SAo/4QDQJDH2fhYJaBIUhrICmIkRO/ltURmpejqxsM/ye9Jqv5zG7VszMC0v/GYn/7BQ==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-x64-gnu": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-gnu/-/core-linux-x64-gnu-1.3.96.tgz", + "integrity": "sha512-USdaZu8lTIkm4Yf9cogct/j5eqtdZqTgcTib4I+NloUW0E/hySou3eSyp3V2UAA1qyuC72ld1otXuyKBna0YKQ==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-linux-x64-musl": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-linux-x64-musl/-/core-linux-x64-musl-1.3.96.tgz", + "integrity": "sha512-QYErutd+G2SNaCinUVobfL7jWWjGTI0QEoQ6hqTp7PxCJS/dmKmj3C5ZkvxRYcq7XcZt7ovrYCTwPTHzt6lZBg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-arm64-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-arm64-msvc/-/core-win32-arm64-msvc-1.3.96.tgz", + "integrity": "sha512-hjGvvAduA3Un2cZ9iNP4xvTXOO4jL3G9iakhFsgVhpkU73SGmK7+LN8ZVBEu4oq2SUcHO6caWvnZ881cxGuSpg==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-ia32-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-ia32-msvc/-/core-win32-ia32-msvc-1.3.96.tgz", + "integrity": "sha512-Far2hVFiwr+7VPCM2GxSmbh3ikTpM3pDombE+d69hkedvYHYZxtTF+2LTKl/sXtpbUnsoq7yV/32c9R/xaaWfw==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/core-win32-x64-msvc": { + "version": "1.3.96", + "resolved": "https://registry.npmjs.org/@swc/core-win32-x64-msvc/-/core-win32-x64-msvc-1.3.96.tgz", + "integrity": "sha512-4VbSAniIu0ikLf5mBX81FsljnfqjoVGleEkCQv4+zRlyZtO3FHoDPkeLVoy6WRlj7tyrRcfUJ4mDdPkbfTO14g==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ], + "peer": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/@swc/counter": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@swc/counter/-/counter-0.1.2.tgz", + "integrity": "sha512-9F4ys4C74eSTEUNndnER3VJ15oru2NumfQxS8geE+f3eB5xvfxpWyqE5XlVnxb/R14uoXi6SLbBwwiDSkv+XEw==", + "dev": true, + "peer": true + }, + "node_modules/@swc/types": { + "version": "0.1.5", + "resolved": "https://registry.npmjs.org/@swc/types/-/types-0.1.5.tgz", + "integrity": "sha512-myfUej5naTBWnqOCc/MdVOLVjXUXtIA+NpDrDBKJtLLg2shUjBu3cZmB/85RyitKc55+lUUyl7oRfLOvkr2hsw==", + "dev": true, + "peer": true + }, + "node_modules/@types/node": { + "version": "20.9.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.9.0.tgz", + "integrity": "sha512-nekiGu2NDb1BcVofVcEKMIwzlx4NjHlcjhoxxKBNLtz15Y1z7MYf549DFvkHSId02Ax6kGwWntIBPC3l/JZcmw==", + "dev": true, + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/buffer-from": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/buffer-from/-/buffer-from-1.1.2.tgz", + "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", + "dev": true + }, + "node_modules/colorette": { + "version": "2.0.20", + "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz", + "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==", + "dev": true + }, + "node_modules/debug": { + "version": "4.3.4", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", + "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "dev": true, + "dependencies": { + "ms": "2.1.2" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/ms": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "dev": true + }, + "node_modules/pirates": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.6.tgz", + "integrity": "sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/source-map-support": { + "version": "0.5.21", + "resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz", + "integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==", + "dev": true, + "dependencies": { + "buffer-from": "^1.0.0", + "source-map": "^0.6.0" + } + }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "dev": true + }, + "node_modules/typescript": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", + "dev": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000..7ac95b8 --- /dev/null +++ b/package.json @@ -0,0 +1,16 @@ +{ + "name": "auto-diff-ts", + "version": "1.0.0", + "description": "", + "main": "dist/index.js", + "scripts": { + "test": "node --require @swc-node/register src/test.ts" + }, + "author": "", + "license": "ISC", + "devDependencies": { + "@swc-node/register": "^1.6.8", + "@types/node": "^20.9.0", + "typescript": "^5.2.2" + } +} diff --git a/src/eval.ts b/src/eval.ts new file mode 100644 index 0000000..9c43c59 --- /dev/null +++ b/src/eval.ts @@ -0,0 +1,145 @@ +import type { Input } from './operations'; + +export type InputVars = { + [V in Vars]: number; +}; + +export function evalDfs( + op: Input, + vars: InputVars, + memo: Map, number> = new Map() +): [number, Map, number>] { + if (memo.has(op)) { + return [memo.get(op)!, memo]; + } else if (op.type === "var") { + console.log(`${op.name}: ${vars[op.name]}`); + const result = vars[op.name]; + memo.set(op, result); + return [result, memo]; + } else { + const inputs = op.inputs.map((i) => evalDfs(i, vars, memo)[0]); + const result = op.value(inputs); + console.log(`${op.name}: ${result}`); + memo.set(op, result); + return [result, memo]; + } +} + +export type Value = { + x: number; + dx: number; +}; + +export function evalForward( + op: Input, + wrt: Vars, + vars: InputVars, + memo: Map, Value> = new Map() +): [Value, Map, Value>] { + if (memo.has(op)) { + return [memo.get(op)!, memo]; + } else if (op.type === "var") { + const x = vars[op.name]; + const dx = op.name === wrt ? 1 : 0; + const result = { x, dx }; + console.log(`${op.name}: ${JSON.stringify(result)}`); + memo.set(op, result); + return [result, memo]; + } else { + const inputs = op.inputs.map((i) => evalForward(i, wrt, vars, memo)[0]); + const inputVals = inputs.map((i) => i.x); + const x = op.value(inputVals); + const dx = op + .deriv(inputVals) + .reduce((acc, n, idx) => acc + inputs[idx].dx * n, 0); + const result = { x, dx }; + console.log(`${op.name}: ${JSON.stringify(result)}`); + memo.set(op, result); + return [result, memo]; + } +} + +export type Gradient = { [K in Vars]: number }; + +export function evalReverse( + op: Input, + vars: InputVars +): [number, Gradient] { + const valueMemo: Map, number> = new Map(); + const visitorMap: Map, Set>> = new Map(); + const deepestVisit: Map, number> = new Map(); + let maxDepth = 0; + function firstTraversal(op: Input, depth: number = 0): number { + if (maxDepth < depth) { + maxDepth = depth; + } + if (!deepestVisit.has(op)) { + deepestVisit.set(op, depth); + } else if (deepestVisit.get(op)! < depth) { + deepestVisit.set(op, depth); + } + if (valueMemo.has(op)) { + return valueMemo.get(op)!; + } else if (op.type === "var") { + const result = vars[op.name]; + valueMemo.set(op, result); + return result; + } else { + const inputs = op.inputs.map((i) => { + if (!visitorMap.has(i)) { + const initialVisitor = new Set([op]); + visitorMap.set(i, initialVisitor); + } else { + visitorMap.get(i)!.add(op); + } + return firstTraversal(i, depth + 1); + }); + const result = op.value(inputs); + valueMemo.set(op, result); + return result; + } + } + firstTraversal(op); + + if (op.type === "var") { + return [ + valueMemo.get(op)!, + { + [op.name]: 1, + } as InputVars, + ]; + } + + const result: Partial> = {}; + + const backMemo: Map, number> = new Map(); + backMemo.set(op, 1); + let bfs: Input[] = [op]; + for (let depth = 0; depth <= maxDepth; depth++) { + let nextBfs: Input[] = []; + for (let i = 0; i < bfs.length; i++) { + const o = bfs[i]; + if (o.type === "var") { + result[o.name] = backMemo.get(o)!; + continue; + } + const inputs = o.inputs.map((i) => valueMemo.get(i)!); + const derivs = o.deriv(inputs); + const u = backMemo.get(o)!; + for (let j = 0; j < o.inputs.length; j++) { + const input = o.inputs[j]; + const cur = backMemo.get(input) ?? 0; + const dx = u * derivs[j]; + console.log(input.name, dx); + backMemo.set(input, cur + dx); + + if (deepestVisit.get(input)! === depth + 1) { + nextBfs.push(input); + } + } + } + bfs = nextBfs; + } + + return [valueMemo.get(op)!, result as Gradient]; +} diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..7df7ee1 --- /dev/null +++ b/src/index.ts @@ -0,0 +1,2 @@ +export * from './eval'; +export * from './operations'; diff --git a/src/operations.ts b/src/operations.ts new file mode 100644 index 0000000..ce3cba3 --- /dev/null +++ b/src/operations.ts @@ -0,0 +1,110 @@ +export type Variable = { + type: "var"; + name: Name; +}; + +export function variable(name: Name): Variable { + return { + type: "var", + name, + }; +} + +export interface Op { + type: "op"; + name: string; + inputs: Input[]; + value(inputs: number[]): number; + deriv(inputs: number[]): number[]; +} + +export type Input = { type: "var"; name: Vars } | Op; + +export function add( + a: Input, + b: Input +): Op { + return { + type: "op", + name: `(${a.name} + ${b.name})`, + inputs: [a, b], + value: ([x, y]) => x + y, + deriv: () => [1, 1], + }; +} + +export function constadd(k: number, a: Input): Op { + return { + type: "op", + name: `(${a.name} + ${k})`, + inputs: [a], + value: ([x]) => k + x, + deriv: () => [1], + }; +} + +export function mult( + a: Input, + b: Input +): Op { + return { + type: "op", + name: `(${a.name} * ${b.name})`, + inputs: [a, b], + value: ([x, y]) => x * y, + deriv: ([x, y]) => [y, x], + }; +} + +export function constmult(k: number, a: Input): Op { + return { + type: "op", + name: `${k}${a.name}`, + inputs: [a], + value: ([x]) => k * x, + deriv: () => [k], + }; +} + +export function div( + a: Input, + b: Input +): Op { + return { + type: "op", + name: `(${a.name} / ${b.name})`, + inputs: [a, b], + value: ([x, y]) => x / y, + deriv: ([x, y]) => [1 / y, x * Math.log(y)], + }; +} + +export function constdiv(k: number, a: Input): Op { + return { + type: "op", + name: ` ${k} / ${a.name}`, + inputs: [a], + value: ([x]) => k / x, + deriv: ([x]) => [k * Math.log(x)], + }; +} + +export function constpow(a: Input, k: number): Op { + return { + type: "op", + name: `${a.name}^${k}`, + inputs: [a], + value: ([x]) => Math.pow(x, k), + deriv: ([x]) => [k * Math.pow(x, k - 1)], + }; +} + +export function exp(a: Input): Op { + return { + type: "op", + name: `e^${a.name}`, + inputs: [a], + value: ([x]) => Math.exp(x), + deriv: ([x]) => [Math.exp(x)], + }; +} diff --git a/src/test.ts b/src/test.ts new file mode 100644 index 0000000..df6702f --- /dev/null +++ b/src/test.ts @@ -0,0 +1,44 @@ +import assert from 'node:assert'; +import { test } from 'node:test'; + +import * as op from './operations'; +import { evalDfs, evalForward, evalReverse } from './eval'; + +const x1 = op.variable("x1"); +const x2 = op.variable("x2"); + +const x1sq = op.constpow(x1, 2); +const x2sq = op.constpow(x2, 2); + +const x1x2 = op.add(x1sq, x2sq); +const minhalf = op.constmult(-0.5, x1x2); +const expx1x2 = op.exp(minhalf); + +const f = op.mult(x1, expx1x2); + +function assertApprox(expected: number, actual: number, epsilon: number = 0.003) { + const diff = Math.abs(expected - actual); + assert(diff <= epsilon, `Expected ${actual} to be approximately ${expected}`); +} + +test('simple evaluation', () => { + const [result] = evalDfs(f, { x1: 2, x2: 0.5 }); + assertApprox(0.24, result); +}) + +test('forward mode', () => { + const [result] = evalForward(f, "x1", { x1: 2, x2: 0.5 }); + assertApprox(0.24, result.x); + assertApprox(-0.36, result.dx); + + const [result2] = evalForward(f, "x2", { x1: 2, x2: 0.5 }); + assertApprox(0.24, result2.x); + assertApprox(-0.12, result2.dx); +}) + +test('reverse mode', () => { + const [value, gradients] = evalReverse(f, { x1: 2, x2: 0.5 }); + assertApprox(0.24, value); + assertApprox(-0.36, gradients.x1); + assertApprox(-0.12, gradients.x2); +}) \ No newline at end of file