From 057d90854bdf3c55ab92dad29bbf219fc399a559 Mon Sep 17 00:00:00 2001 From: Oded Engel Date: Tue, 10 Oct 2023 08:37:13 +0300 Subject: [PATCH] add compiled-regex support for readuntil function --- asyncssh/stream.py | 23 +++++++++++++++++------ tests/test_stream.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/asyncssh/stream.py b/asyncssh/stream.py index 2ee8e3c..c9f2810 100644 --- a/asyncssh/stream.py +++ b/asyncssh/stream.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING, Any, AnyStr, AsyncIterator from typing import Callable, Dict, Generic, Iterable from typing import List, Optional, Set, Tuple, Union, cast +from typing.re import Pattern from .constants import EXTENDED_DATA_STDERR from .logging import SSHLogger @@ -180,18 +181,25 @@ async def readline(self) -> AnyStr: except asyncio.IncompleteReadError as exc: return cast(AnyStr, exc.partial) - async def readuntil(self, separator: object) -> AnyStr: + async def readuntil(self, separator: object, max_separator_len: int = None) -> AnyStr: """Read data from the stream until `separator` is seen This method is a coroutine which reads from the stream until the requested separator is seen. If a match is found, the returned data will include the separator at the end. - The separator argument can be either a single `bytes` or + The separator argument can be a single `bytes` or `str` value or a sequence of multiple values to match - against, returning data as soon as any of the separators + against or a compiled regex (typing.re.Pattern), + returning data as soon as any of the separators are found in the stream. + The separator-length argument may only be set when providing + a compiled regex as a separator. + Otherwise, the separator's length would be used. + Note that if compiled regex is provided and the length is not set, + 0 would be used. (regex match on the whole buffer) + If EOF or a signal is received before a match occurs, an :exc:`IncompleteReadError ` is raised and its `partial` attribute will contain the @@ -202,7 +210,7 @@ async def readuntil(self, separator: object) -> AnyStr: """ - return await self._session.readuntil(separator, self._datatype) + return await self._session.readuntil(separator, self._datatype, max_separator_len) async def readexactly(self, n: int) -> AnyStr: """Read an exact amount of data from the stream @@ -558,7 +566,7 @@ async def read(self, n: int, datatype: DataType, exact: bool) -> AnyStr: return result - async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: + async def readuntil(self, separator: object, datatype: DataType, max_separator_len: int) -> AnyStr: """Read data from the channel until a separator is seen""" if not separator: @@ -573,6 +581,9 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: elif isinstance(separator, (bytes, str)): seplen = len(separator) separators = re.escape(cast(AnyStr, separator)) + elif isinstance(separator, Pattern): + seplen = max_separator_len + separators = separator else: bar = cast(AnyStr, '|' if self._encoding else b'|') seplist = list(cast(Iterable[AnyStr], separator)) @@ -602,7 +613,7 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: newbuf = cast(AnyStr, recv_buf[curbuf]) buf += newbuf - start = max(buflen + 1 - seplen, 0) + start = 0 if seplen is None else max(buflen + 1 - seplen, 0) match = pat.search(buf, start) if match: diff --git a/tests/test_stream.py b/tests/test_stream.py index 12e3675..63dfbc6 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -21,6 +21,7 @@ """Unit tests for AsyncSSH stream API""" import asyncio +import re import asyncssh @@ -391,6 +392,27 @@ async def test_readuntil_empty_separator(self): stdin.close() + @asynctest + async def test_readuntil_regex(self): + """Test readuntil with a regex pattern""" + + async with self.connect() as conn: + stdin, stdout, _ = await conn.open_session() + stdin.write("hello world\nhello world") + output = await stdout.readuntil( + re.compile('hello world'), len('hello world') + ) + self.assertEqual(output, "hello world") + + output = await stdout.readuntil( + re.compile('hello world'), len('hello world') + ) + self.assertEqual(output, "\nhello world") + + stdin.close() + + await conn.wait_closed() + @asynctest async def test_abort(self): """Test abort on a channel"""