From e9364e3baa3aef08306eabdf21977ff7c5b4699c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sat, 23 Nov 2024 19:06:57 +0100 Subject: [PATCH] Improved server performances and memory consumption on error (#385) --- .../backend/_asyncio/_flow_control.py | 2 + .../backend/_asyncio/datagram/endpoint.py | 10 +- .../backend/_asyncio/stream/socket.py | 122 +++++++++++------- .../lowlevel/api_async/endpoints/stream.py | 12 +- .../lowlevel/api_async/servers/stream.py | 2 + .../lowlevel/api_sync/endpoints/stream.py | 12 +- .../test_asyncio_backend/test_stream.py | 72 +++++++++-- 7 files changed, 162 insertions(+), 70 deletions(-) diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/_flow_control.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/_flow_control.py index 92708e6d..dfe540b3 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/_flow_control.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/_flow_control.py @@ -22,6 +22,7 @@ import asyncio import collections import errno as _errno +import traceback import types from collections.abc import Callable @@ -95,6 +96,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.__connection_lost_exception = exc if exc is not None: self.__connection_lost_exception_tb = exc.__traceback__ + self.__loop.call_soon(traceback.clear_frames, exc.__traceback__) for waiter in self.__drain_waiters: if not waiter.done(): diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/datagram/endpoint.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/datagram/endpoint.py index 7ab9f79e..a1ffdfdb 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/datagram/endpoint.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/datagram/endpoint.py @@ -28,6 +28,7 @@ import asyncio.trsock import errno as _errno import socket as _socket +import traceback import warnings from typing import Any, final @@ -136,15 +137,15 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: return self.__transport.get_extra_info(name, default) def __check_exceptions(self) -> None: + exc: BaseException | None = None try: exc = self.__exception_queue.get_nowait() except asyncio.QueueEmpty: pass else: - try: - raise exc - finally: - del exc + raise exc + finally: + del exc class DatagramEndpointProtocol(asyncio.DatagramProtocol): @@ -203,6 +204,7 @@ def datagram_received(self, data: bytes, addr: tuple[Any, ...]) -> None: def error_received(self, exc: Exception) -> None: if self.__transport is not None: + self.__loop.call_soon(traceback.clear_frames, exc.__traceback__) self.__exception_queue.put_nowait(exc) self.__recv_queue.put_nowait(None) # Wake up endpoint diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py index 708fc9d8..e32f1ba4 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py @@ -22,10 +22,11 @@ import asyncio import asyncio.trsock import errno as _errno +import traceback import warnings from collections.abc import Callable, Iterable, Mapping from types import MappingProxyType, TracebackType -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, Any, final, overload from ......exceptions import UnsupportedOperation from ..... import _utils, socket as socket_tools @@ -142,6 +143,7 @@ class StreamReaderBufferedProtocol(asyncio.BufferedProtocol): "__buffer", "__buffer_view", "__buffer_nbytes_written", + "__external_buffer_view", "__transport", "__closed", "__write_flow", @@ -169,10 +171,11 @@ def __init__( self.__loop: asyncio.AbstractEventLoop = loop self.__buffer: bytearray | None = bytearray(self.max_size) self.__buffer_view: memoryview = memoryview(self.__buffer) + self.__external_buffer_view: WriteableBuffer | None = None self.__buffer_nbytes_written: int = 0 self.__transport: asyncio.Transport | None = None self.__closed: asyncio.Future[None] = loop.create_future() - self.__read_waiter: asyncio.Future[None] | None = None + self.__read_waiter: asyncio.Future[int | None] | None = None self.__write_flow: WriteFlowControl self.__read_paused: bool = False self.__connection_lost: bool = False @@ -203,6 +206,7 @@ def connection_lost(self, exc: Exception | None) -> None: self.__eof_reached = True else: self.__connection_lost_exception_tb = exc.__traceback__ + self.__loop.call_soon(traceback.clear_frames, exc.__traceback__) self.__buffer_nbytes_written = 0 self.__buffer = None @@ -217,6 +221,8 @@ def connection_lost(self, exc: Exception | None) -> None: self.__transport = None def get_buffer(self, sizehint: int) -> WriteableBuffer: + if (external_buffer_view := self.__external_buffer_view) is not None: + return external_buffer_view # Ignore sizehint, the buffer is already at its maximum size. # Returns unused buffer part if self.__buffer is None: @@ -227,12 +233,21 @@ def buffer_updated(self, nbytes: int) -> None: assert not self.__connection_lost, "buffer_updated() after connection_lost()" # nosec assert_used assert not self.__eof_reached, "buffer_updated() after eof_received()" # nosec assert_used + if self.__external_buffer_view is not None: + # Early remove to prevent using this buffer between this point and the wakeup of the task. + self.__external_buffer_view = None + self._read_waiter_fut(lambda waiter: waiter.set_result(nbytes)) + # Call to _maybe_pause_transport() is unnecessary: Did not write in internal buffer. + return + self.__buffer_nbytes_written += nbytes assert 0 <= self.__buffer_nbytes_written <= self.__buffer_view.nbytes # nosec assert_used self._wakeup_read_waiter(None) self._maybe_pause_transport() def eof_received(self) -> bool: + # Early remove to prevent using this buffer between this point and the wakeup of the task. + self.__external_buffer_view = None self.__eof_reached = True self._wakeup_read_waiter(None) if self.__over_ssl: @@ -243,14 +258,13 @@ def eof_received(self) -> bool: return True async def receive_data(self, bufsize: int, /) -> bytes: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + self._check_for_connection_lost() if bufsize == 0: return b"" if bufsize < 0: raise ValueError("'bufsize' must be a positive or null integer") - blocked: bool = await self._wait_for_data("receive_data") + await self._wait_for_data("receive_data", None) nbytes_written = self.__buffer_nbytes_written if nbytes_written: @@ -264,66 +278,80 @@ async def receive_data(self, bufsize: int, /) -> bytes: else: data = b"" self._maybe_resume_transport() - if not blocked: - await TaskUtils.cancel_shielded_coro_yield() return data async def receive_data_into(self, buffer: WriteableBuffer, /) -> int: - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) - with memoryview(buffer).cast("B") as buffer: + self._check_for_connection_lost() + + with memoryview(buffer) as buffer: if not buffer: return 0 - blocked: bool = await self._wait_for_data("receive_data_into") - - nbytes_written = self.__buffer_nbytes_written - if nbytes_written: - protocol_buffer_written = self.__buffer_view[:nbytes_written] - bufsize_offset = nbytes_written - buffer.nbytes - if bufsize_offset > 0: - nbytes_written = buffer.nbytes - buffer[:] = protocol_buffer_written[:nbytes_written] - protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:] - self.__buffer_nbytes_written = bufsize_offset - else: - buffer[:nbytes_written] = protocol_buffer_written - self.__buffer_nbytes_written = 0 + with buffer.cast("B") if buffer.itemsize != 1 else buffer as buffer: + nbytes_written = await self._wait_for_data("receive_data_into", buffer) + if nbytes_written is not None: + # Call to _maybe_resume_transport() is unnecessary: Did not write in internal buffer. + return nbytes_written + + nbytes_written = self.__buffer_nbytes_written + if nbytes_written: + protocol_buffer_written = self.__buffer_view[:nbytes_written] + bufsize_offset = nbytes_written - buffer.nbytes + if bufsize_offset > 0: + nbytes_written = buffer.nbytes + buffer[:] = protocol_buffer_written[:nbytes_written] + protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:] + self.__buffer_nbytes_written = bufsize_offset + else: + buffer[:nbytes_written] = protocol_buffer_written + self.__buffer_nbytes_written = 0 self._maybe_resume_transport() - if not blocked: - await TaskUtils.cancel_shielded_coro_yield() return nbytes_written - async def _wait_for_data(self, requester: str) -> bool: - if self.__read_waiter is not None: - raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data") - - if self.__buffer_nbytes_written or self.__eof_reached: - return False + @overload + async def _wait_for_data(self, requester: str, external_buffer: None) -> None: ... - assert not self.__read_paused, "transport reading is paused" # nosec assert_used + @overload + async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer) -> int | None: ... - if self.__transport is None: - # happening if transport.pause_reading() raises NotImplementedError - raise _utils.error_from_errno(_errno.ECONNABORTED) + async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer | None) -> int | None: + if self.__read_waiter is not None: + raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data") self.__read_waiter = self.__loop.create_future() try: - await self.__read_waiter + nbytes_written_in_external_buffer: int | None + if self.__buffer_nbytes_written or self.__eof_reached: + self.__read_waiter.set_result(None) + await TaskUtils.coro_yield() + nbytes_written_in_external_buffer = None + else: + assert not self.__read_paused, "transport reading is paused" # nosec assert_used + + if self.__transport is None: + # happening if transport.pause_reading() raises NotImplementedError + raise _utils.error_from_errno(_errno.ECONNABORTED) + + self.__external_buffer_view = external_buffer + try: + nbytes_written_in_external_buffer = await self.__read_waiter + finally: + self.__external_buffer_view = None finally: self.__read_waiter = None - if self.__connection_lost_exception is not None: - raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) - return True - def _wakeup_read_waiter(self, exc: Exception | None) -> None: + if nbytes_written_in_external_buffer is None: + self._check_for_connection_lost() + return nbytes_written_in_external_buffer + + def _read_waiter_fut(self, set_result_cb: Callable[[asyncio.Future[int | None]], None]) -> None: if (waiter := self.__read_waiter) is not None: if not waiter.done(): - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + set_result_cb(waiter) + + def _wakeup_read_waiter(self, exc: Exception | None) -> None: + self._read_waiter_fut(lambda waiter: waiter.set_result(None) if exc is None else waiter.set_exception(exc)) def _get_read_buffer_size(self) -> int: return self.__buffer_nbytes_written @@ -363,6 +391,10 @@ def _maybe_resume_transport(self) -> None: transport.resume_reading() self.__read_paused = False + def _check_for_connection_lost(self) -> None: + if self.__connection_lost_exception is not None: + raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb) + def pause_writing(self) -> None: self.__write_flow.pause_writing() diff --git a/src/easynetwork/lowlevel/api_async/endpoints/stream.py b/src/easynetwork/lowlevel/api_async/endpoints/stream.py index e766f8fa..9a3a4eed 100644 --- a/src/easynetwork/lowlevel/api_async/endpoints/stream.py +++ b/src/easynetwork/lowlevel/api_async/endpoints/stream.py @@ -94,8 +94,10 @@ async def aclose(self) -> None: """ Closes the endpoint. """ - await self.__transport.aclose() - self.__receiver.clear() + try: + await self.__transport.aclose() + finally: + self.__receiver.clear() async def recv_packet(self) -> _T_ReceivedPacket: """ @@ -269,8 +271,10 @@ async def aclose(self) -> None: Closes the endpoint. """ with self.__send_guard: - await self.__transport.aclose() - self.__receiver.clear() + try: + await self.__transport.aclose() + finally: + self.__receiver.clear() async def send_packet(self, packet: _T_SentPacket) -> None: """ diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 729f33b6..b26d6796 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -220,6 +220,8 @@ async def __client_coroutine( case _: # pragma: no cover assert_never(self.__protocol) + # NOTE: It is safe to clear the consumer before the transport here. + # There is no task reading the transport at this point. task_exit_stack.callback(consumer.clear) request_handler_generator = client_connected_cb( diff --git a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py index 98e80f4e..3ddee363 100644 --- a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py +++ b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py @@ -92,8 +92,10 @@ def close(self) -> None: """ Closes the endpoint. """ - self.__transport.close() - self.__receiver.clear() + try: + self.__transport.close() + finally: + self.__receiver.clear() def recv_packet(self, *, timeout: float | None = None) -> _T_ReceivedPacket: """ @@ -270,8 +272,10 @@ def close(self) -> None: """ Closes the endpoint. """ - self.__transport.close() - self.__receiver.clear() + try: + self.__transport.close() + finally: + self.__receiver.clear() def send_packet(self, packet: _T_SentPacket, *, timeout: float | None = None) -> None: """ diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index 5befbe4d..efa26c60 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -1118,7 +1118,8 @@ async def test____receive_data____default( self.write_in_protocol_buffer(protocol, b"abcdef") # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"abcdef" @@ -1135,8 +1136,9 @@ async def test____receive_data____partial_read( self.write_in_protocol_buffer(protocol, b"abcdef") # Act - first = await data_receiver(protocol, 3) - second = await data_receiver(protocol, 3) + async with asyncio.timeout(5): + first = await data_receiver(protocol, 3) + second = await data_receiver(protocol, 3) # Assert assert first == b"abc" @@ -1144,22 +1146,64 @@ async def test____receive_data____partial_read( mock_asyncio_transport.resume_reading.assert_not_called() @pytest.mark.asyncio - async def test____receive_data____buffer_updated_several_times( + @pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}") + @pytest.mark.parametrize("data_receiver", ["data"], indirect=True) + async def test____receive_data____owned_data____buffer_updated_several_times( self, + blocking: bool, protocol: StreamReaderBufferedProtocol, mock_asyncio_transport: MagicMock, data_receiver: _ProtocolDataReceiver, ) -> None: # Arrange event_loop = asyncio.get_running_loop() - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + if blocking: + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + else: + self.write_in_protocol_buffer(protocol, b"abc") + self.write_in_protocol_buffer(protocol, b"def") # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"abcdef" + assert protocol._get_read_buffer_size() == 0 + mock_asyncio_transport.resume_reading.assert_not_called() + + @pytest.mark.asyncio + @pytest.mark.parametrize("blocking", [False, True], ids=lambda p: f"blocking=={p}") + @pytest.mark.parametrize("data_receiver", ["buffer"], indirect=True) + async def test____receive_data____into_buffer____buffer_updated_several_times( + self, + blocking: bool, + protocol: StreamReaderBufferedProtocol, + mock_asyncio_transport: MagicMock, + data_receiver: _ProtocolDataReceiver, + ) -> None: + # Arrange + event_loop = asyncio.get_running_loop() + if blocking: + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") + event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"def") + else: + self.write_in_protocol_buffer(protocol, b"abc") + self.write_in_protocol_buffer(protocol, b"def") + + # Act + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) + + # Assert + if blocking: + assert data == b"abc" + assert protocol._get_read_buffer_size() == 3 # should be b"def" + assert (await data_receiver(protocol, 1024)) == b"def" + else: + assert data == b"abcdef" + assert protocol._get_read_buffer_size() == 0 mock_asyncio_transport.resume_reading.assert_not_called() @pytest.mark.asyncio @@ -1172,7 +1216,8 @@ async def test____receive_data____null_bufsize( # Arrange # Act - data = await data_receiver(protocol, 0) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 0) # Assert assert data == b"" @@ -1208,7 +1253,8 @@ def protocol_eof_handler() -> None: protocol_eof_handler() # Act - data = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + data = await data_receiver(protocol, 1024) # Assert assert data == b"" @@ -1243,14 +1289,14 @@ async def test____receive_data____connection_reset( data_receiver: _ProtocolDataReceiver, ) -> None: # Arrange - event_loop = asyncio.get_running_loop() - event_loop.call_soon(self.write_in_protocol_buffer, protocol, b"abc") - event_loop.call_soon(protocol.connection_lost, None) + self.write_in_protocol_buffer(protocol, b"abc") + protocol.connection_lost(None) # Act & Assert for _ in range(3): with pytest.raises(ConnectionResetError): - _ = await data_receiver(protocol, 1024) + async with asyncio.timeout(5): + _ = await data_receiver(protocol, 1024) @pytest.mark.asyncio async def test____receive_data____invalid_bufsize(