diff --git a/asyncssh/process.py b/asyncssh/process.py index b4bea25..33a0cdb 100644 --- a/asyncssh/process.py +++ b/asyncssh/process.py @@ -30,8 +30,8 @@ import socket import stat from types import TracebackType -from typing import Any, AnyStr, Callable, Dict, Generic, IO -from typing import Iterable, Mapping, Optional, Set, TextIO +from typing import Any, AnyStr, Awaitable, Callable, Dict, Generic, IO +from typing import Iterable, List, Mapping, Optional, Set, TextIO from typing import Tuple, Type, TypeVar, Union, cast from typing_extensions import Protocol @@ -308,14 +308,33 @@ def __init__(self, process: 'SSHProcess[AnyStr]', encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) - self._conn = process.channel.get_connection() + self._process: 'SSHProcess[AnyStr]' = process self._file = file self._needs_close = needs_close + self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() + self._write_task: Optional[asyncio.Task[None]] = \ + process.channel.get_connection().create_task(self._writer()) + + async def _writer(self) -> None: + """Process writes to the file""" + + while True: + data = await self._queue.get() + + if data is None: + self._queue.task_done() + break + + await self._file.write(self.encode(data)) + self._queue.task_done() + + if self._needs_close: + await self._file.close() def write(self, data: AnyStr) -> None: """Write data to the file""" - self._conn.create_task(self._file.write(self.encode(data))) + self._queue.put_nowait(data) def write_eof(self) -> None: """Close output file when end of file is received""" @@ -325,8 +344,10 @@ def write_eof(self) -> None: def close(self) -> None: """Stop forwarding data to the file""" - if self._needs_close: - self._conn.create_task(self._file.close()) + if self._write_task: + self._write_task = None + self._queue.put_nowait(None) + self._process.add_cleanup_task(self._queue.join()) class _PipeReader(_UnicodeReader[AnyStr], asyncio.BaseProtocol): @@ -721,6 +742,8 @@ class SSHProcess(SSHStreamSession, Generic[AnyStr]): def __init__(self, *args) -> None: super().__init__(*args) + self._cleanup_tasks: List[Awaitable[None]] = [] + self._readers: Dict[Optional[int], _ReaderProtocol] = {} self._send_eof: Dict[Optional[int], bool] = {} @@ -729,6 +752,20 @@ def __init__(self, *args) -> None: self._paused_write_streams: Set[Optional[int]] = set() + async def __aenter__(self) -> 'SSHProcess[AnyStr]': + """Allow SSHProcess to be used as an async context manager""" + + return self + + async def __aexit__(self, _exc_type: Optional[Type[BaseException]], + _exc_value: Optional[BaseException], + _traceback: Optional[TracebackType]) -> bool: + """Wait for a full channel close when exiting the async context""" + + self.close() + await self.wait_closed() + return False + @property def channel(self) -> SSHChannel[AnyStr]: """The channel associated with the process""" @@ -931,6 +968,11 @@ def _should_pause_reading(self) -> bool: return bool(self._paused_write_streams) or \ super()._should_pause_reading() + def add_cleanup_task(self, task: Awaitable[None]) -> None: + """Add a task to run when the process exits""" + + self._cleanup_tasks.append(task) + def connection_lost(self, exc: Optional[Exception]) -> None: """Handle a close of the SSH channel""" @@ -1091,6 +1133,9 @@ async def wait_closed(self) -> None: assert self._chan is not None await self._chan.wait_closed() + for task in self._cleanup_tasks: + await task + class SSHClientProcess(SSHProcess[AnyStr], SSHClientStreamSession[AnyStr]): """SSH client process handler""" @@ -1105,20 +1150,6 @@ def __init__(self) -> None: self._stdout: Optional[SSHReader[AnyStr]] = None self._stderr: Optional[SSHReader[AnyStr]] = None - async def __aenter__(self) -> 'SSHClientProcess[AnyStr]': - """Allow SSHProcess to be used as an async context manager""" - - return self - - async def __aexit__(self, _exc_type: Optional[Type[BaseException]], - _exc_value: Optional[BaseException], - _traceback: Optional[TracebackType]) -> bool: - """Wait for a full channel close when exiting the async context""" - - self.close() - await self._chan.wait_closed() - return False - def _collect_output(self, datatype: DataType = None) -> AnyStr: """Return output from the process""" @@ -1333,7 +1364,7 @@ async def communicate(self, input: Optional[AnyStr] = None) -> \ self._chan.write(input) self._chan.write_eof() - await self._chan.wait_closed() + await self.wait_closed() return self.collect_output() # pylint: enable=redefined-builtin @@ -1482,20 +1513,6 @@ def __init__(self, process_factory: SSHServerProcessFactory, self._stdout: Optional[SSHWriter[AnyStr]] = None self._stderr: Optional[SSHWriter[AnyStr]] = None - async def __aenter__(self) -> 'SSHServerProcess[AnyStr]': - """Allow SSHProcess to be used as an async context manager""" - - return self - - async def __aexit__(self, _exc_type: Optional[Type[BaseException]], - _exc_value: Optional[BaseException], - _traceback: Optional[TracebackType]) -> bool: - """Wait for a full channel close when exiting the async context""" - - self.close() - await self._chan.wait_closed() - return False - def _start_process(self, stdin: SSHReader[AnyStr], stdout: SSHWriter[AnyStr], stderr: SSHWriter[AnyStr]) -> MaybeAwait[None]: