Skip to content

Commit

Permalink
AsyncSocket: Fixed OSError raised by methods if aclose() is called (#175
Browse files Browse the repository at this point in the history
)
  • Loading branch information
francis-clairicia authored Dec 2, 2023
1 parent 3beb7c4 commit 68fc10c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
18 changes: 9 additions & 9 deletions src/easynetwork/lowlevel/asyncio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ async def accept(self) -> _socket.socket:
return client_socket

async def sendall(self, data: ReadableBuffer, /) -> None:
with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendall(socket, data)

async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")
Expand All @@ -145,22 +145,22 @@ async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
_utils.adjust_leftover_buffer(buffers, sent)

async def sendto(self, data: ReadableBuffer, address: _socket._Address, /) -> None:
with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendto(socket, data, address)

async def recv(self, bufsize: int, /) -> bytes:
with self.__conflict_detection("recv", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv(socket, bufsize)

async def recv_into(self, buffer: WriteableBuffer, /) -> int:
with self.__conflict_detection("recv", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv_into(socket, buffer)

async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
with self.__conflict_detection("recv", abort_errno=_errno.ECONNABORTED):
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recvfrom(socket, bufsize)

Expand All @@ -180,7 +180,7 @@ async def shutdown(self, how: int, /) -> None:
await asyncio.sleep(0)

@contextlib.contextmanager
def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int = _errno.EINTR) -> Iterator[None]:
def __conflict_detection(self, task_id: _SocketTaskId) -> Iterator[None]:
if task_id in self.__waiters:
raise _utils.error_from_errno(_errno.EBUSY)

Expand All @@ -200,11 +200,11 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int = _er
yield

if scope.cancelled_caught():
raise _utils.error_from_errno(abort_errno)
raise _utils.error_from_errno(_errno.EBADF)

def __check_not_closed(self) -> _socket.socket:
if (socket := self.__socket) is None:
raise _utils.error_from_errno(_errno.ENOTSOCK)
raise _utils.error_from_errno(_errno.EBADF)
return socket

@property
Expand Down
24 changes: 4 additions & 20 deletions tests/unit_test/test_async/test_asyncio_backend/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import contextlib
from collections.abc import Callable, Coroutine, Iterable, Iterator
from errno import EBUSY, ECONNABORTED, EINTR, ENOTSOCK
from errno import EBADF, EBUSY
from socket import SHUT_RD, SHUT_RDWR, SHUT_WR, socket as Socket
from typing import TYPE_CHECKING, Any, final

Expand Down Expand Up @@ -140,13 +140,12 @@ async def test____method____closed_socket____before_attempt(
await socket_method()

# Assert
assert exc_info.value.errno == ENOTSOCK
assert exc_info.value.errno == EBADF
mock_socket_method.assert_not_called()

async def test____method____closed_socket____during_attempt(
self,
socket: AsyncSocket,
abort_errno: int,
socket_method: Callable[[], Coroutine[Any, Any, Any]],
event_loop: asyncio.AbstractEventLoop,
mock_socket_method: MagicMock,
Expand All @@ -161,7 +160,7 @@ async def test____method____closed_socket____during_attempt(
await busy_method_task

# Assert
assert exc_info.value.errno == abort_errno
assert exc_info.value.errno == EBADF
mock_socket_method.assert_not_called()

@pytest.mark.parametrize("cancellation_requests", [1, 3])
Expand Down Expand Up @@ -325,11 +324,6 @@ def socket(
) -> AsyncSocket:
return AsyncSocket(mock_tcp_listener_socket, event_loop)

@pytest.fixture
@staticmethod
def abort_errno() -> int:
return EINTR

@pytest.fixture
@staticmethod
def socket_method(socket: AsyncSocket) -> Callable[[], Coroutine[Any, Any, Any]]:
Expand Down Expand Up @@ -394,11 +388,6 @@ def socket(
def sock_method_name(request: Any) -> str:
return request.param

@pytest.fixture
@staticmethod
def abort_errno() -> int:
return ECONNABORTED

@pytest.fixture
@staticmethod
def socket_method(sock_method_name: str, socket: AsyncSocket) -> Callable[[], Coroutine[Any, Any, Any]]:
Expand Down Expand Up @@ -625,7 +614,7 @@ async def test____shutdown____closed_socket(
await socket.shutdown(shutdown_how)

# Assert
assert exc_info.value.errno == ENOTSOCK
assert exc_info.value.errno == EBADF
mock_tcp_socket.shutdown.assert_not_called()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -717,11 +706,6 @@ def socket(
def sock_method_name(request: Any) -> str:
return request.param

@pytest.fixture
@staticmethod
def abort_errno() -> int:
return ECONNABORTED

@pytest.fixture
@staticmethod
def socket_method(sock_method_name: str, socket: AsyncSocket) -> Callable[[], Coroutine[Any, Any, Any]]:
Expand Down

0 comments on commit 68fc10c

Please sign in to comment.