diff --git a/src/tson.test.ts b/src/tson.test.ts index 452bb9c6..e211888d 100644 --- a/src/tson.test.ts +++ b/src/tson.test.ts @@ -1,5 +1,6 @@ import { expect, test } from "vitest"; +import { expectError } from "./testUtils.js"; import { createTson } from "./tson.js"; import { TsonType } from "./types.js"; @@ -31,3 +32,39 @@ test("duplicate keys", () => { '"Multiple handlers for key string found"', ); }); + +test("no max call stack", () => { + const t = createTson({ + types: [], + }); + + const expected: Record = {}; + expected["a"] = expected; + + // stringify should fail b/c of JSON limitations + const err = expectError(() => t.stringify(expected)); + + expect(err.message).toMatchInlineSnapshot('"Circular reference detected"'); +}); + +test("allow duplicate objects", () => { + const t = createTson({ + types: [], + }); + + const obj = { + a: 1, + b: 2, + c: 3, + }; + + const expected = { + a: obj, + b: obj, + c: obj, + }; + + const actual = t.deserialize(t.serialize(expected)); + + expect(actual).toEqual(expected); +}); diff --git a/src/tson.ts b/src/tson.ts index 928890fc..b2986c2f 100644 --- a/src/tson.ts +++ b/src/tson.ts @@ -75,6 +75,19 @@ export function createTsonStringify(opts: TsonOptions): TsonStringifyFn { JSON.stringify(serializer(obj), null, space)) as TsonStringifyFn; } +export class CircularReferenceError extends Error { + /** + * The circular reference that was found + */ + public readonly value; + + constructor(value: unknown) { + super(`Circular reference detected`); + this.name = this.constructor.name; + this.value = value; + } +} + export function createTsonSerialize(opts: TsonOptions): TsonSerializeFn { const handlers = (() => { const types = opts.types.map((handler) => { @@ -124,24 +137,49 @@ export function createTsonSerialize(opts: TsonOptions): TsonSerializeFn { const [nonPrimitive, byPrimitive] = handlers; const walker: WalkerFactory = (nonce) => { + const seen = new WeakSet(); + const cache = new WeakMap(); + const walk: WalkFn = (value) => { const type = typeof value; + const isComplex = !!value && type === "object"; + + if (isComplex) { + if (seen.has(value)) { + const cached = cache.get(value); + if (!cached) { + throw new CircularReferenceError(value); + } + + return cached; + } + + seen.add(value); + } + + const cacheAndReturn = (result: unknown) => { + if (isComplex) { + cache.set(value, result); + } + + return result; + }; const primitiveHandler = byPrimitive[type]; if ( primitiveHandler && (!primitiveHandler.test || primitiveHandler.test(value)) ) { - return primitiveHandler.$serialize(value, nonce, walk); + return cacheAndReturn(primitiveHandler.$serialize(value, nonce, walk)); } for (const handler of nonPrimitive) { if (handler.test(value)) { - return handler.$serialize(value, nonce, walk); + return cacheAndReturn(handler.$serialize(value, nonce, walk)); } } - return mapOrReturn(value, walk); + return cacheAndReturn(mapOrReturn(value, walk)); }; return walk;