Skip to content

Commit

Permalink
Fix issue with AsyncFileWriter potentially writing data out of order
Browse files Browse the repository at this point in the history
This commit fixes an issue where data can be written out of order when
redirecting to a file opened using aiofiles. My thanks go to Chan Chun Wai
for reporting this issue and providing code to reproduce it.
  • Loading branch information
ronf committed Oct 6, 2023
1 parent 18252db commit 8f1fe10
Showing 1 changed file with 52 additions and 35 deletions.
87 changes: 52 additions & 35 deletions asyncssh/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import socket
import stat
from types import TracebackType
from typing import Any, AnyStr, Callable, Dict, Generic, IO
from typing import Iterable, Mapping, Optional, Set, TextIO
from typing import Any, AnyStr, Awaitable, Callable, Dict, Generic, IO
from typing import Iterable, List, Mapping, Optional, Set, TextIO
from typing import Tuple, Type, TypeVar, Union, cast
from typing_extensions import Protocol

Expand Down Expand Up @@ -308,14 +308,33 @@ def __init__(self, process: 'SSHProcess[AnyStr]',
encoding: Optional[str], errors: str):
super().__init__(encoding, errors, hasattr(file, 'encoding'))

self._conn = process.channel.get_connection()
self._process: 'SSHProcess[AnyStr]' = process
self._file = file
self._needs_close = needs_close
self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue()
self._write_task: Optional[asyncio.Task[None]] = \
process.channel.get_connection().create_task(self._writer())

async def _writer(self) -> None:
"""Process writes to the file"""

while True:
data = await self._queue.get()

if data is None:
self._queue.task_done()
break

await self._file.write(self.encode(data))
self._queue.task_done()

if self._needs_close:
await self._file.close()

def write(self, data: AnyStr) -> None:
"""Write data to the file"""

self._conn.create_task(self._file.write(self.encode(data)))
self._queue.put_nowait(data)

def write_eof(self) -> None:
"""Close output file when end of file is received"""
Expand All @@ -325,8 +344,10 @@ def write_eof(self) -> None:
def close(self) -> None:
"""Stop forwarding data to the file"""

if self._needs_close:
self._conn.create_task(self._file.close())
if self._write_task:
self._write_task = None
self._queue.put_nowait(None)
self._process.add_cleanup_task(self._queue.join())


class _PipeReader(_UnicodeReader[AnyStr], asyncio.BaseProtocol):
Expand Down Expand Up @@ -721,6 +742,8 @@ class SSHProcess(SSHStreamSession, Generic[AnyStr]):
def __init__(self, *args) -> None:
super().__init__(*args)

self._cleanup_tasks: List[Awaitable[None]] = []

self._readers: Dict[Optional[int], _ReaderProtocol] = {}
self._send_eof: Dict[Optional[int], bool] = {}

Expand All @@ -729,6 +752,20 @@ def __init__(self, *args) -> None:

self._paused_write_streams: Set[Optional[int]] = set()

async def __aenter__(self) -> 'SSHProcess[AnyStr]':
"""Allow SSHProcess to be used as an async context manager"""

return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
_exc_value: Optional[BaseException],
_traceback: Optional[TracebackType]) -> bool:
"""Wait for a full channel close when exiting the async context"""

self.close()
await self.wait_closed()
return False

@property
def channel(self) -> SSHChannel[AnyStr]:
"""The channel associated with the process"""
Expand Down Expand Up @@ -931,6 +968,11 @@ def _should_pause_reading(self) -> bool:
return bool(self._paused_write_streams) or \
super()._should_pause_reading()

def add_cleanup_task(self, task: Awaitable[None]) -> None:
"""Add a task to run when the process exits"""

self._cleanup_tasks.append(task)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle a close of the SSH channel"""

Expand Down Expand Up @@ -1091,6 +1133,9 @@ async def wait_closed(self) -> None:
assert self._chan is not None
await self._chan.wait_closed()

for task in self._cleanup_tasks:
await task


class SSHClientProcess(SSHProcess[AnyStr], SSHClientStreamSession[AnyStr]):
"""SSH client process handler"""
Expand All @@ -1105,20 +1150,6 @@ def __init__(self) -> None:
self._stdout: Optional[SSHReader[AnyStr]] = None
self._stderr: Optional[SSHReader[AnyStr]] = None

async def __aenter__(self) -> 'SSHClientProcess[AnyStr]':
"""Allow SSHProcess to be used as an async context manager"""

return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
_exc_value: Optional[BaseException],
_traceback: Optional[TracebackType]) -> bool:
"""Wait for a full channel close when exiting the async context"""

self.close()
await self._chan.wait_closed()
return False

def _collect_output(self, datatype: DataType = None) -> AnyStr:
"""Return output from the process"""

Expand Down Expand Up @@ -1333,7 +1364,7 @@ async def communicate(self, input: Optional[AnyStr] = None) -> \
self._chan.write(input)
self._chan.write_eof()

await self._chan.wait_closed()
await self.wait_closed()

return self.collect_output()
# pylint: enable=redefined-builtin
Expand Down Expand Up @@ -1482,20 +1513,6 @@ def __init__(self, process_factory: SSHServerProcessFactory,
self._stdout: Optional[SSHWriter[AnyStr]] = None
self._stderr: Optional[SSHWriter[AnyStr]] = None

async def __aenter__(self) -> 'SSHServerProcess[AnyStr]':
"""Allow SSHProcess to be used as an async context manager"""

return self

async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
_exc_value: Optional[BaseException],
_traceback: Optional[TracebackType]) -> bool:
"""Wait for a full channel close when exiting the async context"""

self.close()
await self._chan.wait_closed()
return False

def _start_process(self, stdin: SSHReader[AnyStr],
stdout: SSHWriter[AnyStr],
stderr: SSHWriter[AnyStr]) -> MaybeAwait[None]:
Expand Down

0 comments on commit 8f1fe10

Please sign in to comment.