From 808719f570201ccb50adfd642b522db54a44f478 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 23 Oct 2024 11:04:25 -0700 Subject: [PATCH 01/11] Start Supabase --- libs/checkpoint-supabase/.gitignore | 7 + libs/checkpoint-supabase/langchain.config.js | 21 ++ libs/checkpoint-supabase/package.json | 91 +++++ libs/checkpoint-supabase/src/index.ts | 344 ++++++++++++++++++ libs/checkpoint-supabase/tsconfig.cjs.json | 8 + libs/checkpoint-supabase/tsconfig.json | 23 ++ libs/checkpoint-supabase/turbo.json | 11 + libs/checkpoint-validation/package.json | 1 + .../src/tests/supabase.spec.ts | 5 + .../src/tests/supabase_initializer.ts | 30 ++ yarn.lock | 140 +++++++ 11 files changed, 681 insertions(+) create mode 100644 libs/checkpoint-supabase/.gitignore create mode 100644 libs/checkpoint-supabase/langchain.config.js create mode 100644 libs/checkpoint-supabase/package.json create mode 100644 libs/checkpoint-supabase/src/index.ts create mode 100644 libs/checkpoint-supabase/tsconfig.cjs.json create mode 100644 libs/checkpoint-supabase/tsconfig.json create mode 100644 libs/checkpoint-supabase/turbo.json create mode 100644 libs/checkpoint-validation/src/tests/supabase.spec.ts create mode 100644 libs/checkpoint-validation/src/tests/supabase_initializer.ts diff --git a/libs/checkpoint-supabase/.gitignore b/libs/checkpoint-supabase/.gitignore new file mode 100644 index 00000000..c10034e2 --- /dev/null +++ b/libs/checkpoint-supabase/.gitignore @@ -0,0 +1,7 @@ +index.cjs +index.js +index.d.ts +index.d.cts +node_modules +dist +.yarn diff --git a/libs/checkpoint-supabase/langchain.config.js b/libs/checkpoint-supabase/langchain.config.js new file mode 100644 index 00000000..fe70c345 --- /dev/null +++ b/libs/checkpoint-supabase/langchain.config.js @@ -0,0 +1,21 @@ +import { resolve, dirname } from "node:path"; +import { fileURLToPath } from "node:url"; + +/** + * @param {string} relativePath + * @returns {string} + */ +function abs(relativePath) { + return resolve(dirname(fileURLToPath(import.meta.url)), relativePath); +} + +export const config = { + internals: [/node\:/, /@langchain\/core\//, /async_hooks/], + entrypoints: { + index: "index" + }, + tsConfigPath: resolve("./tsconfig.json"), + cjsSource: "./dist-cjs", + cjsDestination: "./dist", + abs, +}; diff --git a/libs/checkpoint-supabase/package.json b/libs/checkpoint-supabase/package.json new file mode 100644 index 00000000..3b094ad1 --- /dev/null +++ b/libs/checkpoint-supabase/package.json @@ -0,0 +1,91 @@ +{ + "name": "@langchain/langgraph-checkpoint-supabase", + "version": "0.1.2", + "description": "LangGraph", + "type": "module", + "engines": { + "node": ">=18" + }, + "main": "./index.js", + "types": "./index.d.ts", + "repository": { + "type": "git", + "url": "git@github.com:langchain-ai/langgraphjs.git" + }, + "scripts": { + "build": "yarn turbo:command build:internal --filter=@langchain/langgraph-checkpoint-supabase", + "build:internal": "yarn clean && yarn lc_build --create-entrypoints --pre --tree-shaking", + "clean": "rm -rf dist/ dist-cjs/ .turbo/", + "lint:eslint": "NODE_OPTIONS=--max-old-space-size=4096 eslint --cache --ext .ts,.js src/", + "lint:dpdm": "dpdm --exit-code circular:1 --no-warning --no-tree src/*.ts src/**/*.ts", + "lint": "yarn lint:eslint && yarn lint:dpdm", + "lint:fix": "yarn lint:eslint --fix && yarn lint:dpdm", + "prepack": "yarn build", + "test": "NODE_OPTIONS=--experimental-vm-modules jest --testPathIgnorePatterns=\\.int\\.test.ts --testTimeout 30000 --maxWorkers=50%", + "test:watch": "NODE_OPTIONS=--experimental-vm-modules jest --watch --testPathIgnorePatterns=\\.int\\.test.ts", + "test:single": "NODE_OPTIONS=--experimental-vm-modules yarn run jest --config jest.config.cjs --testTimeout 100000", + "test:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.int\\.test.ts --testTimeout 100000 --maxWorkers=50%", + "format": "prettier --config .prettierrc --write \"src\"", + "format:check": "prettier --config .prettierrc --check \"src\"" + }, + "author": "LangChain", + "license": "MIT", + "dependencies": { + "@supabase/supabase-js": "^2.45.6" + }, + "peerDependencies": { + "@langchain/core": ">=0.2.31 <0.4.0", + "@langchain/langgraph-checkpoint": "~0.0.6" + }, + "devDependencies": { + "@jest/globals": "^29.5.0", + "@langchain/langgraph-checkpoint": "workspace:*", + "@langchain/scripts": ">=0.1.3 <0.2.0", + "@swc/core": "^1.3.90", + "@swc/jest": "^0.2.29", + "@tsconfig/recommended": "^1.0.3", + "@types/uuid": "^10", + "@typescript-eslint/eslint-plugin": "^6.12.0", + "@typescript-eslint/parser": "^6.12.0", + "dotenv": "^16.3.1", + "dpdm": "^3.12.0", + "eslint": "^8.33.0", + "eslint-config-airbnb-base": "^15.0.0", + "eslint-config-prettier": "^8.6.0", + "eslint-plugin-import": "^2.29.1", + "eslint-plugin-jest": "^28.8.0", + "eslint-plugin-no-instanceof": "^1.0.1", + "eslint-plugin-prettier": "^4.2.1", + "jest": "^29.5.0", + "jest-environment-node": "^29.6.4", + "prettier": "^2.8.3", + "release-it": "^17.6.0", + "rollup": "^4.23.0", + "ts-jest": "^29.1.0", + "tsx": "^4.7.0", + "typescript": "^4.9.5 || ^5.4.5" + }, + "publishConfig": { + "access": "public", + "registry": "https://registry.npmjs.org/" + }, + "exports": { + ".": { + "types": { + "import": "./index.d.ts", + "require": "./index.d.cts", + "default": "./index.d.ts" + }, + "import": "./index.js", + "require": "./index.cjs" + }, + "./package.json": "./package.json" + }, + "files": [ + "dist/", + "index.cjs", + "index.js", + "index.d.ts", + "index.d.cts" + ] +} diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts new file mode 100644 index 00000000..7be0c66c --- /dev/null +++ b/libs/checkpoint-supabase/src/index.ts @@ -0,0 +1,344 @@ +import type { SupabaseClient } from "@supabase/supabase-js"; + +import type { RunnableConfig } from "@langchain/core/runnables"; +import { + BaseCheckpointSaver, + type Checkpoint, + type CheckpointListOptions, + type CheckpointTuple, + type SerializerProtocol, + type PendingWrite, + type CheckpointMetadata, +} from "@langchain/langgraph-checkpoint"; + +interface CheckpointRow { + checkpoint: string; + metadata: string; + parent_checkpoint_id?: string; + thread_id: string; + checkpoint_id: string; + checkpoint_ns?: string; + type?: string; +} + +interface WritesRow { + thread_id: string; + checkpoint_ns: string; + checkpoint_id: string; + task_id: string; + idx: number; + channel: string; + type?: string; + value?: string; +} + +// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys +// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list +// of keys that we use is out of sync with the `CheckpointMetadata` type. +const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const; + +type CheckKeys = [K[number]] extends [ + keyof T +] + ? [keyof T] extends [K[number]] + ? K + : never + : never; + +function validateKeys( + keys: CheckKeys +): K { + return keys; +} + +// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the +// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in +// `CheckpointMetadata` +const validCheckpointMetadataKeys = validateKeys< + CheckpointMetadata, + typeof checkpointMetadataKeys +>(checkpointMetadataKeys); + +export class SupaSaver extends BaseCheckpointSaver { + constructor(private client: SupabaseClient, serde?: SerializerProtocol) { + super(serde); + } + + async getTuple(config: RunnableConfig): Promise { + const { + thread_id, + checkpoint_ns = "", + checkpoint_id, + } = config.configurable ?? {}; + let res; + if (checkpoint_id) { + // data = this.db + // .prepare( + // `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` + // ) + // .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; + res = await this.client + .from("chat_session_checkpoints") + .select("*") + .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d") + .eq("id", checkpoint_id) + .eq("thread_id", thread_id) + .eq("checkpoint_ns", checkpoint_ns) + .maybeSingle() + .throwOnError(); + } else { + // row = this.db + // .prepare( + // `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1` + // ) + // .get(thread_id, checkpoint_ns) as CheckpointRow; + res = await this.client + .from("chat_session_checkpoints") + .select("*") + .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d") + .eq("thread_id", thread_id) + .eq("checkpoint_ns", checkpoint_ns) + .maybeSingle() + .throwOnError(); + } + const row = res?.data?.[0]; + if (row === undefined) { + return undefined; + } + let finalConfig = config; + if (!checkpoint_id) { + finalConfig = { + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.checkpoint_id, + }, + }; + } + if ( + finalConfig.configurable?.thread_id === undefined || + finalConfig.configurable?.checkpoint_id === undefined + ) { + throw new Error("Missing thread_id or checkpoint_id"); + } + // find any pending writes + // const pendingWritesRows = this.db + // .prepare( + // `SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` + // ) + // .all( + // finalConfig.configurable.thread_id.toString(), + // checkpoint_ns, + // finalConfig.configurable.checkpoint_id.toString() + // ) as WritesRow[]; + const pendingWritesRes = await this.client + .from("chat_session_writes") + .select("*") + .eq("thread_id", finalConfig.configurable.thread_id.toString()) + .eq("checkpoint_ns", checkpoint_ns) + .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()); + const pendingWritesRows = pendingWritesRes.data ?? []; + const pendingWrites = await Promise.all( + pendingWritesRows.map(async (row) => { + return [ + row.task_id, + row.channel, + await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), + ] as [string, string, unknown]; + }) + ); + return { + config: finalConfig, + checkpoint: (await this.serde.loadsTyped( + row.type ?? "json", + row.checkpoint + )) as Checkpoint, + metadata: (await this.serde.loadsTyped( + row.type ?? "json", + row.metadata + )) as CheckpointMetadata, + parentConfig: row.parent_checkpoint_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } + : undefined, + pendingWrites, + }; + } + + async *list( + config: RunnableConfig, + options?: CheckpointListOptions + ): AsyncGenerator { + const { limit, before, filter } = options ?? {}; + const thread_id = config.configurable?.thread_id; + const checkpoint_ns = config.configurable?.checkpoint_ns; + + let sql = + `SELECT\n` + + " thread_id,\n" + + " checkpoint_ns,\n" + + " checkpoint_id,\n" + + " parent_checkpoint_id,\n" + + " type,\n" + + " checkpoint,\n" + + " metadata\n" + + "FROM checkpoints\n"; + + const whereClause: string[] = []; + + if (thread_id) { + whereClause.push("thread_id = ?"); + } + + if (checkpoint_ns !== undefined && checkpoint_ns !== null) { + whereClause.push("checkpoint_ns = ?"); + } + + if (before?.configurable?.checkpoint_id !== undefined) { + whereClause.push("checkpoint_id < ?"); + } + + const sanitizedFilter = Object.fromEntries( + Object.entries(filter ?? {}).filter( + ([key, value]) => + value !== undefined && + validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata) + ) + ); + + whereClause.push( + ...Object.entries(sanitizedFilter).map( + ([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?` + ) + ); + + if (whereClause.length > 0) { + sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`; + } + + sql += "\nORDER BY checkpoint_id DESC"; + + if (limit) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided + } + + const args = [ + thread_id, + checkpoint_ns, + before?.configurable?.checkpoint_id, + ...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)), + ].filter((value) => value !== undefined && value !== null); + + const rows: CheckpointRow[] = this.db + .prepare(sql) + .all(...args) as CheckpointRow[]; + + if (rows) { + for (const row of rows) { + yield { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.checkpoint_id, + }, + }, + checkpoint: (await this.serde.loadsTyped( + row.type ?? "json", + row.checkpoint + )) as Checkpoint, + metadata: (await this.serde.loadsTyped( + row.type ?? "json", + row.metadata + )) as CheckpointMetadata, + parentConfig: row.parent_checkpoint_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } + : undefined, + }; + } + } + } + + async put( + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata + ): Promise { + const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); + const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); + if (type1 !== type2) { + throw new Error( + "Failed to serialized checkpoint and metadata to the same type." + ); + } + const row = [ + config.configurable?.thread_id?.toString(), + config.configurable?.checkpoint_ns, + checkpoint.id, + config.configurable?.checkpoint_id, + type1, + serializedCheckpoint, + serializedMetadata, + ]; + + this.db + .prepare( + `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` + ) + .run(...row); + + return { + configurable: { + thread_id: config.configurable?.thread_id, + checkpoint_ns: config.configurable?.checkpoint_ns, + checkpoint_id: checkpoint.id, + }, + }; + } + + async putWrites( + config: RunnableConfig, + writes: PendingWrite[], + taskId: string + ): Promise { + const stmt = this.db.prepare(` + INSERT OR REPLACE INTO writes + (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `); + + const transaction = this.db.transaction((rows) => { + for (const row of rows) { + stmt.run(...row); + } + }); + + const rows = writes.map((write, idx) => { + const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); + return [ + config.configurable?.thread_id, + config.configurable?.checkpoint_ns, + config.configurable?.checkpoint_id, + taskId, + idx, + write[0], + type, + serializedWrite, + ]; + }); + + transaction(rows); + } +} diff --git a/libs/checkpoint-supabase/tsconfig.cjs.json b/libs/checkpoint-supabase/tsconfig.cjs.json new file mode 100644 index 00000000..3b7026ea --- /dev/null +++ b/libs/checkpoint-supabase/tsconfig.cjs.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "module": "commonjs", + "declaration": false + }, + "exclude": ["node_modules", "dist", "docs", "**/tests"] +} diff --git a/libs/checkpoint-supabase/tsconfig.json b/libs/checkpoint-supabase/tsconfig.json new file mode 100644 index 00000000..bc85d83b --- /dev/null +++ b/libs/checkpoint-supabase/tsconfig.json @@ -0,0 +1,23 @@ +{ + "extends": "@tsconfig/recommended", + "compilerOptions": { + "outDir": "../dist", + "rootDir": "./src", + "target": "ES2021", + "lib": ["ES2021", "ES2022.Object", "DOM"], + "module": "ES2020", + "moduleResolution": "nodenext", + "esModuleInterop": true, + "declaration": true, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "useDefineForClassFields": true, + "strictPropertyInitialization": false, + "allowJs": true, + "strict": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "docs"] +} diff --git a/libs/checkpoint-supabase/turbo.json b/libs/checkpoint-supabase/turbo.json new file mode 100644 index 00000000..d1bb60a7 --- /dev/null +++ b/libs/checkpoint-supabase/turbo.json @@ -0,0 +1,11 @@ +{ + "extends": ["//"], + "tasks": { + "build": { + "outputs": ["**/dist/**"] + }, + "build:internal": { + "dependsOn": ["^build:internal"] + } + } +} diff --git a/libs/checkpoint-validation/package.json b/libs/checkpoint-validation/package.json index bc07827a..b615f07c 100644 --- a/libs/checkpoint-validation/package.json +++ b/libs/checkpoint-validation/package.json @@ -51,6 +51,7 @@ "@langchain/langgraph-checkpoint-mongodb": "workspace:*", "@langchain/langgraph-checkpoint-postgres": "workspace:*", "@langchain/langgraph-checkpoint-sqlite": "workspace:*", + "@langchain/langgraph-checkpoint-supabase": "workspace:*", "@langchain/scripts": ">=0.1.3 <0.2.0", "@testcontainers/mongodb": "^10.13.2", "@testcontainers/postgresql": "^10.13.2", diff --git a/libs/checkpoint-validation/src/tests/supabase.spec.ts b/libs/checkpoint-validation/src/tests/supabase.spec.ts new file mode 100644 index 00000000..84813940 --- /dev/null +++ b/libs/checkpoint-validation/src/tests/supabase.spec.ts @@ -0,0 +1,5 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +import { specTest } from "../spec/index.js"; +import { initializer } from "./supabase_initializer.js"; + +specTest(initializer); diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts new file mode 100644 index 00000000..7ad2b8ee --- /dev/null +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -0,0 +1,30 @@ +// eslint-disable-next-line import/no-extraneous-dependencies +import { SupaSaver } from "@langchain/langgraph-checkpoint-supabase"; +import { CheckpointerTestInitializer } from "../types.js"; +import { createClient } from "@supabase/supabase-js"; + +const SUPABASE_URL = process.env.SUPABASE_URL!; +const SUPABASE_KEY = process.env.SUPABASE_KEY!; + +export const initializer: CheckpointerTestInitializer = { + checkpointerName: "@langchain/langgraph-checkpoint-supabase", + + async createCheckpointer() { + const client = createClient(SUPABASE_URL, SUPABASE_KEY); + return new SupaSaver(client); + }, + + async afterAll() { + const client = createClient(SUPABASE_URL, SUPABASE_KEY); + await client + .from("chat_session_checkpoints") + .delete() + .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); + await client + .from("chat_session_writes") + .delete() + .neq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); + }, +}; + +export default initializer; diff --git a/yarn.lock b/yarn.lock index e8d385ea..67a08a52 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1768,6 +1768,43 @@ __metadata: languageName: unknown linkType: soft +"@langchain/langgraph-checkpoint-supabase@workspace:*, @langchain/langgraph-checkpoint-supabase@workspace:libs/checkpoint-supabase": + version: 0.0.0-use.local + resolution: "@langchain/langgraph-checkpoint-supabase@workspace:libs/checkpoint-supabase" + dependencies: + "@jest/globals": ^29.5.0 + "@langchain/langgraph-checkpoint": "workspace:*" + "@langchain/scripts": ">=0.1.3 <0.2.0" + "@supabase/supabase-js": ^2.45.6 + "@swc/core": ^1.3.90 + "@swc/jest": ^0.2.29 + "@tsconfig/recommended": ^1.0.3 + "@types/uuid": ^10 + "@typescript-eslint/eslint-plugin": ^6.12.0 + "@typescript-eslint/parser": ^6.12.0 + dotenv: ^16.3.1 + dpdm: ^3.12.0 + eslint: ^8.33.0 + eslint-config-airbnb-base: ^15.0.0 + eslint-config-prettier: ^8.6.0 + eslint-plugin-import: ^2.29.1 + eslint-plugin-jest: ^28.8.0 + eslint-plugin-no-instanceof: ^1.0.1 + eslint-plugin-prettier: ^4.2.1 + jest: ^29.5.0 + jest-environment-node: ^29.6.4 + prettier: ^2.8.3 + release-it: ^17.6.0 + rollup: ^4.23.0 + ts-jest: ^29.1.0 + tsx: ^4.7.0 + typescript: ^4.9.5 || ^5.4.5 + peerDependencies: + "@langchain/core": ">=0.2.31 <0.4.0" + "@langchain/langgraph-checkpoint": ~0.0.6 + languageName: unknown + linkType: soft + "@langchain/langgraph-checkpoint-validation@workspace:libs/checkpoint-validation": version: 0.0.0-use.local resolution: "@langchain/langgraph-checkpoint-validation@workspace:libs/checkpoint-validation" @@ -1778,6 +1815,7 @@ __metadata: "@langchain/langgraph-checkpoint-mongodb": "workspace:*" "@langchain/langgraph-checkpoint-postgres": "workspace:*" "@langchain/langgraph-checkpoint-sqlite": "workspace:*" + "@langchain/langgraph-checkpoint-supabase": "workspace:*" "@langchain/scripts": ">=0.1.3 <0.2.0" "@swc-node/register": ^1.10.9 "@swc/core": ^1.3.90 @@ -2821,6 +2859,77 @@ __metadata: languageName: node linkType: hard +"@supabase/auth-js@npm:2.65.1": + version: 2.65.1 + resolution: "@supabase/auth-js@npm:2.65.1" + dependencies: + "@supabase/node-fetch": ^2.6.14 + checksum: 5e4a9c4d94b5d8d3e4c6ea113eb4adf84d5bf0b187c775e4577693d18bfba4ffa6fdf9ef236e1f7a2cebf1696948cba1ec8cafd705a6493b63ecb7807cee86ac + languageName: node + linkType: hard + +"@supabase/functions-js@npm:2.4.3": + version: 2.4.3 + resolution: "@supabase/functions-js@npm:2.4.3" + dependencies: + "@supabase/node-fetch": ^2.6.14 + checksum: 1c2d58b498c19bd0c8984407f1d4c207ac6816df5e38c52f0d009a9ae55cfd80cc3b74b66414b386c3dc5c972b7db99452aeed545f9f5d6472ebb631274261a8 + languageName: node + linkType: hard + +"@supabase/node-fetch@npm:2.6.15, @supabase/node-fetch@npm:^2.6.14": + version: 2.6.15 + resolution: "@supabase/node-fetch@npm:2.6.15" + dependencies: + whatwg-url: ^5.0.0 + checksum: 9673b49236a56df49eb7ea5cb789cf4e8b1393069b84b4964ac052995e318a34872f428726d128f232139e17c3375a531e45e99edd3e96a25cce60d914b53879 + languageName: node + linkType: hard + +"@supabase/postgrest-js@npm:1.16.3": + version: 1.16.3 + resolution: "@supabase/postgrest-js@npm:1.16.3" + dependencies: + "@supabase/node-fetch": ^2.6.14 + checksum: e89f3d75b8d7253de19356c9f57ca1674cd09a62a5229bf80705450bebf0cbe0ca667333a5e349c13eb10a74dfcbb316d574399b734591d96389d5e2ff2f0801 + languageName: node + linkType: hard + +"@supabase/realtime-js@npm:2.10.7": + version: 2.10.7 + resolution: "@supabase/realtime-js@npm:2.10.7" + dependencies: + "@supabase/node-fetch": ^2.6.14 + "@types/phoenix": ^1.5.4 + "@types/ws": ^8.5.10 + ws: ^8.14.2 + checksum: fd0a39a096c691782732eac5a08f5b150c7fbb0b8d73e91c0d7a4df9accd5835a760d3ed984c4640e3c7a72e4e8ece31ce1bbeedd47e4aacb3476c5a53e95791 + languageName: node + linkType: hard + +"@supabase/storage-js@npm:2.7.1": + version: 2.7.1 + resolution: "@supabase/storage-js@npm:2.7.1" + dependencies: + "@supabase/node-fetch": ^2.6.14 + checksum: ed8f3a3178856c331b36588f4fff5cbb7f2f89977fff9716ab20b1977d13816bda5a887a316638f2a05ac35fdef46e18eab8a543d6113de76d3a06b15bf9ae8e + languageName: node + linkType: hard + +"@supabase/supabase-js@npm:^2.45.6": + version: 2.45.6 + resolution: "@supabase/supabase-js@npm:2.45.6" + dependencies: + "@supabase/auth-js": 2.65.1 + "@supabase/functions-js": 2.4.3 + "@supabase/node-fetch": 2.6.15 + "@supabase/postgrest-js": 1.16.3 + "@supabase/realtime-js": 2.10.7 + "@supabase/storage-js": 2.7.1 + checksum: d97c18180a7e4725615e6d22ab322eb6f68de0fe8b5bb9cdd921544c16274917b5d7594ff46fb4401d175329e768e82a9ad8dd3c362a6c5ab2231afb021481b9 + languageName: node + linkType: hard + "@swc-node/core@npm:^1.13.3": version: 1.13.3 resolution: "@swc-node/core@npm:1.13.3" @@ -3594,6 +3703,13 @@ __metadata: languageName: node linkType: hard +"@types/phoenix@npm:^1.5.4": + version: 1.6.5 + resolution: "@types/phoenix@npm:1.6.5" + checksum: b87416393159f0ba2812875fc2721914a3284cde8b1f263dfcd46f4149dae7f4efc2bfa062d558c8bbfb7ae2a9d802487b0dd4744ff08799386cbc49c19368f0 + languageName: node + linkType: hard + "@types/qs@npm:^6.9.15": version: 6.9.15 resolution: "@types/qs@npm:6.9.15" @@ -3689,6 +3805,15 @@ __metadata: languageName: node linkType: hard +"@types/ws@npm:^8.5.10": + version: 8.5.12 + resolution: "@types/ws@npm:8.5.12" + dependencies: + "@types/node": "*" + checksum: ddefb6ad1671f70ce73b38a5f47f471d4d493864fca7c51f002a86e5993d031294201c5dced6d5018fb8905ad46888d65c7f20dd54fc165910b69f42fba9a6d0 + languageName: node + linkType: hard + "@types/yargs-parser@npm:*": version: 21.0.3 resolution: "@types/yargs-parser@npm:21.0.3" @@ -13013,6 +13138,21 @@ __metadata: languageName: node linkType: hard +"ws@npm:^8.14.2": + version: 8.18.0 + resolution: "ws@npm:8.18.0" + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: ">=5.0.2" + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + checksum: 91d4d35bc99ff6df483bdf029b9ea4bfd7af1f16fc91231a96777a63d263e1eabf486e13a2353970efc534f9faa43bdbf9ee76525af22f4752cbc5ebda333975 + languageName: node + linkType: hard + "xdg-basedir@npm:^5.0.1, xdg-basedir@npm:^5.1.0": version: 5.1.0 resolution: "xdg-basedir@npm:5.1.0" From f649e311c8001ed7a08b078a60c92c5163870658 Mon Sep 17 00:00:00 2001 From: William Overton Date: Wed, 23 Oct 2024 22:15:28 +0100 Subject: [PATCH 02/11] Work on supabase checkpointer --- libs/checkpoint-supabase/src/index.ts | 54 +++++++++------------------ 1 file changed, 17 insertions(+), 37 deletions(-) diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 7be0c66c..f226f14a 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -5,10 +5,10 @@ import { BaseCheckpointSaver, type Checkpoint, type CheckpointListOptions, + type CheckpointMetadata, type CheckpointTuple, - type SerializerProtocol, type PendingWrite, - type CheckpointMetadata, + type SerializerProtocol, } from "@langchain/langgraph-checkpoint"; interface CheckpointRow { @@ -178,29 +178,21 @@ export class SupaSaver extends BaseCheckpointSaver { const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns; - let sql = - `SELECT\n` + - " thread_id,\n" + - " checkpoint_ns,\n" + - " checkpoint_id,\n" + - " parent_checkpoint_id,\n" + - " type,\n" + - " checkpoint,\n" + - " metadata\n" + - "FROM checkpoints\n"; - - const whereClause: string[] = []; + let query = this.client + .from("chat_session_checkpoints") + .select("*") + .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); if (thread_id) { - whereClause.push("thread_id = ?"); + query = query.eq("thread_id", thread_id); } if (checkpoint_ns !== undefined && checkpoint_ns !== null) { - whereClause.push("checkpoint_ns = ?"); + query = query.eq("checkpoint_ns", checkpoint_ns); } if (before?.configurable?.checkpoint_id !== undefined) { - whereClause.push("checkpoint_id < ?"); + query = query.lt("checkpoint_id", before.configurable.checkpoint_id); } const sanitizedFilter = Object.fromEntries( @@ -211,33 +203,21 @@ export class SupaSaver extends BaseCheckpointSaver { ) ); - whereClause.push( - ...Object.entries(sanitizedFilter).map( - ([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?` - ) - ); - - if (whereClause.length > 0) { - sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`; + for (const [key, value] of Object.entries(sanitizedFilter)) { + query = query.eq(`metadata->${key}`, JSON.stringify(value)); } - sql += "\nORDER BY checkpoint_id DESC"; + query = query.order("checkpoint_id", { ascending: false }); if (limit) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided + query = query.limit(parseInt(limit as any, 10)); } - const args = [ - thread_id, - checkpoint_ns, - before?.configurable?.checkpoint_id, - ...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)), - ].filter((value) => value !== undefined && value !== null); + const { data: rows, error } = await query; - const rows: CheckpointRow[] = this.db - .prepare(sql) - .all(...args) as CheckpointRow[]; + if (error) { + throw error; + } if (rows) { for (const row of rows) { From a55b7a7ca8a25ccc0e0996e8cf31b0645d12fac8 Mon Sep 17 00:00:00 2001 From: Andrew Morrison Date: Thu, 24 Oct 2024 10:16:34 +0100 Subject: [PATCH 03/11] Update supabase checkpointer Co-authored-by: William Overton --- libs/checkpoint-supabase/jest.config.cjs | 20 +++ libs/checkpoint-supabase/jest.env.cjs | 12 ++ libs/checkpoint-supabase/src/index.ts | 128 +++++++++++------- libs/checkpoint-validation/jest.config.cjs | 1 + .../src/tests/supabase_initializer.ts | 6 +- 5 files changed, 113 insertions(+), 54 deletions(-) create mode 100644 libs/checkpoint-supabase/jest.config.cjs create mode 100644 libs/checkpoint-supabase/jest.env.cjs diff --git a/libs/checkpoint-supabase/jest.config.cjs b/libs/checkpoint-supabase/jest.config.cjs new file mode 100644 index 00000000..385d19f6 --- /dev/null +++ b/libs/checkpoint-supabase/jest.config.cjs @@ -0,0 +1,20 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: "ts-jest/presets/default-esm", + testEnvironment: "./jest.env.cjs", + modulePathIgnorePatterns: ["dist/"], + moduleNameMapper: { + "^(\\.{1,2}/.*)\\.js$": "$1", + }, + transform: { + "^.+\\.tsx?$": ["@swc/jest"], + }, + transformIgnorePatterns: [ + "/node_modules/", + "\\.pnp\\.[^\\/]+$", + "./scripts/jest-setup-after-env.js", + ], + setupFiles: ["dotenv/config"], + testTimeout: 20_000, + passWithNoTests: true, +}; diff --git a/libs/checkpoint-supabase/jest.env.cjs b/libs/checkpoint-supabase/jest.env.cjs new file mode 100644 index 00000000..2ccedccb --- /dev/null +++ b/libs/checkpoint-supabase/jest.env.cjs @@ -0,0 +1,12 @@ +const { TestEnvironment } = require("jest-environment-node"); + +class AdjustedTestEnvironmentToSupportFloat32Array extends TestEnvironment { + constructor(config, context) { + // Make `instanceof Float32Array` return true in tests + // to avoid https://github.com/xenova/transformers.js/issues/57 and https://github.com/jestjs/jest/issues/2549 + super(config, context); + this.global.Float32Array = Float32Array; + } +} + +module.exports = AdjustedTestEnvironmentToSupportFloat32Array; diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index f226f14a..37ab0088 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -5,10 +5,10 @@ import { BaseCheckpointSaver, type Checkpoint, type CheckpointListOptions, - type CheckpointMetadata, type CheckpointTuple, - type PendingWrite, type SerializerProtocol, + type PendingWrite, + type CheckpointMetadata, } from "@langchain/langgraph-checkpoint"; interface CheckpointRow { @@ -70,7 +70,9 @@ export class SupaSaver extends BaseCheckpointSaver { checkpoint_ns = "", checkpoint_id, } = config.configurable ?? {}; + let res; + if (checkpoint_id) { // data = this.db // .prepare( @@ -78,9 +80,8 @@ export class SupaSaver extends BaseCheckpointSaver { // ) // .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; res = await this.client - .from("chat_session_checkpoints") + .from("langchain_checkpoints") .select("*") - .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d") .eq("id", checkpoint_id) .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) @@ -93,18 +94,20 @@ export class SupaSaver extends BaseCheckpointSaver { // ) // .get(thread_id, checkpoint_ns) as CheckpointRow; res = await this.client - .from("chat_session_checkpoints") + .from("langchain_checkpoints") .select("*") - .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d") .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) .maybeSingle() .throwOnError(); } - const row = res?.data?.[0]; + + const row = res?.data as CheckpointRow; + if (row === undefined) { return undefined; } + let finalConfig = config; if (!checkpoint_id) { finalConfig = { @@ -132,14 +135,14 @@ export class SupaSaver extends BaseCheckpointSaver { // finalConfig.configurable.checkpoint_id.toString() // ) as WritesRow[]; const pendingWritesRes = await this.client - .from("chat_session_writes") + .from("langchain_writes") .select("*") .eq("thread_id", finalConfig.configurable.thread_id.toString()) .eq("checkpoint_ns", checkpoint_ns) .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()); const pendingWritesRows = pendingWritesRes.data ?? []; const pendingWrites = await Promise.all( - pendingWritesRows.map(async (row) => { + pendingWritesRows.map(async (row: WritesRow) => { return [ row.task_id, row.channel, @@ -179,9 +182,9 @@ export class SupaSaver extends BaseCheckpointSaver { const checkpoint_ns = config.configurable?.checkpoint_ns; let query = this.client - .from("chat_session_checkpoints") + .from("langchain_checkpoints") .select("*") - .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); + if (thread_id) { query = query.eq("thread_id", thread_id); @@ -258,26 +261,38 @@ export class SupaSaver extends BaseCheckpointSaver { ): Promise { const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); + + const { thread_id, checkpoint_ns, checkpoint_id } = + config.configurable ?? {}; if (type1 !== type2) { throw new Error( "Failed to serialized checkpoint and metadata to the same type." ); } - const row = [ - config.configurable?.thread_id?.toString(), - config.configurable?.checkpoint_ns, - checkpoint.id, - config.configurable?.checkpoint_id, - type1, - serializedCheckpoint, - serializedMetadata, - ]; - - this.db - .prepare( - `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` - ) - .run(...row); + await this.client.from("langchain_checkpoints").insert({ + thread_id: thread_id, + checkpoint_ns: checkpoint_ns, + parent_checkpoint_id: checkpoint_id, + type: type1, + checkpoint: serializedCheckpoint, + metadata: serializedMetadata, + }); + // + // const row = [ + // config.configurable?.thread_id?.toString(), + // config.configurable?.checkpoint_ns, + // checkpoint.id, + // config.configurable?.checkpoint_id, + // type1, + // serializedCheckpoint, + // serializedMetadata, + // ]; + // + // this.db + // .prepare( + // `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` + // ) + // .run(...row); return { configurable: { @@ -293,32 +308,45 @@ export class SupaSaver extends BaseCheckpointSaver { writes: PendingWrite[], taskId: string ): Promise { - const stmt = this.db.prepare(` - INSERT OR REPLACE INTO writes - (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - `); + // const stmt = this.db.prepare(` + // INSERT OR REPLACE INTO writes + // (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) + // VALUES (?, ?, ?, ?, ?, ?, ?, ?) + // `); + // + // const transaction = this.db.transaction((rows) => { + // for (const row of rows) { + // stmt.run(...row); + // } + // }); - const transaction = this.db.transaction((rows) => { - for (const row of rows) { - stmt.run(...row); - } - }); + const thread_id = config.configurable!.thread_id; + const checkpoint_id = config.configurable!.checkpoint_id; + const checkpoint_ns = config.configurable!.checkpoint_ns; - const rows = writes.map((write, idx) => { - const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); - return [ - config.configurable?.thread_id, - config.configurable?.checkpoint_ns, - config.configurable?.checkpoint_id, - taskId, - idx, - write[0], - type, - serializedWrite, - ]; - }); + if ( + thread_id === undefined || + checkpoint_id === undefined || + checkpoint_ns === undefined + ) { + throw new Error("checkpoint_id, sessionId or checkpoint_ns is undefined"); + } + + await Promise.all( + writes.map(async (write, idx) => { + const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); - transaction(rows); + await this.client.from("langchain_writes").insert({ + thread_id: thread_id, + checkpoint_ns: checkpoint_ns, + checkpoint_id: checkpoint_id, + task_id: taskId, + idx, + channel: write[0], + type, + value: serializedWrite, + }); + }) + ); } } diff --git a/libs/checkpoint-validation/jest.config.cjs b/libs/checkpoint-validation/jest.config.cjs index ab56a4c3..418f44e6 100644 --- a/libs/checkpoint-validation/jest.config.cjs +++ b/libs/checkpoint-validation/jest.config.cjs @@ -11,6 +11,7 @@ module.exports = { "/libs/checkpoint-mongodb/src/index.ts", "/libs/checkpoint-postgres/src/index.ts", "/libs/checkpoint-sqlite/src/index.ts", + "/libs/checkpoint-supabase/src/index.ts", ], coveragePathIgnorePatterns: [ diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts index 7ad2b8ee..f832ad08 100644 --- a/libs/checkpoint-validation/src/tests/supabase_initializer.ts +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -17,13 +17,11 @@ export const initializer: CheckpointerTestInitializer = { async afterAll() { const client = createClient(SUPABASE_URL, SUPABASE_KEY); await client - .from("chat_session_checkpoints") + .from("langgraph_checkpoints") .delete() - .eq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); await client - .from("chat_session_writes") + .from("langgraph_writes") .delete() - .neq("session_id", "6b3cffb2-e521-46e3-9509-266f5380245d"); }, }; From bd2eba990a24850be9c11fc069365673d671e4b3 Mon Sep 17 00:00:00 2001 From: Andrew Morrison Date: Thu, 24 Oct 2024 11:57:31 +0100 Subject: [PATCH 04/11] Update parsing --- libs/checkpoint-supabase/src/index.ts | 120 ++++++++++-------- libs/checkpoint-validation/src/spec/list.ts | 21 +-- .../src/tests/supabase_initializer.ts | 20 ++- 3 files changed, 100 insertions(+), 61 deletions(-) diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 37ab0088..74e412c9 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -41,8 +41,8 @@ type CheckKeys = [K[number]] extends [ keyof T ] ? [keyof T] extends [K[number]] - ? K - : never + ? K + : never : never; function validateKeys( @@ -63,14 +63,31 @@ export class SupaSaver extends BaseCheckpointSaver { constructor(private client: SupabaseClient, serde?: SerializerProtocol) { super(serde); } + protected _dumpCheckpoint(checkpoint: Checkpoint) { + const serialized: Record = { + ...checkpoint, + pending_sends: [], + }; + if ("channel_values" in serialized) { + delete serialized.channel_values; + } + return serialized; + } + protected _dumpMetadata(metadata: CheckpointMetadata) { + const [, serializedMetadata] = this.serde.dumpsTyped(metadata); + // We need to remove null characters before writing + return JSON.parse( + new TextDecoder().decode(serializedMetadata).replace(/\0/g, "") + ); + } async getTuple(config: RunnableConfig): Promise { const { thread_id, checkpoint_ns = "", checkpoint_id, } = config.configurable ?? {}; - + let res; if (checkpoint_id) { @@ -80,9 +97,9 @@ export class SupaSaver extends BaseCheckpointSaver { // ) // .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; res = await this.client - .from("langchain_checkpoints") + .from("langgraph_checkpoints") .select("*") - .eq("id", checkpoint_id) + .eq("checkpoint_id", checkpoint_id) .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) .maybeSingle() @@ -94,17 +111,18 @@ export class SupaSaver extends BaseCheckpointSaver { // ) // .get(thread_id, checkpoint_ns) as CheckpointRow; res = await this.client - .from("langchain_checkpoints") + .from("langgraph_checkpoints") .select("*") .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) + .order("checkpoint_id", { ascending: false }) .maybeSingle() .throwOnError(); } - + const row = res?.data as CheckpointRow; - if (row === undefined) { + if (row === null) { return undefined; } @@ -135,39 +153,42 @@ export class SupaSaver extends BaseCheckpointSaver { // finalConfig.configurable.checkpoint_id.toString() // ) as WritesRow[]; const pendingWritesRes = await this.client - .from("langchain_writes") + .from("langgraph_writes") .select("*") .eq("thread_id", finalConfig.configurable.thread_id.toString()) .eq("checkpoint_ns", checkpoint_ns) - .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()); + .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()) + .throwOnError(); + const pendingWritesRows = pendingWritesRes.data ?? []; const pendingWrites = await Promise.all( pendingWritesRows.map(async (row: WritesRow) => { return [ row.task_id, row.channel, - await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), + await this.serde.loadsTyped(row.type ?? "json", JSON.stringify(row.value) ?? ""), ] as [string, string, unknown]; }) ); + return { config: finalConfig, checkpoint: (await this.serde.loadsTyped( row.type ?? "json", - row.checkpoint + JSON.stringify(row.checkpoint) )) as Checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", - row.metadata + JSON.stringify(row.metadata) )) as CheckpointMetadata, parentConfig: row.parent_checkpoint_id ? { - configurable: { - thread_id: row.thread_id, - checkpoint_ns, - checkpoint_id: row.parent_checkpoint_id, - }, - } + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } : undefined, pendingWrites, }; @@ -182,10 +203,9 @@ export class SupaSaver extends BaseCheckpointSaver { const checkpoint_ns = config.configurable?.checkpoint_ns; let query = this.client - .from("langchain_checkpoints") + .from("langgraph_checkpoints") .select("*") - if (thread_id) { query = query.eq("thread_id", thread_id); } @@ -207,7 +227,7 @@ export class SupaSaver extends BaseCheckpointSaver { ); for (const [key, value] of Object.entries(sanitizedFilter)) { - query = query.eq(`metadata->${key}`, JSON.stringify(value)); + query = query.eq(`metadata@>${key}`, JSON.stringify(value)); } query = query.order("checkpoint_id", { ascending: false }); @@ -216,7 +236,7 @@ export class SupaSaver extends BaseCheckpointSaver { query = query.limit(parseInt(limit as any, 10)); } - const { data: rows, error } = await query; + const { data: rows, error } = await query.throwOnError(); if (error) { throw error; @@ -234,20 +254,20 @@ export class SupaSaver extends BaseCheckpointSaver { }, checkpoint: (await this.serde.loadsTyped( row.type ?? "json", - row.checkpoint + JSON.stringify(row.checkpoint) )) as Checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", - row.metadata + JSON.stringify(row.metadata) )) as CheckpointMetadata, parentConfig: row.parent_checkpoint_id ? { - configurable: { - thread_id: row.thread_id, - checkpoint_ns: row.checkpoint_ns, - checkpoint_id: row.parent_checkpoint_id, - }, - } + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } : undefined, }; } @@ -259,24 +279,20 @@ export class SupaSaver extends BaseCheckpointSaver { checkpoint: Checkpoint, metadata: CheckpointMetadata ): Promise { - const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); - const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); - - const { thread_id, checkpoint_ns, checkpoint_id } = - config.configurable ?? {}; - if (type1 !== type2) { - throw new Error( - "Failed to serialized checkpoint and metadata to the same type." - ); - } - await this.client.from("langchain_checkpoints").insert({ + const serializedCheckpoint = this._dumpCheckpoint(checkpoint); + const serializedMetadata = this._dumpMetadata(metadata); + const { thread_id, checkpoint_ns, checkpoint_id } = config.configurable!; + + await this.client.from("langgraph_checkpoints").insert({ + checkpoint_id: checkpoint.id, thread_id: thread_id, checkpoint_ns: checkpoint_ns, parent_checkpoint_id: checkpoint_id, - type: type1, + type: 'json', checkpoint: serializedCheckpoint, metadata: serializedMetadata, - }); + }).throwOnError(); + // // const row = [ // config.configurable?.thread_id?.toString(), @@ -334,18 +350,22 @@ export class SupaSaver extends BaseCheckpointSaver { await Promise.all( writes.map(async (write, idx) => { - const [type, serializedWrite] = this.serde.dumpsTyped(write[1]); - - await this.client.from("langchain_writes").insert({ + const [, serializedWrite] = this.serde.dumpsTyped(write[1]); + ; + await this.client.from("langgraph_writes").upsert({ thread_id: thread_id, checkpoint_ns: checkpoint_ns, checkpoint_id: checkpoint_id, task_id: taskId, idx, channel: write[0], - type, - value: serializedWrite, - }); + type: "json", + value: JSON.parse( + new TextDecoder().decode(serializedWrite).replace(/\0/g, "") + ), + }, { + onConfict: "thread_id,checkpoint_ns,checkpoint_id,task_id,idx", + }).throwOnError(); }) ); } diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index d9f63097..ff11c37b 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -121,9 +121,11 @@ export function listTests( // see: https://github.com/langchain-ai/langgraphjs/issues/590 const checkpointerIncludesPendingWritesOnList = initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-mongodb" && + "@langchain/langgraph-checkpoint-mongodb" && initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-sqlite"; + "@langchain/langgraph-checkpoint-sqlite" && + initializer.checkpointerName !== + "@langchain/langgraph-checkpoint-supabase"; const expectedTuple = expectedTuplesMap.get(key); if (!checkpointerIncludesPendingWritesOnList) { @@ -172,7 +174,7 @@ export function listTests( // TODO: MongoDBSaver support for filter is broken and can't be fixed without a breaking change // see: https://github.com/langchain-ai/langgraphjs/issues/581 initializer.checkpointerName === - "@langchain/langgraph-checkpoint-mongodb" + "@langchain/langgraph-checkpoint-mongodb" ? [undefined] : [undefined, {}, { source: "input" }, { source: "loop" }], }; @@ -194,17 +196,17 @@ export function listTests( tuple.config.configurable?.thread_id === thread_id) && (checkpoint_ns === undefined || tuple.config.configurable?.checkpoint_ns === - checkpoint_ns) && + checkpoint_ns) && (before === undefined || tuple.checkpoint.id < - before.configurable?.checkpoint_id) && + before.configurable?.checkpoint_id) && (filter === undefined || Object.entries(filter).every( ([key, value]) => ( tuple.metadata as - | Record - | undefined + | Record + | undefined )?.[key] === value )) ) @@ -322,9 +324,8 @@ export function listTests( const descriptionWhen = descriptionWhenParts.length > 1 - ? `${descriptionWhenParts.slice(0, -1).join(", ")}, and ${ - descriptionWhenParts[descriptionWhenParts.length - 1] - }` + ? `${descriptionWhenParts.slice(0, -1).join(", ")}, and ${descriptionWhenParts[descriptionWhenParts.length - 1] + }` : descriptionWhenParts[0]; return `should return ${descriptionTupleCount} when ${descriptionWhen}`; diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts index f832ad08..dcfcdc37 100644 --- a/libs/checkpoint-validation/src/tests/supabase_initializer.ts +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -15,14 +15,32 @@ export const initializer: CheckpointerTestInitializer = { }, async afterAll() { + // const client = createClient(SUPABASE_URL, SUPABASE_KEY); + // await client + // .from("langgraph_checkpoints") + // .delete() + // .neq("thread_id", "filter-needs-a-value") + // .throwOnError() + // await client + // .from("langgraph_writes") + // .delete() + // .neq("thread_id", "filter-needs-a-value") + // .throwOnError() + }, + + async beforeAll() { const client = createClient(SUPABASE_URL, SUPABASE_KEY); await client .from("langgraph_checkpoints") .delete() + .neq("thread_id", "filter-needs-a-value") + .throwOnError() await client .from("langgraph_writes") .delete() - }, + .neq("thread_id", "filter-needs-a-value") + .throwOnError() + } }; export default initializer; From 1aed561de2259e856a408188b4c0c5d31b044434 Mon Sep 17 00:00:00 2001 From: Andrew Morrison Date: Thu, 24 Oct 2024 12:19:54 +0100 Subject: [PATCH 05/11] update metadata filter --- libs/checkpoint-supabase/src/index.ts | 7 ++++--- libs/checkpoint-validation/src/spec/list.ts | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 74e412c9..9a8d2eaa 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -199,8 +199,7 @@ export class SupaSaver extends BaseCheckpointSaver { options?: CheckpointListOptions ): AsyncGenerator { const { limit, before, filter } = options ?? {}; - const thread_id = config.configurable?.thread_id; - const checkpoint_ns = config.configurable?.checkpoint_ns; + const {thread_id, checkpoint_ns} = config.configurable ?? {}; let query = this.client .from("langgraph_checkpoints") @@ -227,7 +226,9 @@ export class SupaSaver extends BaseCheckpointSaver { ); for (const [key, value] of Object.entries(sanitizedFilter)) { - query = query.eq(`metadata@>${key}`, JSON.stringify(value)); + let searchObject = {} as any; + searchObject[key] = value + query = query.contains(`metadata`, JSON.stringify(searchObject)); } query = query.order("checkpoint_id", { ascending: false }); diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index ff11c37b..0bbeb6c0 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -116,7 +116,7 @@ export function listTests( } else { expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size); for (const [key, value] of actualTuplesMap.entries()) { - // TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them + // TODO: MongoDBSaver, SQLiteSaver And SupabaseSaver don't return pendingWrites on list, so we need to special case them // see: https://github.com/langchain-ai/langgraphjs/issues/589 // see: https://github.com/langchain-ai/langgraphjs/issues/590 const checkpointerIncludesPendingWritesOnList = @@ -132,6 +132,10 @@ export function listTests( delete expectedTuple?.pendingWrites; } + if(expectedTuple === undefined) { + console.log("Blammo!", value, expectedTuple); + } + expect(value).toEqual(expectedTuple); } } From 1578a7320fc5a8a03b336ae638b9b10c29611cb6 Mon Sep 17 00:00:00 2001 From: Andrew Morrison Date: Thu, 24 Oct 2024 12:33:24 +0100 Subject: [PATCH 06/11] skipped test in put method --- libs/checkpoint-supabase/src/index.ts | 25 +++++++++------------ libs/checkpoint-validation/src/spec/list.ts | 4 ---- libs/checkpoint-validation/src/spec/put.ts | 2 ++ 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 9a8d2eaa..0d6de191 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -136,12 +136,14 @@ export class SupaSaver extends BaseCheckpointSaver { }, }; } + if ( finalConfig.configurable?.thread_id === undefined || finalConfig.configurable?.checkpoint_id === undefined ) { throw new Error("Missing thread_id or checkpoint_id"); } + // find any pending writes // const pendingWritesRows = this.db // .prepare( @@ -152,6 +154,7 @@ export class SupaSaver extends BaseCheckpointSaver { // checkpoint_ns, // finalConfig.configurable.checkpoint_id.toString() // ) as WritesRow[]; + const pendingWritesRes = await this.client .from("langgraph_writes") .select("*") @@ -160,7 +163,7 @@ export class SupaSaver extends BaseCheckpointSaver { .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()) .throwOnError(); - const pendingWritesRows = pendingWritesRes.data ?? []; + const pendingWritesRows = (pendingWritesRes.data ?? []) as WritesRow[]; const pendingWrites = await Promise.all( pendingWritesRows.map(async (row: WritesRow) => { return [ @@ -284,14 +287,16 @@ export class SupaSaver extends BaseCheckpointSaver { const serializedMetadata = this._dumpMetadata(metadata); const { thread_id, checkpoint_ns, checkpoint_id } = config.configurable!; - await this.client.from("langgraph_checkpoints").insert({ + await this.client.from("langgraph_checkpoints").upsert({ checkpoint_id: checkpoint.id, - thread_id: thread_id, + thread_id: thread_id.toString(), checkpoint_ns: checkpoint_ns, parent_checkpoint_id: checkpoint_id, type: 'json', checkpoint: serializedCheckpoint, metadata: serializedMetadata, + }, { + onConflict: "thread_id,checkpoint_ns,checkpoint_id", }).throwOnError(); // @@ -337,22 +342,12 @@ export class SupaSaver extends BaseCheckpointSaver { // } // }); - const thread_id = config.configurable!.thread_id; - const checkpoint_id = config.configurable!.checkpoint_id; - const checkpoint_ns = config.configurable!.checkpoint_ns; - - if ( - thread_id === undefined || - checkpoint_id === undefined || - checkpoint_ns === undefined - ) { - throw new Error("checkpoint_id, sessionId or checkpoint_ns is undefined"); - } + const {thread_id, checkpoint_id, checkpoint_ns} = config.configurable!; await Promise.all( writes.map(async (write, idx) => { const [, serializedWrite] = this.serde.dumpsTyped(write[1]); - ; + await this.client.from("langgraph_writes").upsert({ thread_id: thread_id, checkpoint_ns: checkpoint_ns, diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index 0bbeb6c0..bada2f42 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -132,10 +132,6 @@ export function listTests( delete expectedTuple?.pendingWrites; } - if(expectedTuple === undefined) { - console.log("Blammo!", value, expectedTuple); - } - expect(value).toEqual(expectedTuple); } } diff --git a/libs/checkpoint-validation/src/spec/put.ts b/libs/checkpoint-validation/src/spec/put.ts index 5ecc343d..63db2f73 100644 --- a/libs/checkpoint-validation/src/spec/put.ts +++ b/libs/checkpoint-validation/src/spec/put.ts @@ -222,6 +222,8 @@ export function putTests( "TODO: MongoDBSaver doesn't store channel deltas", "@langchain/langgraph-checkpoint-sqlite": "TODO: SQLiteSaver doesn't store channel deltas", + "@langchain/langgraph-checkpoint-supabase": + "TODO: SupabaseSaver doesn't store channel deltas", })( "should only store channel_values that have changed (based on newVersions)", async () => { From 3ccbc715988be47655411ac1ca400d6c15a03a08 Mon Sep 17 00:00:00 2001 From: William Overton Date: Tue, 29 Oct 2024 10:56:36 +0000 Subject: [PATCH 07/11] Include Jacob's work and continue attacking tests --- libs/checkpoint-sqlite/src/index.ts | 6 +- libs/checkpoint-supabase/src/index.ts | 201 ++++++++------------- libs/checkpoint-validation/src/spec/put.ts | 6 +- 3 files changed, 86 insertions(+), 127 deletions(-) diff --git a/libs/checkpoint-sqlite/src/index.ts b/libs/checkpoint-sqlite/src/index.ts index 6740bb22..622fee6b 100644 --- a/libs/checkpoint-sqlite/src/index.ts +++ b/libs/checkpoint-sqlite/src/index.ts @@ -1,14 +1,14 @@ -import Database, { Database as DatabaseType } from "better-sqlite3"; import type { RunnableConfig } from "@langchain/core/runnables"; import { BaseCheckpointSaver, type Checkpoint, type CheckpointListOptions, + type CheckpointMetadata, type CheckpointTuple, - type SerializerProtocol, type PendingWrite, - type CheckpointMetadata, + type SerializerProtocol, } from "@langchain/langgraph-checkpoint"; +import Database, { Database as DatabaseType } from "better-sqlite3"; interface CheckpointRow { checkpoint: string; diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 0d6de191..10ea1898 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -5,10 +5,10 @@ import { BaseCheckpointSaver, type Checkpoint, type CheckpointListOptions, + type CheckpointMetadata, type CheckpointTuple, - type SerializerProtocol, type PendingWrite, - type CheckpointMetadata, + type SerializerProtocol, } from "@langchain/langgraph-checkpoint"; interface CheckpointRow { @@ -41,8 +41,8 @@ type CheckKeys = [K[number]] extends [ keyof T ] ? [keyof T] extends [K[number]] - ? K - : never + ? K + : never : never; function validateKeys( @@ -63,16 +63,6 @@ export class SupaSaver extends BaseCheckpointSaver { constructor(private client: SupabaseClient, serde?: SerializerProtocol) { super(serde); } - protected _dumpCheckpoint(checkpoint: Checkpoint) { - const serialized: Record = { - ...checkpoint, - pending_sends: [], - }; - if ("channel_values" in serialized) { - delete serialized.channel_values; - } - return serialized; - } protected _dumpMetadata(metadata: CheckpointMetadata) { const [, serializedMetadata] = this.serde.dumpsTyped(metadata); @@ -91,38 +81,27 @@ export class SupaSaver extends BaseCheckpointSaver { let res; if (checkpoint_id) { - // data = this.db - // .prepare( - // `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - // ) - // .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; res = await this.client .from("langgraph_checkpoints") - .select("*") + .select() .eq("checkpoint_id", checkpoint_id) .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) - .maybeSingle() .throwOnError(); } else { - // row = this.db - // .prepare( - // `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1` - // ) - // .get(thread_id, checkpoint_ns) as CheckpointRow; res = await this.client .from("langgraph_checkpoints") - .select("*") + .select() .eq("thread_id", thread_id) .eq("checkpoint_ns", checkpoint_ns) .order("checkpoint_id", { ascending: false }) - .maybeSingle() .throwOnError(); } - const row = res?.data as CheckpointRow; + const rows = res.data as CheckpointRow[]; - if (row === null) { + const row = rows[0]; + if (row == null) { return undefined; } @@ -157,7 +136,7 @@ export class SupaSaver extends BaseCheckpointSaver { const pendingWritesRes = await this.client .from("langgraph_writes") - .select("*") + .select() .eq("thread_id", finalConfig.configurable.thread_id.toString()) .eq("checkpoint_ns", checkpoint_ns) .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()) @@ -169,7 +148,10 @@ export class SupaSaver extends BaseCheckpointSaver { return [ row.task_id, row.channel, - await this.serde.loadsTyped(row.type ?? "json", JSON.stringify(row.value) ?? ""), + await this.serde.loadsTyped( + row.type ?? "json", + JSON.stringify(row.value) ?? "" + ), ] as [string, string, unknown]; }) ); @@ -186,12 +168,12 @@ export class SupaSaver extends BaseCheckpointSaver { )) as CheckpointMetadata, parentConfig: row.parent_checkpoint_id ? { - configurable: { - thread_id: row.thread_id, - checkpoint_ns, - checkpoint_id: row.parent_checkpoint_id, - }, - } + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } : undefined, pendingWrites, }; @@ -202,13 +184,12 @@ export class SupaSaver extends BaseCheckpointSaver { options?: CheckpointListOptions ): AsyncGenerator { const { limit, before, filter } = options ?? {}; - const {thread_id, checkpoint_ns} = config.configurable ?? {}; + const thread_id = config.configurable?.thread_id; + const checkpoint_ns = config.configurable?.checkpoint_ns; - let query = this.client - .from("langgraph_checkpoints") - .select("*") + let query = this.client.from("langgraph_checkpoints").select("*"); - if (thread_id) { + if (thread_id !== undefined && thread_id !== null) { query = query.eq("thread_id", thread_id); } @@ -230,8 +211,8 @@ export class SupaSaver extends BaseCheckpointSaver { for (const [key, value] of Object.entries(sanitizedFilter)) { let searchObject = {} as any; - searchObject[key] = value - query = query.contains(`metadata`, JSON.stringify(searchObject)); + searchObject[key] = value; + query = query.eq(`metadata->>${key}`, value); } query = query.order("checkpoint_id", { ascending: false }); @@ -240,10 +221,10 @@ export class SupaSaver extends BaseCheckpointSaver { query = query.limit(parseInt(limit as any, 10)); } - const { data: rows, error } = await query.throwOnError(); + const { data: rows } = await query.throwOnError(); - if (error) { - throw error; + if (rows === null) { + throw new Error("Unexpected error listing checkpoints"); } if (rows) { @@ -266,12 +247,12 @@ export class SupaSaver extends BaseCheckpointSaver { )) as CheckpointMetadata, parentConfig: row.parent_checkpoint_id ? { - configurable: { - thread_id: row.thread_id, - checkpoint_ns: row.checkpoint_ns, - checkpoint_id: row.parent_checkpoint_id, - }, - } + configurable: { + thread_id: row.thread_id, + checkpoint_ns: row.checkpoint_ns, + checkpoint_id: row.parent_checkpoint_id, + }, + } : undefined, }; } @@ -283,43 +264,25 @@ export class SupaSaver extends BaseCheckpointSaver { checkpoint: Checkpoint, metadata: CheckpointMetadata ): Promise { - const serializedCheckpoint = this._dumpCheckpoint(checkpoint); - const serializedMetadata = this._dumpMetadata(metadata); - const { thread_id, checkpoint_ns, checkpoint_id } = config.configurable!; - - await this.client.from("langgraph_checkpoints").upsert({ - checkpoint_id: checkpoint.id, - thread_id: thread_id.toString(), - checkpoint_ns: checkpoint_ns, - parent_checkpoint_id: checkpoint_id, - type: 'json', - checkpoint: serializedCheckpoint, - metadata: serializedMetadata, - }, { - onConflict: "thread_id,checkpoint_ns,checkpoint_id", - }).throwOnError(); - - // - // const row = [ - // config.configurable?.thread_id?.toString(), - // config.configurable?.checkpoint_ns, - // checkpoint.id, - // config.configurable?.checkpoint_id, - // type1, - // serializedCheckpoint, - // serializedMetadata, - // ]; - // - // this.db - // .prepare( - // `INSERT OR REPLACE INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata) VALUES (?, ?, ?, ?, ?, ?, ?)` - // ) - // .run(...row); + await this.client + .from("langgraph_checkpoints") + .upsert( + { + thread_id: config.configurable?.thread_id, + checkpoint_ns: config.configurable?.checkpoint_ns, + checkpoint_id: checkpoint.id, + parent_checkpoint_id: config.configurable?.checkpoint_id, + type: "json", + checkpoint: checkpoint, + metadata: metadata, + } + ) + .throwOnError(); return { configurable: { thread_id: config.configurable?.thread_id, - checkpoint_ns: config.configurable?.checkpoint_ns, + checkpoint_ns: config.configurable?.checkpoint_ns ?? "", checkpoint_id: checkpoint.id, }, }; @@ -330,39 +293,33 @@ export class SupaSaver extends BaseCheckpointSaver { writes: PendingWrite[], taskId: string ): Promise { - // const stmt = this.db.prepare(` - // INSERT OR REPLACE INTO writes - // (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) - // VALUES (?, ?, ?, ?, ?, ?, ?, ?) - // `); - // - // const transaction = this.db.transaction((rows) => { - // for (const row of rows) { - // stmt.run(...row); - // } - // }); - - const {thread_id, checkpoint_id, checkpoint_ns} = config.configurable!; - - await Promise.all( - writes.map(async (write, idx) => { - const [, serializedWrite] = this.serde.dumpsTyped(write[1]); - - await this.client.from("langgraph_writes").upsert({ - thread_id: thread_id, - checkpoint_ns: checkpoint_ns, - checkpoint_id: checkpoint_id, - task_id: taskId, - idx, - channel: write[0], - type: "json", - value: JSON.parse( - new TextDecoder().decode(serializedWrite).replace(/\0/g, "") - ), - }, { - onConfict: "thread_id,checkpoint_ns,checkpoint_id,task_id,idx", - }).throwOnError(); - }) - ); + const thread_id = config.configurable?.thread_id; + const checkpoint_id = config.configurable?.checkpoint_id; + const checkpoint_ns = config.configurable?.checkpoint_ns; + + // Process writes sequentially + for (const [idx, write] of writes.entries()) { + const [, serializedWrite] = this.serde.dumpsTyped(write[1]); + + await this.client + .from("langgraph_writes") + .upsert( + [ + { + thread_id, + checkpoint_ns, + checkpoint_id, + task_id: taskId, + idx, + channel: write[0], + type: "json", + value: JSON.parse( + new TextDecoder().decode(serializedWrite).replace(/\0/g, "") + ), + }, + ] + ) + .throwOnError(); + } } -} +} \ No newline at end of file diff --git a/libs/checkpoint-validation/src/spec/put.ts b/libs/checkpoint-validation/src/spec/put.ts index 63db2f73..8fc552b0 100644 --- a/libs/checkpoint-validation/src/spec/put.ts +++ b/libs/checkpoint-validation/src/spec/put.ts @@ -1,3 +1,4 @@ +import { RunnableConfig } from "@langchain/core/runnables"; import { Checkpoint, CheckpointMetadata, @@ -5,13 +6,12 @@ import { uuid6, type BaseCheckpointSaver, } from "@langchain/langgraph-checkpoint"; -import { RunnableConfig } from "@langchain/core/runnables"; -import { CheckpointerTestInitializer } from "../types.js"; import { initialCheckpointTuple, it_skipForSomeModules, putTuples, } from "../test_utils.js"; +import { CheckpointerTestInitializer } from "../types.js"; export function putTests( initializer: CheckpointerTestInitializer @@ -185,6 +185,8 @@ export function putTests( // see: https://github.com/langchain-ai/langgraphjs/issues/592 "@langchain/langgraph-checkpoint-sqlite": "TODO: SqliteSaver stores config with no checkpoint_ns instead of default namespace", + "@langchain/langgraph-checkpoint-supabase": + "TODO: SupabaseSaver stores config with no checkpoint_ns instead of default namespace", })( "should default to empty namespace if the checkpoint namespace is missing from config.configurable", async () => { From cd62b74d14711207ae820f08d0080dd9172dfa13 Mon Sep 17 00:00:00 2001 From: James Birtles Date: Tue, 29 Oct 2024 12:30:55 +0000 Subject: [PATCH 08/11] clean up db in destroyCheckpointer --- libs/checkpoint-validation/src/tests/supabase_initializer.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts index dcfcdc37..0ac10e3d 100644 --- a/libs/checkpoint-validation/src/tests/supabase_initializer.ts +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -28,7 +28,7 @@ export const initializer: CheckpointerTestInitializer = { // .throwOnError() }, - async beforeAll() { + async destroyCheckpointer() { const client = createClient(SUPABASE_URL, SUPABASE_KEY); await client .from("langgraph_checkpoints") From 396170f8b5c579ef39486520498009c4e79686e7 Mon Sep 17 00:00:00 2001 From: William Overton Date: Wed, 30 Oct 2024 09:52:17 +0000 Subject: [PATCH 09/11] Refactor and add readme --- libs/checkpoint-supabase/README.md | 106 +++++++++ libs/checkpoint-supabase/src/index.ts | 207 +++++++++--------- libs/checkpoint-validation/package.json | 1 + .../src/tests/supabase_initializer.ts | 8 +- yarn.lock | 15 ++ 5 files changed, 227 insertions(+), 110 deletions(-) create mode 100644 libs/checkpoint-supabase/README.md diff --git a/libs/checkpoint-supabase/README.md b/libs/checkpoint-supabase/README.md new file mode 100644 index 00000000..184f81f7 --- /dev/null +++ b/libs/checkpoint-supabase/README.md @@ -0,0 +1,106 @@ +# @langchain/langgraph-checkpoint-supabase + +Implementation of a [LangGraph.js](https://github.com/langchain-ai/langgraphjs) CheckpointSaver that uses the Supabase JS SDK. + +## Setup + +Create the following tables in your Supabase database, you can change the table names if required when also setting the `checkPointTable` and `writeTable` options of the `SupabaseSaver` class. + +> [!CAUTION] +> ⚠️ Make sure to enable RLS policies on the tables! + +```sql +create table + public.langgraph_checkpoints ( + thread_id text not null, + created_at timestamp with time zone not null default now(), + checkpoint_ns text not null default '', + checkpoint_id text not null, + parent_checkpoint_id text null, + type text null, + checkpoint jsonb null, + metadata jsonb null, + constraint langgraph_checkpoints_pkey primary key (thread_id, checkpoint_ns, checkpoint_id) + ) tablespace pg_default; + +create table + public.langgraph_writes ( + thread_id text not null, + created_at timestamp with time zone not null default now(), + checkpoint_ns text not null default '', + checkpoint_id text not null, + task_id text not null, + idx bigint not null, + channel text not null, + type text null, + value jsonb null, + constraint langgraph_writes_pkey primary key ( + thread_id, + checkpoint_ns, + checkpoint_id, + task_id, + idx + ) + ) tablespace pg_default; +``` + +## Usage + +```ts +import { SqliteSaver } from "@langchain/langgraph-checkpoint-supabase"; + +const writeConfig = { + configurable: { + thread_id: "1", + checkpoint_ns: "" + } +}; +const readConfig = { + configurable: { + thread_id: "1" + } +}; + +const supabaseClient = createClient(SUPABASE_URL, SUPABASE_KEY); +const checkpointer = new SupabaseSaver(supabaseClient, { + checkPointTable: "langgraph_checkpoints", + writeTable: "langgraph_writes", +}); + +const checkpoint = { + v: 1, + ts: "2024-07-31T20:14:19.804150+00:00", + id: "1ef4f797-8335-6428-8001-8a1503f9b875", + channel_values: { + my_key: "meow", + node: "node" + }, + channel_versions: { + __start__: 2, + my_key: 3, + start:node: 3, + node: 3 + }, + versions_seen: { + __input__: {}, + __start__: { + __start__: 1 + }, + node: { + start:node: 2 + } + }, + pending_sends: [], +} + +// store checkpoint +await checkpointer.put(writeConfig, checkpoint, {}, {}) + +// load checkpoint +await checkpointer.get(readConfig) + +// list checkpoints +for await (const checkpoint of checkpointer.list(readConfig)) { + console.log(checkpoint); +} +``` diff --git a/libs/checkpoint-supabase/src/index.ts b/libs/checkpoint-supabase/src/index.ts index 10ea1898..d75347bc 100644 --- a/libs/checkpoint-supabase/src/index.ts +++ b/libs/checkpoint-supabase/src/index.ts @@ -32,9 +32,7 @@ interface WritesRow { value?: string; } -// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys -// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list -// of keys that we use is out of sync with the `CheckpointMetadata` type. +const DEFAULT_TYPE = 'json' as const; const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const; type CheckKeys = [K[number]] extends [ @@ -51,132 +49,129 @@ function validateKeys( return keys; } -// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the -// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in -// `CheckpointMetadata` const validCheckpointMetadataKeys = validateKeys< CheckpointMetadata, typeof checkpointMetadataKeys >(checkpointMetadataKeys); -export class SupaSaver extends BaseCheckpointSaver { - constructor(private client: SupabaseClient, serde?: SerializerProtocol) { +export class SupabaseSaver extends BaseCheckpointSaver { + + private options: { + checkPointTable: string; + writeTable: string; + } = { + checkPointTable: "langgraph_checkpoints", + writeTable: "langgraph_writes", + }; + + constructor(private client: SupabaseClient, config?: { + checkPointTable?: string; + writeTable?: string; + },serde?: SerializerProtocol) { super(serde); + + // Apply config + if (config) { + this.options = { + ...this.options, + ...config, + }; + } } - protected _dumpMetadata(metadata: CheckpointMetadata) { + protected _dumpMetadata(metadata: CheckpointMetadata): unknown { const [, serializedMetadata] = this.serde.dumpsTyped(metadata); - // We need to remove null characters before writing + return this.parseAndCleanJson(serializedMetadata); + } + + private parseAndCleanJson(data: Uint8Array): unknown { return JSON.parse( - new TextDecoder().decode(serializedMetadata).replace(/\0/g, "") + new TextDecoder().decode(data).replace(/\0/g, "") ); } - async getTuple(config: RunnableConfig): Promise { - const { - thread_id, - checkpoint_ns = "", - checkpoint_id, - } = config.configurable ?? {}; - - let res; - - if (checkpoint_id) { - res = await this.client - .from("langgraph_checkpoints") - .select() - .eq("checkpoint_id", checkpoint_id) - .eq("thread_id", thread_id) - .eq("checkpoint_ns", checkpoint_ns) - .throwOnError(); - } else { - res = await this.client - .from("langgraph_checkpoints") - .select() - .eq("thread_id", thread_id) - .eq("checkpoint_ns", checkpoint_ns) - .order("checkpoint_id", { ascending: false }) - .throwOnError(); + + private validateConfig(config: RunnableConfig): asserts config is Required { + if (!config.configurable?.thread_id || !config.configurable?.checkpoint_id) { + throw new Error("Missing required config: thread_id or checkpoint_id"); } + } - const rows = res.data as CheckpointRow[]; + async getTuple(config: RunnableConfig): Promise { + const { thread_id, checkpoint_ns = "", checkpoint_id } = config.configurable ?? {}; + + const query = this.client + .from(this.options.checkPointTable) + .select() + .eq("thread_id", thread_id) + .eq("checkpoint_ns", checkpoint_ns); - const row = rows[0]; - if (row == null) { - return undefined; - } + const res = await (checkpoint_id + ? query.eq("checkpoint_id", checkpoint_id) + : query.order("checkpoint_id", { ascending: false }) + ).throwOnError(); + + const [row] = res.data as CheckpointRow[]; + if (!row) return undefined; + + const finalConfig = !checkpoint_id ? { + configurable: { + thread_id: row.thread_id, + checkpoint_ns, + checkpoint_id: row.checkpoint_id, + }, + } : config; + + this.validateConfig(finalConfig); + + const pendingWrites = await this.fetchPendingWrites( + finalConfig.configurable.thread_id, + checkpoint_ns, + finalConfig.configurable.checkpoint_id + ); - let finalConfig = config; - if (!checkpoint_id) { - finalConfig = { + return { + config: finalConfig, + checkpoint: await this.deserializeField(row.type, row.checkpoint) as Checkpoint, + metadata: await this.deserializeField(row.type, row.metadata) as CheckpointMetadata, + parentConfig: row.parent_checkpoint_id ? { configurable: { thread_id: row.thread_id, checkpoint_ns, - checkpoint_id: row.checkpoint_id, + checkpoint_id: row.parent_checkpoint_id, }, - }; - } + } : undefined, + pendingWrites, + }; + } - if ( - finalConfig.configurable?.thread_id === undefined || - finalConfig.configurable?.checkpoint_id === undefined - ) { - throw new Error("Missing thread_id or checkpoint_id"); - } + private async deserializeField(type: string | undefined, value: string): Promise { + return this.serde.loadsTyped( + type ?? DEFAULT_TYPE, + JSON.stringify(value) + ); + } - // find any pending writes - // const pendingWritesRows = this.db - // .prepare( - // `SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - // ) - // .all( - // finalConfig.configurable.thread_id.toString(), - // checkpoint_ns, - // finalConfig.configurable.checkpoint_id.toString() - // ) as WritesRow[]; - - const pendingWritesRes = await this.client - .from("langgraph_writes") + private async fetchPendingWrites( + threadId: string, + checkpointNs: string, + checkpointId: string + ): Promise<[string, string, unknown][]> { + const { data } = await this.client + .from(this.options.writeTable) .select() - .eq("thread_id", finalConfig.configurable.thread_id.toString()) - .eq("checkpoint_ns", checkpoint_ns) - .eq("checkpoint_id", finalConfig.configurable.checkpoint_id.toString()) + .eq("thread_id", threadId) + .eq("checkpoint_ns", checkpointNs) + .eq("checkpoint_id", checkpointId) .throwOnError(); - const pendingWritesRows = (pendingWritesRes.data ?? []) as WritesRow[]; - const pendingWrites = await Promise.all( - pendingWritesRows.map(async (row: WritesRow) => { - return [ - row.task_id, - row.channel, - await this.serde.loadsTyped( - row.type ?? "json", - JSON.stringify(row.value) ?? "" - ), - ] as [string, string, unknown]; - }) + const rows = data as WritesRow[]; + return Promise.all( + rows.map(async (row) => [ + row.task_id, + row.channel, + await this.deserializeField(row.type, row.value ?? ""), + ]) ); - - return { - config: finalConfig, - checkpoint: (await this.serde.loadsTyped( - row.type ?? "json", - JSON.stringify(row.checkpoint) - )) as Checkpoint, - metadata: (await this.serde.loadsTyped( - row.type ?? "json", - JSON.stringify(row.metadata) - )) as CheckpointMetadata, - parentConfig: row.parent_checkpoint_id - ? { - configurable: { - thread_id: row.thread_id, - checkpoint_ns, - checkpoint_id: row.parent_checkpoint_id, - }, - } - : undefined, - pendingWrites, - }; } async *list( @@ -187,7 +182,7 @@ export class SupaSaver extends BaseCheckpointSaver { const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns; - let query = this.client.from("langgraph_checkpoints").select("*"); + let query = this.client.from(this.options.checkPointTable).select("*"); if (thread_id !== undefined && thread_id !== null) { query = query.eq("thread_id", thread_id); @@ -265,7 +260,7 @@ export class SupaSaver extends BaseCheckpointSaver { metadata: CheckpointMetadata ): Promise { await this.client - .from("langgraph_checkpoints") + .from(this.options.checkPointTable) .upsert( { thread_id: config.configurable?.thread_id, @@ -302,7 +297,7 @@ export class SupaSaver extends BaseCheckpointSaver { const [, serializedWrite] = this.serde.dumpsTyped(write[1]); await this.client - .from("langgraph_writes") + .from(this.options.writeTable) .upsert( [ { diff --git a/libs/checkpoint-validation/package.json b/libs/checkpoint-validation/package.json index b615f07c..133d15a6 100644 --- a/libs/checkpoint-validation/package.json +++ b/libs/checkpoint-validation/package.json @@ -53,6 +53,7 @@ "@langchain/langgraph-checkpoint-sqlite": "workspace:*", "@langchain/langgraph-checkpoint-supabase": "workspace:*", "@langchain/scripts": ">=0.1.3 <0.2.0", + "@supabase/supabase-js": "^2.46.1", "@testcontainers/mongodb": "^10.13.2", "@testcontainers/postgresql": "^10.13.2", "@tsconfig/recommended": "^1.0.3", diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts index 0ac10e3d..9f87e3c3 100644 --- a/libs/checkpoint-validation/src/tests/supabase_initializer.ts +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -1,17 +1,17 @@ // eslint-disable-next-line import/no-extraneous-dependencies -import { SupaSaver } from "@langchain/langgraph-checkpoint-supabase"; -import { CheckpointerTestInitializer } from "../types.js"; +import { SupabaseSaver } from "@langchain/langgraph-checkpoint-supabase"; import { createClient } from "@supabase/supabase-js"; +import { CheckpointerTestInitializer } from "../types.js"; const SUPABASE_URL = process.env.SUPABASE_URL!; const SUPABASE_KEY = process.env.SUPABASE_KEY!; -export const initializer: CheckpointerTestInitializer = { +export const initializer: CheckpointerTestInitializer = { checkpointerName: "@langchain/langgraph-checkpoint-supabase", async createCheckpointer() { const client = createClient(SUPABASE_URL, SUPABASE_KEY); - return new SupaSaver(client); + return new SupabaseSaver(client); }, async afterAll() { diff --git a/yarn.lock b/yarn.lock index 67a08a52..1b32cab4 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1817,6 +1817,7 @@ __metadata: "@langchain/langgraph-checkpoint-sqlite": "workspace:*" "@langchain/langgraph-checkpoint-supabase": "workspace:*" "@langchain/scripts": ">=0.1.3 <0.2.0" + "@supabase/supabase-js": ^2.46.1 "@swc-node/register": ^1.10.9 "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 @@ -2930,6 +2931,20 @@ __metadata: languageName: node linkType: hard +"@supabase/supabase-js@npm:^2.46.1": + version: 2.46.1 + resolution: "@supabase/supabase-js@npm:2.46.1" + dependencies: + "@supabase/auth-js": 2.65.1 + "@supabase/functions-js": 2.4.3 + "@supabase/node-fetch": 2.6.15 + "@supabase/postgrest-js": 1.16.3 + "@supabase/realtime-js": 2.10.7 + "@supabase/storage-js": 2.7.1 + checksum: 5f8c143124adab36a145c78a1c9799e0fd80598d64904e99e907efa3c6ae1a6f3c95c97b2c9a493690d6761169d9fe138f885c7f3885e376a5928676c616c5fc + languageName: node + linkType: hard + "@swc-node/core@npm:^1.13.3": version: 1.13.3 resolution: "@swc-node/core@npm:1.13.3" From 34567ffc61a5a7c8c8a16c1230295add33ff7302 Mon Sep 17 00:00:00 2001 From: William Overton Date: Wed, 30 Oct 2024 09:55:15 +0000 Subject: [PATCH 10/11] Enable RLS in readme setup --- libs/checkpoint-supabase/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libs/checkpoint-supabase/README.md b/libs/checkpoint-supabase/README.md index 184f81f7..cc531d05 100644 --- a/libs/checkpoint-supabase/README.md +++ b/libs/checkpoint-supabase/README.md @@ -7,7 +7,7 @@ Implementation of a [LangGraph.js](https://github.com/langchain-ai/langgraphjs) Create the following tables in your Supabase database, you can change the table names if required when also setting the `checkPointTable` and `writeTable` options of the `SupabaseSaver` class. > [!CAUTION] -> ⚠️ Make sure to enable RLS policies on the tables! +> Make sure to enable RLS policies on the tables! ```sql create table @@ -42,6 +42,10 @@ create table idx ) ) tablespace pg_default; + +--- Important to disable public access to the tables! +alter table "langgraph_checkpoints" enable row level security; +alter table "langgraph_writes" enable row level security; ``` ## Usage From 9bc1b87d13dd890402ee7dd401f9dfb8f0d13f68 Mon Sep 17 00:00:00 2001 From: William Overton Date: Wed, 30 Oct 2024 10:06:07 +0000 Subject: [PATCH 11/11] Cleanup --- libs/checkpoint-sqlite/src/index.ts | 6 +++--- .../src/tests/supabase_initializer.ts | 15 +-------------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/libs/checkpoint-sqlite/src/index.ts b/libs/checkpoint-sqlite/src/index.ts index 622fee6b..6740bb22 100644 --- a/libs/checkpoint-sqlite/src/index.ts +++ b/libs/checkpoint-sqlite/src/index.ts @@ -1,14 +1,14 @@ +import Database, { Database as DatabaseType } from "better-sqlite3"; import type { RunnableConfig } from "@langchain/core/runnables"; import { BaseCheckpointSaver, type Checkpoint, type CheckpointListOptions, - type CheckpointMetadata, type CheckpointTuple, - type PendingWrite, type SerializerProtocol, + type PendingWrite, + type CheckpointMetadata, } from "@langchain/langgraph-checkpoint"; -import Database, { Database as DatabaseType } from "better-sqlite3"; interface CheckpointRow { checkpoint: string; diff --git a/libs/checkpoint-validation/src/tests/supabase_initializer.ts b/libs/checkpoint-validation/src/tests/supabase_initializer.ts index 9f87e3c3..cfd9899b 100644 --- a/libs/checkpoint-validation/src/tests/supabase_initializer.ts +++ b/libs/checkpoint-validation/src/tests/supabase_initializer.ts @@ -14,20 +14,7 @@ export const initializer: CheckpointerTestInitializer = { return new SupabaseSaver(client); }, - async afterAll() { - // const client = createClient(SUPABASE_URL, SUPABASE_KEY); - // await client - // .from("langgraph_checkpoints") - // .delete() - // .neq("thread_id", "filter-needs-a-value") - // .throwOnError() - // await client - // .from("langgraph_writes") - // .delete() - // .neq("thread_id", "filter-needs-a-value") - // .throwOnError() - }, - + // Reset the tables between groups of tests async destroyCheckpointer() { const client = createClient(SUPABASE_URL, SUPABASE_KEY); await client