From e54eb53a6952381428fb321b9e695932f8dd92f2 Mon Sep 17 00:00:00 2001 From: BdEgh Date: Mon, 26 Feb 2024 04:01:37 +0500 Subject: [PATCH] Read stream for downloads --- playwright/_impl/_artifact.py | 7 ++++++- playwright/_impl/_download.py | 6 +++++- playwright/_impl/_stream.py | 9 ++++++++- playwright/async_api/_generated.py | 11 ++++++++++- playwright/sync_api/_generated.py | 11 ++++++++++- tests/async/test_download.py | 30 ++++++++++++++++++++++++++++++ 6 files changed, 69 insertions(+), 5 deletions(-) diff --git a/playwright/_impl/_artifact.py b/playwright/_impl/_artifact.py index 63833fe04..f61aca489 100644 --- a/playwright/_impl/_artifact.py +++ b/playwright/_impl/_artifact.py @@ -14,7 +14,7 @@ import pathlib from pathlib import Path -from typing import Dict, Optional, Union, cast +from typing import AsyncIterator, Dict, Optional, Union, cast from playwright._impl._connection import ChannelOwner, from_channel from playwright._impl._helper import Error, make_dirs_for_file, patch_error_message @@ -41,6 +41,11 @@ async def save_as(self, path: Union[str, Path]) -> None: make_dirs_for_file(path) await stream.save_as(path) + async def read_stream(self) -> AsyncIterator[bytes]: + stream = cast(Stream, from_channel(await self._channel.send("stream"))) + async for chunk in stream.read_stream(): + yield chunk + async def failure(self) -> Optional[str]: return patch_error_message(await self._channel.send("failure")) diff --git a/playwright/_impl/_download.py b/playwright/_impl/_download.py index ffaf5cacd..64ffb56a0 100644 --- a/playwright/_impl/_download.py +++ b/playwright/_impl/_download.py @@ -14,7 +14,7 @@ import pathlib from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, AsyncIterator, Optional, Union from playwright._impl._artifact import Artifact @@ -60,5 +60,9 @@ async def path(self) -> pathlib.Path: async def save_as(self, path: Union[str, Path]) -> None: await self._artifact.save_as(path) + async def read_stream(self) -> AsyncIterator[bytes]: + async for chunk in self._artifact.read_stream(): + yield chunk + async def cancel(self) -> None: return await self._artifact.cancel() diff --git a/playwright/_impl/_stream.py b/playwright/_impl/_stream.py index d27427589..c34a520ab 100644 --- a/playwright/_impl/_stream.py +++ b/playwright/_impl/_stream.py @@ -14,7 +14,7 @@ import base64 from pathlib import Path -from typing import Dict, Union +from typing import AsyncIterator, Dict, Union from playwright._impl._connection import ChannelOwner @@ -36,6 +36,13 @@ async def save_as(self, path: Union[str, Path]) -> None: ) await self._loop.run_in_executor(None, lambda: file.close()) + async def read_stream(self) -> AsyncIterator[bytes]: + while True: + binary = await self._channel.send("read", {"size": 1024 * 1024}) + if not binary: + break + yield base64.b64decode(binary) + async def read_all(self) -> bytes: binary = b"" while True: diff --git a/playwright/async_api/_generated.py b/playwright/async_api/_generated.py index e484baa09..d8e6e0b09 100644 --- a/playwright/async_api/_generated.py +++ b/playwright/async_api/_generated.py @@ -15,7 +15,7 @@ import pathlib import typing -from typing import Literal +from typing import AsyncIterator, Literal from playwright._impl._accessibility import Accessibility as AccessibilityImpl from playwright._impl._api_structures import ( @@ -6852,6 +6852,15 @@ async def save_as(self, path: typing.Union[str, pathlib.Path]) -> None: return mapping.from_maybe_impl(await self._impl_obj.save_as(path=path)) + async def read_stream(self) -> AsyncIterator[bytes]: + """Download.read_stream + + Yields a readable stream chunks for a successful download, or throws for a failed/canceled download. + """ + + async for chunk in mapping.from_maybe_impl(self._impl_obj.read_stream()): + yield chunk + async def cancel(self) -> None: """Download.cancel diff --git a/playwright/sync_api/_generated.py b/playwright/sync_api/_generated.py index a861367be..51e7d3d09 100644 --- a/playwright/sync_api/_generated.py +++ b/playwright/sync_api/_generated.py @@ -15,7 +15,7 @@ import pathlib import typing -from typing import Literal +from typing import Iterable, Literal from playwright._impl._accessibility import Accessibility as AccessibilityImpl from playwright._impl._api_structures import ( @@ -6962,6 +6962,15 @@ def save_as(self, path: typing.Union[str, pathlib.Path]) -> None: return mapping.from_maybe_impl(self._sync(self._impl_obj.save_as(path=path))) + def read_stream(self) -> Iterable[bytes]: + """Download.read_stream + + Yields a readable stream chunks for a successful download, or throws for a failed/canceled download. + """ + + for chunk in mapping.from_maybe_impl(self._sync(self._impl_obj.read_stream())): + yield chunk + def cancel(self) -> None: """Download.cancel diff --git a/tests/async/test_download.py b/tests/async/test_download.py index 96d06820e..8faaf8fe7 100644 --- a/tests/async/test_download.py +++ b/tests/async/test_download.py @@ -43,8 +43,16 @@ def handle_download_with_file_name(request: TestServerRequest) -> None: request.write(b"Hello world") request.finish() + def handle_download_big_file(request: TestServerRequest) -> None: + request.setHeader("Content-Type", "application/octet-stream") + request.setHeader("Content-Disposition", "attachment") + request.write(b"A" * 1024 * 1024) + request.write(b"B") + request.finish() + server.set_route("/download", handle_download) server.set_route("/downloadWithFilename", handle_download_with_file_name) + server.set_route("/downloadBigFile", handle_download_big_file) yield @@ -381,3 +389,25 @@ def handle_download(request: TestServerRequest) -> None: await download.cancel() assert await download.failure() == "canceled" await page.close() + + +async def test_stream_reading(browser: Browser, server: Server) -> None: + page = await browser.new_page(accept_downloads=True) + await page.set_content(f'download') + async with page.expect_download() as download_info: + await page.click("a") + download = await download_info.value + data = b"".join([chunk async for chunk in download.read_stream()]) + assert data == b"Hello world" + await page.close() + + +async def test_stream_reading_multiple_chunks(browser: Browser, server: Server) -> None: + page = await browser.new_page(accept_downloads=True) + await page.set_content(f'download') + async with page.expect_download() as download_info: + await page.click("a") + download = await download_info.value + data = b"".join([chunk async for chunk in download.read_stream()]) + assert data == b"A" * 1024 * 1024 + b"B" + await page.close()