Skip to content

Commit

Permalink
Add timeout for executing seds
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Apr 20, 2022
1 parent 60ce8a7 commit 4145389
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, NamedTuple, Pattern, Optional, Dict, Deque
from types import FrameType
from collections import deque, defaultdict
from contextlib import contextmanager
from difflib import SequenceMatcher
from html import escape
import string
import signal
import time
import re

Expand All @@ -35,6 +38,20 @@
SedMatch = Tuple[str, str, str, str, str]


def raise_timeout(sig: signal.Signals, frame_type: FrameType) -> None:
raise TimeoutError()


@contextmanager
def timeout(max_time: float = 0.5) -> None:
signal.signal(signal.SIGALRM, raise_timeout)
signal.setitimer(signal.ITIMER_REAL, max_time)
try:
yield
finally:
signal.alarm(0)


class SedBot(Plugin):
prev_user_events: Dict[RoomID, Dict[UserID, MessageEvent]]
prev_room_events: Dict[RoomID, Deque[MessageEvent]]
Expand Down Expand Up @@ -120,7 +137,8 @@ def _compile_passive_statement(cls, match: SedMatch) -> Optional[SedStatement]:

@staticmethod
def _exec(stmt: SedStatement, body: str) -> str:
return stmt.find.sub(stmt.replace, body, count=0 if stmt.is_global else 1)
with timeout():
return stmt.find.sub(stmt.replace, body, count=0 if stmt.is_global else 1)

@staticmethod
def op_to_str(tag: str, old_text: str, new_text: str) -> str:
Expand Down Expand Up @@ -183,7 +201,14 @@ def _is_recent(evt: MessageEvent) -> bool:
@command.passive(r"(?:^|[^a-zA-Z0-9])sed (s.+)")
@command.passive(r"^(s[#/].+[#/].+)$")
async def command_handler(self, evt: MessageEvent, match: SedMatch) -> None:
stmt = self._compile_passive_statement(match)
try:
await self._command_handler(evt, match)
except TimeoutError:
await evt.reply("3:<")

async def _command_handler(self, evt: MessageEvent, match: SedMatch):
with timeout():
stmt = self._compile_passive_statement(match)
if not stmt:
return
try:
Expand Down

0 comments on commit 4145389

Please sign in to comment.