diff --git a/packages/rpc-subscriptions/src/__tests__/cached-abortable-iterable-test.ts b/packages/rpc-subscriptions/src/__tests__/cached-abortable-iterable-test.ts index 2f1a9093e1ff..5ecc4b3df700 100644 --- a/packages/rpc-subscriptions/src/__tests__/cached-abortable-iterable-test.ts +++ b/packages/rpc-subscriptions/src/__tests__/cached-abortable-iterable-test.ts @@ -1,7 +1,7 @@ import { getCachedAbortableIterableFactory } from '../cached-abortable-iterable'; describe('getCachedAbortableIterableFactory', () => { - let asyncGenerator: jest.Mock>; + let getAsyncIterable: jest.MockedFn<() => AsyncIterable>; let factory: (...args: unknown[]) => Promise>; let getAbortSignalFromInputArgs: jest.Mock; let getCacheKeyFromInputArgs: jest.Mock; @@ -9,7 +9,7 @@ describe('getCachedAbortableIterableFactory', () => { let onCreateIterable: jest.Mock; beforeEach(() => { jest.useFakeTimers(); - asyncGenerator = jest.fn().mockImplementation(async function* () { + getAsyncIterable = jest.fn().mockImplementation(async function* () { yield await new Promise(() => { /* never resolve */ }); @@ -17,8 +17,10 @@ describe('getCachedAbortableIterableFactory', () => { getAbortSignalFromInputArgs = jest.fn().mockImplementation(() => new AbortController().signal); getCacheKeyFromInputArgs = jest.fn().mockReturnValue('cache-key'); onCacheHit = jest.fn(); - onCreateIterable = jest.fn().mockResolvedValue({ - [Symbol.asyncIterator]: asyncGenerator, + onCreateIterable = jest.fn().mockReturnValue({ + [Symbol.asyncIterator]() { + return getAsyncIterable()[Symbol.asyncIterator](); + }, }); factory = getCachedAbortableIterableFactory({ getAbortSignalFromInputArgs, @@ -137,7 +139,7 @@ describe('getCachedAbortableIterableFactory', () => { it('creates a new iterable for a message given that the prior iterable threw', async () => { expect.assertions(2); let throwFromIterable; - asyncGenerator.mockImplementationOnce(async function* () { + getAsyncIterable.mockImplementationOnce(async function* () { yield await new Promise((_, reject) => { throwFromIterable = reject; }); @@ -155,7 +157,7 @@ describe('getCachedAbortableIterableFactory', () => { it('creates a new iterable for a message given that prior iterable returned', async () => { expect.assertions(1); let returnFromIterable; - asyncGenerator.mockImplementationOnce(async function* () { + getAsyncIterable.mockImplementationOnce(async function* () { try { yield await new Promise((_, reject) => { returnFromIterable = reject; @@ -197,7 +199,7 @@ describe('getCachedAbortableIterableFactory', () => { Promise.all([factory('A'), factory('B')]); expect(onCacheHit).not.toHaveBeenCalled(); await jest.runAllTimersAsync(); - const iterable = asyncGenerator(); + const iterable = getAsyncIterable(); // FIXME: https://github.com/microsoft/TypeScript/issues/11498 // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore @@ -206,7 +208,7 @@ describe('getCachedAbortableIterableFactory', () => { expect(onCacheHit).toHaveBeenCalledWith(iterable, 'B'); }); it('calls `onCacheHit` in the same runloop when the cached iterable is already resolved', () => { - const iterable = asyncGenerator(); + const iterable = getAsyncIterable(); onCreateIterable.mockReturnValue(iterable); Promise.all([factory('A'), factory('B')]); expect(onCacheHit).toHaveBeenCalledWith(iterable, 'B'); @@ -225,7 +227,7 @@ describe('getCachedAbortableIterableFactory', () => { factory('B'); await jest.runAllTimersAsync(); expect(onCacheHit).not.toHaveBeenCalled(); - const iterable = asyncGenerator(); + const iterable = getAsyncIterable(); // FIXME: https://github.com/microsoft/TypeScript/issues/11498 // eslint-disable-next-line @typescript-eslint/ban-ts-comment // @ts-ignore @@ -235,7 +237,7 @@ describe('getCachedAbortableIterableFactory', () => { }); it('calls `onCacheHit` in different runloops when the cached iterable is already resolved', async () => { expect.assertions(1); - const iterable = asyncGenerator(); + const iterable = getAsyncIterable(); onCreateIterable.mockReturnValue(iterable); await factory('A'); await factory('B'); diff --git a/packages/rpc-subscriptions/src/__tests__/rpc-subscriptions-coalescer-test.ts b/packages/rpc-subscriptions/src/__tests__/rpc-subscriptions-coalescer-test.ts index 87ca87c3d77c..b41ea77500da 100644 --- a/packages/rpc-subscriptions/src/__tests__/rpc-subscriptions-coalescer-test.ts +++ b/packages/rpc-subscriptions/src/__tests__/rpc-subscriptions-coalescer-test.ts @@ -8,21 +8,23 @@ interface TestRpcSubscriptionNotifications { } describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { - let asyncGenerator: jest.Mock>; + let getAsyncIterable: jest.MockedFn<() => AsyncIterable>; let createPendingSubscription: jest.Mock; let getDeduplicationKey: jest.Mock; let subscribe: jest.Mock; let rpcSubscriptions: RpcSubscriptions; beforeEach(() => { jest.useFakeTimers(); - asyncGenerator = jest.fn().mockImplementation(async function* () { + getAsyncIterable = jest.fn().mockImplementation(async function* () { yield await new Promise(() => { /* never resolve */ }); }); getDeduplicationKey = jest.fn(); - subscribe = jest.fn().mockResolvedValue({ - [Symbol.asyncIterator]: asyncGenerator, + subscribe = jest.fn().mockReturnValue({ + [Symbol.asyncIterator]() { + return getAsyncIterable()[Symbol.asyncIterator](); + }, }); createPendingSubscription = jest.fn().mockReturnValue({ subscribe }); rpcSubscriptions = getRpcSubscriptionsWithSubscriptionCoalescing({ @@ -98,7 +100,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('publishes the same messages through both iterables', async () => { expect.assertions(2); - asyncGenerator.mockImplementation(async function* () { + getAsyncIterable.mockImplementation(async function* () { yield Promise.resolve('hello'); }); const iterableA = await rpcSubscriptions @@ -117,7 +119,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('publishes the final message when the iterable returns', async () => { expect.assertions(1); - asyncGenerator.mockImplementation( + getAsyncIterable.mockImplementation( // eslint-disable-next-line require-yield async function* () { return await Promise.resolve('hello'); @@ -132,7 +134,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('aborting a subscription causes it to return', async () => { expect.assertions(1); - asyncGenerator.mockImplementation(async function* () { + getAsyncIterable.mockImplementation(async function* () { yield Promise.resolve('hello'); }); const abortController = new AbortController(); @@ -146,7 +148,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('aborting one subscription does not abort the other', async () => { expect.assertions(1); - asyncGenerator.mockImplementation(async function* () { + getAsyncIterable.mockImplementation(async function* () { yield Promise.resolve('hello'); }); const abortControllerA = new AbortController(); @@ -258,7 +260,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('aborting a subscription causes it to return', async () => { expect.assertions(1); - asyncGenerator.mockImplementation(async function* () { + getAsyncIterable.mockImplementation(async function* () { yield Promise.resolve('hello'); }); const abortController = new AbortController(); @@ -272,7 +274,7 @@ describe('getRpcSubscriptionsWithSubscriptionCoalescing', () => { }); it('aborting one subscription does not abort the other', async () => { expect.assertions(1); - asyncGenerator.mockImplementation(async function* () { + getAsyncIterable.mockImplementation(async function* () { yield Promise.resolve('hello'); }); const abortControllerA = new AbortController();