Skip to content

Commit

Permalink
Improved server performances and memory consumption on error (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Nov 23, 2024
1 parent 1956c23 commit e9364e3
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import asyncio
import collections
import errno as _errno
import traceback
import types
from collections.abc import Callable

Expand Down Expand Up @@ -95,6 +96,7 @@ def connection_lost(self, exc: Exception | None) -> None:
self.__connection_lost_exception = exc
if exc is not None:
self.__connection_lost_exception_tb = exc.__traceback__
self.__loop.call_soon(traceback.clear_frames, exc.__traceback__)

for waiter in self.__drain_waiters:
if not waiter.done():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import asyncio.trsock
import errno as _errno
import socket as _socket
import traceback
import warnings
from typing import Any, final

Expand Down Expand Up @@ -136,15 +137,15 @@ def get_extra_info(self, name: str, default: Any = None) -> Any:
return self.__transport.get_extra_info(name, default)

def __check_exceptions(self) -> None:
exc: BaseException | None = None
try:
exc = self.__exception_queue.get_nowait()
except asyncio.QueueEmpty:
pass
else:
try:
raise exc
finally:
del exc
raise exc
finally:
del exc


class DatagramEndpointProtocol(asyncio.DatagramProtocol):
Expand Down Expand Up @@ -203,6 +204,7 @@ def datagram_received(self, data: bytes, addr: tuple[Any, ...]) -> None:

def error_received(self, exc: Exception) -> None:
if self.__transport is not None:
self.__loop.call_soon(traceback.clear_frames, exc.__traceback__)
self.__exception_queue.put_nowait(exc)
self.__recv_queue.put_nowait(None) # Wake up endpoint

Expand Down
122 changes: 77 additions & 45 deletions src/easynetwork/lowlevel/api_async/backend/_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import asyncio
import asyncio.trsock
import errno as _errno
import traceback
import warnings
from collections.abc import Callable, Iterable, Mapping
from types import MappingProxyType, TracebackType
from typing import TYPE_CHECKING, Any, final
from typing import TYPE_CHECKING, Any, final, overload

from ......exceptions import UnsupportedOperation
from ..... import _utils, socket as socket_tools
Expand Down Expand Up @@ -142,6 +143,7 @@ class StreamReaderBufferedProtocol(asyncio.BufferedProtocol):
"__buffer",
"__buffer_view",
"__buffer_nbytes_written",
"__external_buffer_view",
"__transport",
"__closed",
"__write_flow",
Expand Down Expand Up @@ -169,10 +171,11 @@ def __init__(
self.__loop: asyncio.AbstractEventLoop = loop
self.__buffer: bytearray | None = bytearray(self.max_size)
self.__buffer_view: memoryview = memoryview(self.__buffer)
self.__external_buffer_view: WriteableBuffer | None = None
self.__buffer_nbytes_written: int = 0
self.__transport: asyncio.Transport | None = None
self.__closed: asyncio.Future[None] = loop.create_future()
self.__read_waiter: asyncio.Future[None] | None = None
self.__read_waiter: asyncio.Future[int | None] | None = None
self.__write_flow: WriteFlowControl
self.__read_paused: bool = False
self.__connection_lost: bool = False
Expand Down Expand Up @@ -203,6 +206,7 @@ def connection_lost(self, exc: Exception | None) -> None:
self.__eof_reached = True
else:
self.__connection_lost_exception_tb = exc.__traceback__
self.__loop.call_soon(traceback.clear_frames, exc.__traceback__)

self.__buffer_nbytes_written = 0
self.__buffer = None
Expand All @@ -217,6 +221,8 @@ def connection_lost(self, exc: Exception | None) -> None:
self.__transport = None

def get_buffer(self, sizehint: int) -> WriteableBuffer:
if (external_buffer_view := self.__external_buffer_view) is not None:
return external_buffer_view
# Ignore sizehint, the buffer is already at its maximum size.
# Returns unused buffer part
if self.__buffer is None:
Expand All @@ -227,12 +233,21 @@ def buffer_updated(self, nbytes: int) -> None:
assert not self.__connection_lost, "buffer_updated() after connection_lost()" # nosec assert_used
assert not self.__eof_reached, "buffer_updated() after eof_received()" # nosec assert_used

if self.__external_buffer_view is not None:
# Early remove to prevent using this buffer between this point and the wakeup of the task.
self.__external_buffer_view = None
self._read_waiter_fut(lambda waiter: waiter.set_result(nbytes))
# Call to _maybe_pause_transport() is unnecessary: Did not write in internal buffer.
return

self.__buffer_nbytes_written += nbytes
assert 0 <= self.__buffer_nbytes_written <= self.__buffer_view.nbytes # nosec assert_used
self._wakeup_read_waiter(None)
self._maybe_pause_transport()

def eof_received(self) -> bool:
# Early remove to prevent using this buffer between this point and the wakeup of the task.
self.__external_buffer_view = None
self.__eof_reached = True
self._wakeup_read_waiter(None)
if self.__over_ssl:
Expand All @@ -243,14 +258,13 @@ def eof_received(self) -> bool:
return True

async def receive_data(self, bufsize: int, /) -> bytes:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
self._check_for_connection_lost()
if bufsize == 0:
return b""
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")

blocked: bool = await self._wait_for_data("receive_data")
await self._wait_for_data("receive_data", None)

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
Expand All @@ -264,66 +278,80 @@ async def receive_data(self, bufsize: int, /) -> bytes:
else:
data = b""
self._maybe_resume_transport()
if not blocked:
await TaskUtils.cancel_shielded_coro_yield()
return data

async def receive_data_into(self, buffer: WriteableBuffer, /) -> int:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
with memoryview(buffer).cast("B") as buffer:
self._check_for_connection_lost()

with memoryview(buffer) as buffer:
if not buffer:
return 0

blocked: bool = await self._wait_for_data("receive_data_into")

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
protocol_buffer_written = self.__buffer_view[:nbytes_written]
bufsize_offset = nbytes_written - buffer.nbytes
if bufsize_offset > 0:
nbytes_written = buffer.nbytes
buffer[:] = protocol_buffer_written[:nbytes_written]
protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:]
self.__buffer_nbytes_written = bufsize_offset
else:
buffer[:nbytes_written] = protocol_buffer_written
self.__buffer_nbytes_written = 0
with buffer.cast("B") if buffer.itemsize != 1 else buffer as buffer:
nbytes_written = await self._wait_for_data("receive_data_into", buffer)
if nbytes_written is not None:
# Call to _maybe_resume_transport() is unnecessary: Did not write in internal buffer.
return nbytes_written

nbytes_written = self.__buffer_nbytes_written
if nbytes_written:
protocol_buffer_written = self.__buffer_view[:nbytes_written]
bufsize_offset = nbytes_written - buffer.nbytes
if bufsize_offset > 0:
nbytes_written = buffer.nbytes
buffer[:] = protocol_buffer_written[:nbytes_written]
protocol_buffer_written[:bufsize_offset] = protocol_buffer_written[-bufsize_offset:]
self.__buffer_nbytes_written = bufsize_offset
else:
buffer[:nbytes_written] = protocol_buffer_written
self.__buffer_nbytes_written = 0

self._maybe_resume_transport()
if not blocked:
await TaskUtils.cancel_shielded_coro_yield()
return nbytes_written

async def _wait_for_data(self, requester: str) -> bool:
if self.__read_waiter is not None:
raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data")

if self.__buffer_nbytes_written or self.__eof_reached:
return False
@overload
async def _wait_for_data(self, requester: str, external_buffer: None) -> None: ...

assert not self.__read_paused, "transport reading is paused" # nosec assert_used
@overload
async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer) -> int | None: ...

if self.__transport is None:
# happening if transport.pause_reading() raises NotImplementedError
raise _utils.error_from_errno(_errno.ECONNABORTED)
async def _wait_for_data(self, requester: str, external_buffer: WriteableBuffer | None) -> int | None:
if self.__read_waiter is not None:
raise RuntimeError(f"{requester}() called while another coroutine is already waiting for incoming data")

self.__read_waiter = self.__loop.create_future()
try:
await self.__read_waiter
nbytes_written_in_external_buffer: int | None
if self.__buffer_nbytes_written or self.__eof_reached:
self.__read_waiter.set_result(None)
await TaskUtils.coro_yield()
nbytes_written_in_external_buffer = None
else:
assert not self.__read_paused, "transport reading is paused" # nosec assert_used

if self.__transport is None:
# happening if transport.pause_reading() raises NotImplementedError
raise _utils.error_from_errno(_errno.ECONNABORTED)

self.__external_buffer_view = external_buffer
try:
nbytes_written_in_external_buffer = await self.__read_waiter
finally:
self.__external_buffer_view = None
finally:
self.__read_waiter = None
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)
return True

def _wakeup_read_waiter(self, exc: Exception | None) -> None:
if nbytes_written_in_external_buffer is None:
self._check_for_connection_lost()
return nbytes_written_in_external_buffer

def _read_waiter_fut(self, set_result_cb: Callable[[asyncio.Future[int | None]], None]) -> None:
if (waiter := self.__read_waiter) is not None:
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_result_cb(waiter)

def _wakeup_read_waiter(self, exc: Exception | None) -> None:
self._read_waiter_fut(lambda waiter: waiter.set_result(None) if exc is None else waiter.set_exception(exc))

def _get_read_buffer_size(self) -> int:
return self.__buffer_nbytes_written
Expand Down Expand Up @@ -363,6 +391,10 @@ def _maybe_resume_transport(self) -> None:
transport.resume_reading()
self.__read_paused = False

def _check_for_connection_lost(self) -> None:
if self.__connection_lost_exception is not None:
raise self.__connection_lost_exception.with_traceback(self.__connection_lost_exception_tb)

def pause_writing(self) -> None:
self.__write_flow.pause_writing()

Expand Down
12 changes: 8 additions & 4 deletions src/easynetwork/lowlevel/api_async/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ async def aclose(self) -> None:
"""
Closes the endpoint.
"""
await self.__transport.aclose()
self.__receiver.clear()
try:
await self.__transport.aclose()
finally:
self.__receiver.clear()

async def recv_packet(self) -> _T_ReceivedPacket:
"""
Expand Down Expand Up @@ -269,8 +271,10 @@ async def aclose(self) -> None:
Closes the endpoint.
"""
with self.__send_guard:
await self.__transport.aclose()
self.__receiver.clear()
try:
await self.__transport.aclose()
finally:
self.__receiver.clear()

async def send_packet(self, packet: _T_SentPacket) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ async def __client_coroutine(
case _: # pragma: no cover
assert_never(self.__protocol)

# NOTE: It is safe to clear the consumer before the transport here.
# There is no task reading the transport at this point.
task_exit_stack.callback(consumer.clear)

request_handler_generator = client_connected_cb(
Expand Down
12 changes: 8 additions & 4 deletions src/easynetwork/lowlevel/api_sync/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def close(self) -> None:
"""
Closes the endpoint.
"""
self.__transport.close()
self.__receiver.clear()
try:
self.__transport.close()
finally:
self.__receiver.clear()

def recv_packet(self, *, timeout: float | None = None) -> _T_ReceivedPacket:
"""
Expand Down Expand Up @@ -270,8 +272,10 @@ def close(self) -> None:
"""
Closes the endpoint.
"""
self.__transport.close()
self.__receiver.clear()
try:
self.__transport.close()
finally:
self.__receiver.clear()

def send_packet(self, packet: _T_SentPacket, *, timeout: float | None = None) -> None:
"""
Expand Down
Loading

0 comments on commit e9364e3

Please sign in to comment.