diff --git a/src/core/streaming/__tests__/cachedStreamFetcher.spec.ts b/src/core/streaming/__tests__/cachedStreamFetcher.spec.ts index 9ba37f8c..28001611 100644 --- a/src/core/streaming/__tests__/cachedStreamFetcher.spec.ts +++ b/src/core/streaming/__tests__/cachedStreamFetcher.spec.ts @@ -2,11 +2,12 @@ import { RequestPool } from '@/src/core/streaming/requestPool'; import { CachedStreamFetcher, + sliceChunks, StopSignal, } from '@/src/core/streaming/cachedStreamFetcher'; import { describe, expect, it } from 'vitest'; -describe('ResumableFetcher', () => { +describe('CachedStreamFetcher', () => { it('should support stopping and resuming', async () => { const pool = new RequestPool(); const fetcher = new CachedStreamFetcher( @@ -51,3 +52,24 @@ describe('ResumableFetcher', () => { fetcher.close(); }); }); + +describe('sliceChunks', () => { + it('should work', () => { + expect(sliceChunks([new Uint8Array([1, 2, 3])], 0)).toEqual([]); + expect(sliceChunks([new Uint8Array([1, 2, 3])], 1)).toEqual([ + new Uint8Array([1]), + ]); + expect(sliceChunks([new Uint8Array([1])], 1)).toEqual([ + new Uint8Array([1]), + ]); + expect(sliceChunks([new Uint8Array([1, 2])], 1)).toEqual([ + new Uint8Array([1]), + ]); + expect(sliceChunks([new Uint8Array([1, 2])], 3)).toEqual([ + new Uint8Array([1, 2]), + ]); + expect( + sliceChunks([new Uint8Array([1, 2]), new Uint8Array([3, 4])], 3) + ).toEqual([new Uint8Array([1, 2]), new Uint8Array([3])]); + }); +}); diff --git a/src/core/streaming/cachedStreamFetcher.ts b/src/core/streaming/cachedStreamFetcher.ts index f98a63e5..a290525b 100644 --- a/src/core/streaming/cachedStreamFetcher.ts +++ b/src/core/streaming/cachedStreamFetcher.ts @@ -5,6 +5,7 @@ import { HttpNotFound, } from '@/src/core/streaming/httpCodes'; import { Fetcher, FetcherInit } from '@/src/core/streaming/types'; +import { parseContentRangeHeader } from '@/src/utils/parseContentRangeHeader'; import { Maybe } from '@/src/types'; type FetchFunction = typeof fetch; @@ -17,6 +18,24 @@ export interface CachedStreamFetcherRequestInit extends RequestInit { export const StopSignal = Symbol('StopSignal'); +export function sliceChunks(chunks: Uint8Array[], start: number) { + const newChunks: Uint8Array[] = []; + let size = 0; + for (let i = 0; i < chunks.length && size < start; i++) { + const chunk = chunks[i]; + if (size + chunk.length > start) { + const offset = start - size; + const newChunk = chunk.slice(0, offset); + newChunks.push(newChunk); + size += newChunk.length; + } else { + newChunks.push(chunk); + size += chunk.length; + } + } + return newChunks; +} + /** * A cached stream fetcher that caches a URI stream. * @@ -94,6 +113,9 @@ export class CachedStreamFetcher implements Fetcher { if (!response.body) throw new Error('Did not receive a response body'); const noMoreContent = response.headers.get('content-length') === '0'; + const contentRange = parseContentRangeHeader( + response.headers.get('content-range') + ); const rangeNotSatisfiable = response.status === HTTP_STATUS_REQUESTED_RANGE_NOT_SATISFIABLE; @@ -109,7 +131,15 @@ export class CachedStreamFetcher implements Fetcher { throw new HttpNotFound(this.request.toString()); } - if (!noMoreContent && response.status !== HTTP_STATUS_PARTIAL_CONTENT) { + if (response.status === HTTP_STATUS_PARTIAL_CONTENT) { + if (contentRange.type === 'invalid-range') + throw new Error('Invalid content-range header'); + if (contentRange.type === 'unsatisfied-range') + throw new Error('Range could not be satisfied'); + + const { start } = contentRange; + this.chunks = sliceChunks(this.chunks, start); + } else if (!noMoreContent) { this.chunks = []; } diff --git a/src/utils/__tests__/parseContentRangeHeader.spec.ts b/src/utils/__tests__/parseContentRangeHeader.spec.ts new file mode 100644 index 00000000..fe397107 --- /dev/null +++ b/src/utils/__tests__/parseContentRangeHeader.spec.ts @@ -0,0 +1,46 @@ +import { parseContentRangeHeader } from '@/src/utils/parseContentRangeHeader'; +import { describe, expect, it } from 'vitest'; + +describe('parseContentRangeHeader', () => { + it('should handle valid ranges', () => { + let range = parseContentRangeHeader('bytes 0-1/123'); + expect(range.type).toEqual('range'); + if (range.type !== 'range') return; // ts can't narrow on expect() + + expect(range.start).toEqual(0); + expect(range.end).toEqual(1); + expect(range.length).toEqual(123); + + range = parseContentRangeHeader('bytes 2-5/*'); + expect(range.type).toEqual('range'); + if (range.type !== 'range') return; // ts can't narrow on expect() + + expect(range.start).toEqual(2); + expect(range.end).toEqual(5); + expect(range.length).to.be.null; + }); + + it('should handle unsatisfied ranges', () => { + const range = parseContentRangeHeader('bytes */12'); + expect(range.type).toEqual('unsatisfied-range'); + if (range.type !== 'unsatisfied-range') return; // ts can't narrow on expect() + + expect(range.length).toEqual(12); + }); + + it('should handle invalid ranges', () => { + [ + '', + 'bytes', + 'bytes */*', + 'byte 0-1/2', + 'bytes 1-0/2', + 'bytes 0-1/1', + 'bytes 1-3/2', + 'bytes 1-/2', + 'bytes -1/2', + ].forEach((range) => { + expect(parseContentRangeHeader(range).type).toEqual('invalid-range'); + }); + }); +}); diff --git a/src/utils/parseContentRangeHeader.ts b/src/utils/parseContentRangeHeader.ts new file mode 100644 index 00000000..5590d9b1 --- /dev/null +++ b/src/utils/parseContentRangeHeader.ts @@ -0,0 +1,39 @@ +const CONTENT_RANGE_REGEXP = + /^bytes (?(?\d+)-(?\d+)|\*)\/(?\d+|\*)$/; + +export type ContentRange = + | { type: 'invalid-range' } + | { type: 'unsatisfied-range'; length: number } + | { type: 'range'; start: number; end: number; length: number | null }; + +/** + * Parses a Content-Range header. + * + * Only supports bytes ranges. + * @param headerValue + * @returns + */ +export function parseContentRangeHeader( + headerValue: string | null +): ContentRange { + if (!headerValue) return { type: 'invalid-range' }; + + const match = CONTENT_RANGE_REGEXP.exec(headerValue); + const groups = match?.groups; + if (!groups) return { type: 'invalid-range' }; + + const length = groups.length === '*' ? null : parseInt(groups.length, 10); + + if (groups.range === '*') { + if (length === null) return { type: 'invalid-range' }; + return { type: 'unsatisfied-range', length }; + } + + const start = parseInt(groups.start, 10); + const end = parseInt(groups.end, 10); + + if (end < start || (length !== null && length <= end)) + return { type: 'invalid-range' }; + + return { type: 'range', start, end, length }; +}