diff --git a/reactbot/simplepattern.py b/reactbot/simplepattern.py index 4d3e2c3..f40d7ce 100644 --- a/reactbot/simplepattern.py +++ b/reactbot/simplepattern.py @@ -1,5 +1,5 @@ # reminder - A maubot plugin that reacts to messages that match predefined rules. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,39 +13,59 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, List, Dict, Optional +from typing import Callable, List, Dict, Optional, NamedTuple import re -class BlankMatch: - @staticmethod - def groups() -> List[str]: - return [] +class SimpleMatch(NamedTuple): + value: str - @staticmethod - def group(group: int) -> str: - return "" + def groups(self) -> List[str]: + return [self.value] - @staticmethod - def groupdict() -> Dict[str, str]: + def group(self, group: int) -> Optional[str]: + if group == 0: + return self.value + return None + + def groupdict(self) -> Dict[str, str]: return {} -class SimplePattern: - _ptm = BlankMatch() +def matcher_equals(val: str, pattern: str) -> bool: + return val == pattern + + +def matcher_startswith(val: str, pattern: str) -> bool: + return val.startswith(pattern) + - matcher: Callable[[str], bool] +def matcher_endswith(val: str, pattern: str) -> bool: + return val.endswith(pattern) + + +def matcher_contains(val: str, pattern: str) -> bool: + return pattern in val + + +SimpleMatcherFunc = Callable[[str, str], bool] + + +class SimplePattern: + matcher: SimpleMatcherFunc + pattern: str ignorecase: bool - def __init__(self, matcher: Callable[[str], bool], ignorecase: bool) -> None: + def __init__(self, matcher: SimpleMatcherFunc, pattern: str, ignorecase: bool) -> None: self.matcher = matcher + self.pattern = pattern self.ignorecase = ignorecase - def search(self, val: str) -> BlankMatch: + def search(self, val: str) -> SimpleMatch: if self.ignorecase: val = val.lower() - if self.matcher(val): - return self._ptm + if self.matcher(val, self.pattern): + return SimpleMatch(self.pattern) @staticmethod def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool = False @@ -58,13 +78,16 @@ def compile(pattern: str, flags: re.RegexFlag = re.RegexFlag(0), force_raw: bool first, last = pattern[0], pattern[-1] if first == '^' and last == '$' and (force_raw or esc == f"\\^{pattern[1:-1]}\\$"): s_pattern = s_pattern[1:-1] - return SimplePattern(lambda val: val == s_pattern, ignorecase=ignorecase) + func = matcher_equals elif first == '^' and (force_raw or esc == f"\\^{pattern[1:]}"): s_pattern = s_pattern[1:] - return SimplePattern(lambda val: val.startswith(s_pattern), ignorecase=ignorecase) + func = matcher_startswith elif last == '$' and (force_raw or esc == f"{pattern[:-1]}\\$"): s_pattern = s_pattern[:-1] - return SimplePattern(lambda val: val.endswith(s_pattern), ignorecase=ignorecase) + func = matcher_endswith elif force_raw or esc == pattern: - return SimplePattern(lambda val: s_pattern in val, ignorecase=ignorecase) - return None + func = matcher_contains + else: + # Not a simple pattern + return None + return SimplePattern(matcher=func, pattern=s_pattern, ignorecase=ignorecase)