diff --git a/tgpy/builtin_functions.py b/tgpy/builtin_functions.py index ccadd35..a862c35 100644 --- a/tgpy/builtin_functions.py +++ b/tgpy/builtin_functions.py @@ -7,7 +7,7 @@ from telethon.tl.custom import Message from tgpy import app -from tgpy.message_design import get_code +from tgpy.message_design import parse_message from tgpy.modules import ( Module, delete_module_file, @@ -94,9 +94,10 @@ async def add(self, name: str, code: Optional[str] = None) -> str: original: Message = await app.ctx.msg.get_reply_message() if original is None: return 'Use this function in reply to a message' - code = get_code(original) - if not code: + message_data = parse_message(original) + if not message_data.is_tgpy_message: return 'No code found in reply message' + code = message_data.code origin = f'{filename_prefix}module/{name}' diff --git a/tgpy/handlers/__init__.py b/tgpy/handlers/__init__.py index 657c031..ce9c117 100644 --- a/tgpy/handlers/__init__.py +++ b/tgpy/handlers/__init__.py @@ -2,48 +2,79 @@ from telethon.tl.custom import Message from telethon.tl.types import Channel -from tgpy import app, message_design +from tgpy import app, message_design, reactions_fix from tgpy.handlers.utils import _handle_errors, outgoing_messages_filter +from tgpy.reactions_fix import ReactionsFixResult from tgpy.run_code import eval_message, get_variable_names, parse_code -async def handle_message(message: Message) -> None: - raw_text = message.raw_text - - if not raw_text: - return +async def handle_message( + message: Message, *, only_show_warning: bool = False +) -> Message: + if not message.raw_text: + return message if message.text.startswith('//') and message.text[2:].strip(): - await message.edit(message.text[2:]) - return + return await message.edit(message.text[2:]) locals_ = get_variable_names() - res = parse_code(raw_text, locals_) + res = parse_code(message.raw_text, locals_) if not res.is_code: - return - - await eval_message(raw_text, message, uses_orig=res.uses_orig) + return message + + if only_show_warning: + return await message_design.edit_message( + message, + message.raw_text, + 'Edit message again to evaluate', + ) + else: + return await eval_message(message.raw_text, message, res.uses_orig) @events.register(events.NewMessage(func=outgoing_messages_filter)) @_handle_errors async def on_new_message(event: events.NewMessage.Event) -> None: - await handle_message(event.message) + message = await handle_message(event.message) + reactions_fix.update_hash(message) @events.register(events.MessageEdited(func=outgoing_messages_filter)) @_handle_errors -async def on_message_edited(event: events.NewMessage.Event) -> None: - if isinstance(event.message.chat, Channel) and event.message.chat.broadcast: - return - code = message_design.get_code(event.message) - if not code: - await handle_message(event.message) +async def on_message_edited(event: events.MessageEdited.Event) -> None: + message: Message = event.message + if isinstance(message.chat, Channel) and message.chat.broadcast: return - await eval_message( - code, event.message, uses_orig=parse_code(code, get_variable_names()).uses_orig - ) + message_data = message_design.parse_message(message) + reactions_fix_result = reactions_fix.check_hash(message) + try: + match reactions_fix_result: + case ReactionsFixResult.ignore: + return + case ReactionsFixResult.show_warning: + if message_data.is_tgpy_message: + message = await message_design.edit_message( + message, message_data.code, 'Edit message again to evaluate' + ) + else: + message = await handle_message(message, only_show_warning=True) + return + case ReactionsFixResult.evaluate: + pass + case _: + raise ValueError(f'Bad reactions fix result: {reactions_fix_result}') + + if not message_data.is_tgpy_message: + message = await handle_message(message) + return + message = await eval_message( + message_data.code, + message, + parse_code(message_data.code, get_variable_names()).uses_orig, + ) + finally: + reactions_fix.update_hash(message) @events.register( @@ -55,14 +86,14 @@ async def cancel(message: Message): async for msg in app.client.iter_messages( message.chat_id, max_id=message.id, limit=10 ): - if msg.out and message_design.get_code(msg): + if msg.out and message_design.parse_message(msg).is_tgpy_message: prev = msg break else: return # noinspection PyBroadException try: - await prev.edit(message_design.get_code(prev)) + await prev.edit(message_design.parse_message(prev).code) except Exception: pass else: diff --git a/tgpy/message_design.py b/tgpy/message_design.py index 2509464..8db597d 100644 --- a/tgpy/message_design.py +++ b/tgpy/message_design.py @@ -1,5 +1,6 @@ import sys import traceback as tb +from dataclasses import dataclass from telethon.tl.custom import Message from telethon.tl.types import MessageEntityBold, MessageEntityCode, MessageEntityTextUrl @@ -11,57 +12,96 @@ FORMATTED_ERROR_HEADER = f'TGPy error>' -def utf16_codepoints_len(s: str): - return len(s.encode('utf-16-le')) // 2 +class Utf16CodepointsWrapper(str): + def __len__(self): + return len(self.encode('utf-16-le')) // 2 - -def utf16_codepoints_prefix(s: str, length: int): - s = s.encode('utf-16-le') - s = s[: length * 2] - return s.decode('utf-16-le') + def __getitem__(self, item): + s = self.encode('utf-16-le') + if isinstance(item, slice): + item = slice( + item.start * 2 if item.start else None, + item.stop * 2 if item.stop else None, + item.step * 2 if item.step else None, + ) + s = s[item] + elif isinstance(item, int): + s = s[item * 2 : item * 2 + 2] + else: + raise TypeError(f'{type(item)} is not supported') + return s.decode('utf-16-le') async def edit_message( - message: Message, code: str, result, traceback: str = '', output: str = '' -) -> None: + message: Message, + code: str, + result: str, + traceback: str = '', + output: str = '', +) -> Message: if result is None and output: result = output output = '' - parts = [code.strip(), f'{TITLE} {str(result).strip()}'] - parts += [part for part in (output.strip(), traceback.strip()) if part] - text = '\n\n'.join(parts) + title = Utf16CodepointsWrapper(TITLE) + parts = [ + Utf16CodepointsWrapper(code.strip()), + Utf16CodepointsWrapper(f'{title} {str(result).strip()}'), + ] + parts += [ + Utf16CodepointsWrapper(part) + for part in (output.strip(), traceback.strip()) + if part + ] entities = [] offset = 0 for p in parts: - entities.append(MessageEntityCode(offset, utf16_codepoints_len(p))) - offset += utf16_codepoints_len(p) + 2 + entities.append(MessageEntityCode(offset, len(p))) + offset += len(p) + 2 - entities[1].offset += utf16_codepoints_len(TITLE) + 1 - entities[1].length -= utf16_codepoints_len(TITLE) + 1 + entities[1].offset += len(title) + 1 + entities[1].length -= len(title) + 1 entities[1:1] = [ MessageEntityBold( - utf16_codepoints_len(parts[0]) + 2, - utf16_codepoints_len(TITLE), + len(parts[0]) + 2, + len(title), ), MessageEntityTextUrl( - utf16_codepoints_len(parts[0]) + 2, - utf16_codepoints_len(TITLE), + len(parts[0]) + 2, + len(title), TITLE_URL, ), ] + text = str('\n\n'.join(parts)) if len(text) > 4096: text = text[:4095] + '…' - await message.edit(text, formatting_entities=entities, link_preview=False) + return await message.edit(text, formatting_entities=entities, link_preview=False) + + +@dataclass +class MessageParseResult: + is_tgpy_message: bool + code: str | None + result: str | None -def get_code(message: Message) -> str: +def get_title_entity(message: Message) -> MessageEntityTextUrl | None: for e in message.entities or []: if isinstance(e, MessageEntityTextUrl) and e.url == TITLE_URL: - return utf16_codepoints_prefix(message.raw_text, length=e.offset).strip() - return '' + return e + return None + + +def parse_message(message: Message) -> MessageParseResult: + e = get_title_entity(message) + if not e: + return MessageParseResult(False, None, None) + msg_text = Utf16CodepointsWrapper(message.raw_text) + code = msg_text[: e.offset].strip() + result = msg_text[e.offset + e.length :].strip() + return MessageParseResult(True, code, result) async def send_error(chat) -> None: @@ -71,3 +111,11 @@ async def send_error(chat) -> None: await app.client.send_message( chat, f'{FORMATTED_ERROR_HEADER}\n\n{exc}', link_preview=False ) + + +__all__ = [ + 'edit_message', + 'MessageParseResult', + 'parse_message', + 'send_error', +] diff --git a/tgpy/reactions_fix.py b/tgpy/reactions_fix.py new file mode 100644 index 0000000..0522403 --- /dev/null +++ b/tgpy/reactions_fix.py @@ -0,0 +1,43 @@ +""" +This module tries to fix Telegram bug/undocumented feature where +setting/removing reaction sometimes triggers message edit event. +This bug/feature introduces a security vulnerability in TGPy, +because message reevaluation can be triggered by other users. +""" +import json +from enum import Enum +from hashlib import sha256 + +from telethon.tl.custom import Message + +content_hashes: dict[tuple[int, int], bytes] = {} + + +def get_content_hash(message: Message) -> bytes: + entities = [json.dumps(e.to_dict()) for e in message.entities or []] + data = str(len(entities)) + '\n' + '\n'.join(entities) + message.raw_text + return sha256(data.encode('utf-8')).digest() + + +class ReactionsFixResult(Enum): + ignore = 1 + evaluate = 2 + show_warning = 3 + + +def check_hash(message: Message) -> ReactionsFixResult: + message_uid = (message.chat_id, message.id) + content_hash = get_content_hash(message) + if message_uid not in content_hashes: + return ReactionsFixResult.show_warning + if content_hashes[message_uid] == content_hash: + return ReactionsFixResult.ignore + return ReactionsFixResult.evaluate + + +def update_hash(message: Message) -> None: + message_uid = (message.chat_id, message.id) + content_hashes[message_uid] = get_content_hash(message) + + +__all__ = ['ReactionsFixResult', 'check_hash', 'update_hash'] diff --git a/tgpy/run_code/__init__.py b/tgpy/run_code/__init__.py index 6fab24a..613fbc9 100644 --- a/tgpy/run_code/__init__.py +++ b/tgpy/run_code/__init__.py @@ -1,4 +1,3 @@ -from telethon.errors import MessageIdInvalidError from telethon.tl.custom import Message from tgpy import app, message_design @@ -17,7 +16,7 @@ def get_variable_names(include_orig=True): # fmt: on -async def eval_message(code: str, message: Message, uses_orig=False) -> None: +async def eval_message(code: str, message: Message, uses_orig: bool) -> Message: await message_design.edit_message(message, code, 'Running...') app.ctx.msg = message @@ -48,10 +47,11 @@ async def eval_message(code: str, message: Message, uses_orig=False) -> None: result = convert_result(result) exc = '' - try: - # noinspection PyProtectedMember - await message_design.edit_message( - message, code, result, traceback=exc, output=app.ctx._print_output - ) - except MessageIdInvalidError: - pass + # noinspection PyProtectedMember + return await message_design.edit_message( + message, + code, + result, + traceback=exc, + output=app.ctx._print_output, + ) diff --git a/tgpy/utils.py b/tgpy/utils.py index 8d66713..6bd9e77 100644 --- a/tgpy/utils.py +++ b/tgpy/utils.py @@ -9,6 +9,7 @@ from subprocess import PIPE, Popen import appdirs +from telethon.tl import types from tgpy import version @@ -109,6 +110,17 @@ def get_version(): return 'unknown' +def peer_to_id(peer: types.TypePeer): + if isinstance(peer, types.PeerUser): + return peer.user_id + elif isinstance(peer, types.PeerChat): + return peer.chat_id + elif isinstance(peer, types.PeerChannel): + return peer.channel_id + else: + raise TypeError(f'Unknown peer type: {type(peer)}') + + __all__ = [ 'DATA_DIR', 'MODULES_DIR', @@ -126,4 +138,5 @@ def get_version(): 'get_user', 'get_hostname', 'running_in_docker', + 'peer_to_id', ]