Skip to content

Commit

Permalink
add compiled-regex support for readuntil function
Browse files Browse the repository at this point in the history
  • Loading branch information
engeloded committed Oct 10, 2023
1 parent fa6a1c9 commit bfb4526
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
23 changes: 17 additions & 6 deletions asyncssh/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING, Any, AnyStr, AsyncIterator
from typing import Callable, Dict, Generic, Iterable
from typing import List, Optional, Set, Tuple, Union, cast
from typing.re import Pattern

from .constants import EXTENDED_DATA_STDERR
from .logging import SSHLogger
Expand Down Expand Up @@ -180,18 +181,25 @@ async def readline(self) -> AnyStr:
except asyncio.IncompleteReadError as exc:
return cast(AnyStr, exc.partial)

async def readuntil(self, separator: object) -> AnyStr:
async def readuntil(self, separator: object, max_separator_len: int = None) -> AnyStr:
"""Read data from the stream until `separator` is seen
This method is a coroutine which reads from the stream until
the requested separator is seen. If a match is found, the
returned data will include the separator at the end.
The separator argument can be either a single `bytes` or
The separator argument can be a single `bytes` or
`str` value or a sequence of multiple values to match
against, returning data as soon as any of the separators
against or a compiled regex (typing.re.Pattern),
returning data as soon as any of the separators
are found in the stream.
The separator-length argument may only be set when providing
a compiled regex as a separator.
Otherwise, the separator's length would be used.
Note that if compiled regex is provided and the length is not set,
0 would be used. (regex match on the whole buffer)
If EOF or a signal is received before a match occurs, an
:exc:`IncompleteReadError <asyncio.IncompleteReadError>`
is raised and its `partial` attribute will contain the
Expand All @@ -202,7 +210,7 @@ async def readuntil(self, separator: object) -> AnyStr:
"""

return await self._session.readuntil(separator, self._datatype)
return await self._session.readuntil(separator, self._datatype, max_separator_len)

async def readexactly(self, n: int) -> AnyStr:
"""Read an exact amount of data from the stream
Expand Down Expand Up @@ -558,7 +566,7 @@ async def read(self, n: int, datatype: DataType, exact: bool) -> AnyStr:

return result

async def readuntil(self, separator: object, datatype: DataType) -> AnyStr:
async def readuntil(self, separator: object, datatype: DataType, max_separator_len: int) -> AnyStr:
"""Read data from the channel until a separator is seen"""

if not separator:
Expand All @@ -573,6 +581,9 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr:
elif isinstance(separator, (bytes, str)):
seplen = len(separator)
separators = re.escape(cast(AnyStr, separator))
elif isinstance(separator, Pattern):
seplen = max_separator_len
separators = separator
else:
bar = cast(AnyStr, '|' if self._encoding else b'|')
seplist = list(cast(Iterable[AnyStr], separator))
Expand Down Expand Up @@ -602,7 +613,7 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr:

newbuf = cast(AnyStr, recv_buf[curbuf])
buf += newbuf
start = max(buflen + 1 - seplen, 0)
start = 0 if seplen is None else max(buflen + 1 - seplen, 0)

match = pat.search(buf, start)
if match:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""Unit tests for AsyncSSH stream API"""

import asyncio
import re

import asyncssh

Expand Down Expand Up @@ -391,6 +392,27 @@ async def test_readuntil_empty_separator(self):

stdin.close()

@asynctest
async def test_readuntil_regex(self):
"""Test readuntil with a regex pattern"""

async with self.connect() as conn:
stdin, stdout, _ = await conn.open_session()
stdin.write("hello world\nhello world")
output = await stdout.readuntil(
re.compile('hello world'), len('hello world')
)
self.assertEqual(output, "hello world")

output = await stdout.readuntil(
re.compile('hello world'), len('hello world')
)
self.assertEqual(output, "\nhello world")

stdin.close()

await conn.wait_closed()

@asynctest
async def test_abort(self):
"""Test abort on a channel"""
Expand Down

0 comments on commit bfb4526

Please sign in to comment.