diff --git a/tgpy/builtin_functions.py b/tgpy/builtin_functions.py
index ccadd35..a862c35 100644
--- a/tgpy/builtin_functions.py
+++ b/tgpy/builtin_functions.py
@@ -7,7 +7,7 @@
from telethon.tl.custom import Message
from tgpy import app
-from tgpy.message_design import get_code
+from tgpy.message_design import parse_message
from tgpy.modules import (
Module,
delete_module_file,
@@ -94,9 +94,10 @@ async def add(self, name: str, code: Optional[str] = None) -> str:
original: Message = await app.ctx.msg.get_reply_message()
if original is None:
return 'Use this function in reply to a message'
- code = get_code(original)
- if not code:
+ message_data = parse_message(original)
+ if not message_data.is_tgpy_message:
return 'No code found in reply message'
+ code = message_data.code
origin = f'{filename_prefix}module/{name}'
diff --git a/tgpy/handlers/__init__.py b/tgpy/handlers/__init__.py
index 657c031..ce9c117 100644
--- a/tgpy/handlers/__init__.py
+++ b/tgpy/handlers/__init__.py
@@ -2,48 +2,79 @@
from telethon.tl.custom import Message
from telethon.tl.types import Channel
-from tgpy import app, message_design
+from tgpy import app, message_design, reactions_fix
from tgpy.handlers.utils import _handle_errors, outgoing_messages_filter
+from tgpy.reactions_fix import ReactionsFixResult
from tgpy.run_code import eval_message, get_variable_names, parse_code
-async def handle_message(message: Message) -> None:
- raw_text = message.raw_text
-
- if not raw_text:
- return
+async def handle_message(
+ message: Message, *, only_show_warning: bool = False
+) -> Message:
+ if not message.raw_text:
+ return message
if message.text.startswith('//') and message.text[2:].strip():
- await message.edit(message.text[2:])
- return
+ return await message.edit(message.text[2:])
locals_ = get_variable_names()
- res = parse_code(raw_text, locals_)
+ res = parse_code(message.raw_text, locals_)
if not res.is_code:
- return
-
- await eval_message(raw_text, message, uses_orig=res.uses_orig)
+ return message
+
+ if only_show_warning:
+ return await message_design.edit_message(
+ message,
+ message.raw_text,
+ 'Edit message again to evaluate',
+ )
+ else:
+ return await eval_message(message.raw_text, message, res.uses_orig)
@events.register(events.NewMessage(func=outgoing_messages_filter))
@_handle_errors
async def on_new_message(event: events.NewMessage.Event) -> None:
- await handle_message(event.message)
+ message = await handle_message(event.message)
+ reactions_fix.update_hash(message)
@events.register(events.MessageEdited(func=outgoing_messages_filter))
@_handle_errors
-async def on_message_edited(event: events.NewMessage.Event) -> None:
- if isinstance(event.message.chat, Channel) and event.message.chat.broadcast:
- return
- code = message_design.get_code(event.message)
- if not code:
- await handle_message(event.message)
+async def on_message_edited(event: events.MessageEdited.Event) -> None:
+ message: Message = event.message
+ if isinstance(message.chat, Channel) and message.chat.broadcast:
return
- await eval_message(
- code, event.message, uses_orig=parse_code(code, get_variable_names()).uses_orig
- )
+ message_data = message_design.parse_message(message)
+ reactions_fix_result = reactions_fix.check_hash(message)
+ try:
+ match reactions_fix_result:
+ case ReactionsFixResult.ignore:
+ return
+ case ReactionsFixResult.show_warning:
+ if message_data.is_tgpy_message:
+ message = await message_design.edit_message(
+ message, message_data.code, 'Edit message again to evaluate'
+ )
+ else:
+ message = await handle_message(message, only_show_warning=True)
+ return
+ case ReactionsFixResult.evaluate:
+ pass
+ case _:
+ raise ValueError(f'Bad reactions fix result: {reactions_fix_result}')
+
+ if not message_data.is_tgpy_message:
+ message = await handle_message(message)
+ return
+ message = await eval_message(
+ message_data.code,
+ message,
+ parse_code(message_data.code, get_variable_names()).uses_orig,
+ )
+ finally:
+ reactions_fix.update_hash(message)
@events.register(
@@ -55,14 +86,14 @@ async def cancel(message: Message):
async for msg in app.client.iter_messages(
message.chat_id, max_id=message.id, limit=10
):
- if msg.out and message_design.get_code(msg):
+ if msg.out and message_design.parse_message(msg).is_tgpy_message:
prev = msg
break
else:
return
# noinspection PyBroadException
try:
- await prev.edit(message_design.get_code(prev))
+ await prev.edit(message_design.parse_message(prev).code)
except Exception:
pass
else:
diff --git a/tgpy/message_design.py b/tgpy/message_design.py
index 2509464..8db597d 100644
--- a/tgpy/message_design.py
+++ b/tgpy/message_design.py
@@ -1,5 +1,6 @@
import sys
import traceback as tb
+from dataclasses import dataclass
from telethon.tl.custom import Message
from telethon.tl.types import MessageEntityBold, MessageEntityCode, MessageEntityTextUrl
@@ -11,57 +12,96 @@
FORMATTED_ERROR_HEADER = f'TGPy error>'
-def utf16_codepoints_len(s: str):
- return len(s.encode('utf-16-le')) // 2
+class Utf16CodepointsWrapper(str):
+ def __len__(self):
+ return len(self.encode('utf-16-le')) // 2
-
-def utf16_codepoints_prefix(s: str, length: int):
- s = s.encode('utf-16-le')
- s = s[: length * 2]
- return s.decode('utf-16-le')
+ def __getitem__(self, item):
+ s = self.encode('utf-16-le')
+ if isinstance(item, slice):
+ item = slice(
+ item.start * 2 if item.start else None,
+ item.stop * 2 if item.stop else None,
+ item.step * 2 if item.step else None,
+ )
+ s = s[item]
+ elif isinstance(item, int):
+ s = s[item * 2 : item * 2 + 2]
+ else:
+ raise TypeError(f'{type(item)} is not supported')
+ return s.decode('utf-16-le')
async def edit_message(
- message: Message, code: str, result, traceback: str = '', output: str = ''
-) -> None:
+ message: Message,
+ code: str,
+ result: str,
+ traceback: str = '',
+ output: str = '',
+) -> Message:
if result is None and output:
result = output
output = ''
- parts = [code.strip(), f'{TITLE} {str(result).strip()}']
- parts += [part for part in (output.strip(), traceback.strip()) if part]
- text = '\n\n'.join(parts)
+ title = Utf16CodepointsWrapper(TITLE)
+ parts = [
+ Utf16CodepointsWrapper(code.strip()),
+ Utf16CodepointsWrapper(f'{title} {str(result).strip()}'),
+ ]
+ parts += [
+ Utf16CodepointsWrapper(part)
+ for part in (output.strip(), traceback.strip())
+ if part
+ ]
entities = []
offset = 0
for p in parts:
- entities.append(MessageEntityCode(offset, utf16_codepoints_len(p)))
- offset += utf16_codepoints_len(p) + 2
+ entities.append(MessageEntityCode(offset, len(p)))
+ offset += len(p) + 2
- entities[1].offset += utf16_codepoints_len(TITLE) + 1
- entities[1].length -= utf16_codepoints_len(TITLE) + 1
+ entities[1].offset += len(title) + 1
+ entities[1].length -= len(title) + 1
entities[1:1] = [
MessageEntityBold(
- utf16_codepoints_len(parts[0]) + 2,
- utf16_codepoints_len(TITLE),
+ len(parts[0]) + 2,
+ len(title),
),
MessageEntityTextUrl(
- utf16_codepoints_len(parts[0]) + 2,
- utf16_codepoints_len(TITLE),
+ len(parts[0]) + 2,
+ len(title),
TITLE_URL,
),
]
+ text = str('\n\n'.join(parts))
if len(text) > 4096:
text = text[:4095] + '…'
- await message.edit(text, formatting_entities=entities, link_preview=False)
+ return await message.edit(text, formatting_entities=entities, link_preview=False)
+
+
+@dataclass
+class MessageParseResult:
+ is_tgpy_message: bool
+ code: str | None
+ result: str | None
-def get_code(message: Message) -> str:
+def get_title_entity(message: Message) -> MessageEntityTextUrl | None:
for e in message.entities or []:
if isinstance(e, MessageEntityTextUrl) and e.url == TITLE_URL:
- return utf16_codepoints_prefix(message.raw_text, length=e.offset).strip()
- return ''
+ return e
+ return None
+
+
+def parse_message(message: Message) -> MessageParseResult:
+ e = get_title_entity(message)
+ if not e:
+ return MessageParseResult(False, None, None)
+ msg_text = Utf16CodepointsWrapper(message.raw_text)
+ code = msg_text[: e.offset].strip()
+ result = msg_text[e.offset + e.length :].strip()
+ return MessageParseResult(True, code, result)
async def send_error(chat) -> None:
@@ -71,3 +111,11 @@ async def send_error(chat) -> None:
await app.client.send_message(
chat, f'{FORMATTED_ERROR_HEADER}\n\n{exc}
', link_preview=False
)
+
+
+__all__ = [
+ 'edit_message',
+ 'MessageParseResult',
+ 'parse_message',
+ 'send_error',
+]
diff --git a/tgpy/reactions_fix.py b/tgpy/reactions_fix.py
new file mode 100644
index 0000000..0522403
--- /dev/null
+++ b/tgpy/reactions_fix.py
@@ -0,0 +1,43 @@
+"""
+This module tries to fix Telegram bug/undocumented feature where
+setting/removing reaction sometimes triggers message edit event.
+This bug/feature introduces a security vulnerability in TGPy,
+because message reevaluation can be triggered by other users.
+"""
+import json
+from enum import Enum
+from hashlib import sha256
+
+from telethon.tl.custom import Message
+
+content_hashes: dict[tuple[int, int], bytes] = {}
+
+
+def get_content_hash(message: Message) -> bytes:
+ entities = [json.dumps(e.to_dict()) for e in message.entities or []]
+ data = str(len(entities)) + '\n' + '\n'.join(entities) + message.raw_text
+ return sha256(data.encode('utf-8')).digest()
+
+
+class ReactionsFixResult(Enum):
+ ignore = 1
+ evaluate = 2
+ show_warning = 3
+
+
+def check_hash(message: Message) -> ReactionsFixResult:
+ message_uid = (message.chat_id, message.id)
+ content_hash = get_content_hash(message)
+ if message_uid not in content_hashes:
+ return ReactionsFixResult.show_warning
+ if content_hashes[message_uid] == content_hash:
+ return ReactionsFixResult.ignore
+ return ReactionsFixResult.evaluate
+
+
+def update_hash(message: Message) -> None:
+ message_uid = (message.chat_id, message.id)
+ content_hashes[message_uid] = get_content_hash(message)
+
+
+__all__ = ['ReactionsFixResult', 'check_hash', 'update_hash']
diff --git a/tgpy/run_code/__init__.py b/tgpy/run_code/__init__.py
index 6fab24a..613fbc9 100644
--- a/tgpy/run_code/__init__.py
+++ b/tgpy/run_code/__init__.py
@@ -1,4 +1,3 @@
-from telethon.errors import MessageIdInvalidError
from telethon.tl.custom import Message
from tgpy import app, message_design
@@ -17,7 +16,7 @@ def get_variable_names(include_orig=True):
# fmt: on
-async def eval_message(code: str, message: Message, uses_orig=False) -> None:
+async def eval_message(code: str, message: Message, uses_orig: bool) -> Message:
await message_design.edit_message(message, code, 'Running...')
app.ctx.msg = message
@@ -48,10 +47,11 @@ async def eval_message(code: str, message: Message, uses_orig=False) -> None:
result = convert_result(result)
exc = ''
- try:
- # noinspection PyProtectedMember
- await message_design.edit_message(
- message, code, result, traceback=exc, output=app.ctx._print_output
- )
- except MessageIdInvalidError:
- pass
+ # noinspection PyProtectedMember
+ return await message_design.edit_message(
+ message,
+ code,
+ result,
+ traceback=exc,
+ output=app.ctx._print_output,
+ )
diff --git a/tgpy/utils.py b/tgpy/utils.py
index 8d66713..6bd9e77 100644
--- a/tgpy/utils.py
+++ b/tgpy/utils.py
@@ -9,6 +9,7 @@
from subprocess import PIPE, Popen
import appdirs
+from telethon.tl import types
from tgpy import version
@@ -109,6 +110,17 @@ def get_version():
return 'unknown'
+def peer_to_id(peer: types.TypePeer):
+ if isinstance(peer, types.PeerUser):
+ return peer.user_id
+ elif isinstance(peer, types.PeerChat):
+ return peer.chat_id
+ elif isinstance(peer, types.PeerChannel):
+ return peer.channel_id
+ else:
+ raise TypeError(f'Unknown peer type: {type(peer)}')
+
+
__all__ = [
'DATA_DIR',
'MODULES_DIR',
@@ -126,4 +138,5 @@ def get_version():
'get_user',
'get_hostname',
'running_in_docker',
+ 'peer_to_id',
]