Skip to content

Commit

Permalink
fix: setting/removing reaction no longer triggers reevaluation
Browse files Browse the repository at this point in the history
however, when TGPy is restarted you will need to edit old messages twice to reevaluate
  • Loading branch information
vanutp committed Aug 8, 2022
1 parent 30c08ce commit cf6e64e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 60 deletions.
7 changes: 4 additions & 3 deletions tgpy/builtin_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from telethon.tl.custom import Message

from tgpy import app
from tgpy.message_design import get_code
from tgpy.message_design import parse_message
from tgpy.modules import (
Module,
delete_module_file,
Expand Down Expand Up @@ -94,9 +94,10 @@ async def add(self, name: str, code: Optional[str] = None) -> str:
original: Message = await app.ctx.msg.get_reply_message()
if original is None:
return 'Use this function in reply to a message'
code = get_code(original)
if not code:
message_data = parse_message(original)
if not message_data.is_tgpy_message:
return 'No code found in reply message'
code = message_data.code

origin = f'{filename_prefix}module/{name}'

Expand Down
79 changes: 55 additions & 24 deletions tgpy/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,79 @@
from telethon.tl.custom import Message
from telethon.tl.types import Channel

from tgpy import app, message_design
from tgpy import app, message_design, reactions_fix
from tgpy.handlers.utils import _handle_errors, outgoing_messages_filter
from tgpy.reactions_fix import ReactionsFixResult
from tgpy.run_code import eval_message, get_variable_names, parse_code


async def handle_message(message: Message) -> None:
raw_text = message.raw_text

if not raw_text:
return
async def handle_message(
message: Message, *, only_show_warning: bool = False
) -> Message:
if not message.raw_text:
return message

if message.text.startswith('//') and message.text[2:].strip():
await message.edit(message.text[2:])
return
return await message.edit(message.text[2:])

locals_ = get_variable_names()

res = parse_code(raw_text, locals_)
res = parse_code(message.raw_text, locals_)
if not res.is_code:
return

await eval_message(raw_text, message, uses_orig=res.uses_orig)
return message

if only_show_warning:
return await message_design.edit_message(
message,
message.raw_text,
'Edit message again to evaluate',
)
else:
return await eval_message(message.raw_text, message, res.uses_orig)


@events.register(events.NewMessage(func=outgoing_messages_filter))
@_handle_errors
async def on_new_message(event: events.NewMessage.Event) -> None:
await handle_message(event.message)
message = await handle_message(event.message)
reactions_fix.update_hash(message)


@events.register(events.MessageEdited(func=outgoing_messages_filter))
@_handle_errors
async def on_message_edited(event: events.NewMessage.Event) -> None:
if isinstance(event.message.chat, Channel) and event.message.chat.broadcast:
return
code = message_design.get_code(event.message)
if not code:
await handle_message(event.message)
async def on_message_edited(event: events.MessageEdited.Event) -> None:
message: Message = event.message
if isinstance(message.chat, Channel) and message.chat.broadcast:
return
await eval_message(
code, event.message, uses_orig=parse_code(code, get_variable_names()).uses_orig
)
message_data = message_design.parse_message(message)
reactions_fix_result = reactions_fix.check_hash(message)
try:
match reactions_fix_result:
case ReactionsFixResult.ignore:
return
case ReactionsFixResult.show_warning:
if message_data.is_tgpy_message:
message = await message_design.edit_message(
message, message_data.code, 'Edit message again to evaluate'
)
else:
message = await handle_message(message, only_show_warning=True)
return
case ReactionsFixResult.evaluate:
pass
case _:
raise ValueError(f'Bad reactions fix result: {reactions_fix_result}')

if not message_data.is_tgpy_message:
message = await handle_message(message)
return
message = await eval_message(
message_data.code,
message,
parse_code(message_data.code, get_variable_names()).uses_orig,
)
finally:
reactions_fix.update_hash(message)


@events.register(
Expand All @@ -55,14 +86,14 @@ async def cancel(message: Message):
async for msg in app.client.iter_messages(
message.chat_id, max_id=message.id, limit=10
):
if msg.out and message_design.get_code(msg):
if msg.out and message_design.parse_message(msg).is_tgpy_message:
prev = msg
break
else:
return
# noinspection PyBroadException
try:
await prev.edit(message_design.get_code(prev))
await prev.edit(message_design.parse_message(prev).code)
except Exception:
pass
else:
Expand Down
96 changes: 72 additions & 24 deletions tgpy/message_design.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import traceback as tb
from dataclasses import dataclass

from telethon.tl.custom import Message
from telethon.tl.types import MessageEntityBold, MessageEntityCode, MessageEntityTextUrl
Expand All @@ -11,57 +12,96 @@
FORMATTED_ERROR_HEADER = f'<b><a href="{TITLE_URL}">TGPy error&gt;</a></b>'


def utf16_codepoints_len(s: str):
return len(s.encode('utf-16-le')) // 2
class Utf16CodepointsWrapper(str):
def __len__(self):
return len(self.encode('utf-16-le')) // 2


def utf16_codepoints_prefix(s: str, length: int):
s = s.encode('utf-16-le')
s = s[: length * 2]
return s.decode('utf-16-le')
def __getitem__(self, item):
s = self.encode('utf-16-le')
if isinstance(item, slice):
item = slice(
item.start * 2 if item.start else None,
item.stop * 2 if item.stop else None,
item.step * 2 if item.step else None,
)
s = s[item]
elif isinstance(item, int):
s = s[item * 2 : item * 2 + 2]
else:
raise TypeError(f'{type(item)} is not supported')
return s.decode('utf-16-le')


async def edit_message(
message: Message, code: str, result, traceback: str = '', output: str = ''
) -> None:
message: Message,
code: str,
result: str,
traceback: str = '',
output: str = '',
) -> Message:
if result is None and output:
result = output
output = ''

parts = [code.strip(), f'{TITLE} {str(result).strip()}']
parts += [part for part in (output.strip(), traceback.strip()) if part]
text = '\n\n'.join(parts)
title = Utf16CodepointsWrapper(TITLE)
parts = [
Utf16CodepointsWrapper(code.strip()),
Utf16CodepointsWrapper(f'{title} {str(result).strip()}'),
]
parts += [
Utf16CodepointsWrapper(part)
for part in (output.strip(), traceback.strip())
if part
]

entities = []
offset = 0
for p in parts:
entities.append(MessageEntityCode(offset, utf16_codepoints_len(p)))
offset += utf16_codepoints_len(p) + 2
entities.append(MessageEntityCode(offset, len(p)))
offset += len(p) + 2

entities[1].offset += utf16_codepoints_len(TITLE) + 1
entities[1].length -= utf16_codepoints_len(TITLE) + 1
entities[1].offset += len(title) + 1
entities[1].length -= len(title) + 1
entities[1:1] = [
MessageEntityBold(
utf16_codepoints_len(parts[0]) + 2,
utf16_codepoints_len(TITLE),
len(parts[0]) + 2,
len(title),
),
MessageEntityTextUrl(
utf16_codepoints_len(parts[0]) + 2,
utf16_codepoints_len(TITLE),
len(parts[0]) + 2,
len(title),
TITLE_URL,
),
]

text = str('\n\n'.join(parts))
if len(text) > 4096:
text = text[:4095] + '…'
await message.edit(text, formatting_entities=entities, link_preview=False)
return await message.edit(text, formatting_entities=entities, link_preview=False)


@dataclass
class MessageParseResult:
is_tgpy_message: bool
code: str | None
result: str | None


def get_code(message: Message) -> str:
def get_title_entity(message: Message) -> MessageEntityTextUrl | None:
for e in message.entities or []:
if isinstance(e, MessageEntityTextUrl) and e.url == TITLE_URL:
return utf16_codepoints_prefix(message.raw_text, length=e.offset).strip()
return ''
return e
return None


def parse_message(message: Message) -> MessageParseResult:
e = get_title_entity(message)
if not e:
return MessageParseResult(False, None, None)
msg_text = Utf16CodepointsWrapper(message.raw_text)
code = msg_text[: e.offset].strip()
result = msg_text[e.offset + e.length :].strip()
return MessageParseResult(True, code, result)


async def send_error(chat) -> None:
Expand All @@ -71,3 +111,11 @@ async def send_error(chat) -> None:
await app.client.send_message(
chat, f'{FORMATTED_ERROR_HEADER}\n\n<code>{exc}</code>', link_preview=False
)


__all__ = [
'edit_message',
'MessageParseResult',
'parse_message',
'send_error',
]
43 changes: 43 additions & 0 deletions tgpy/reactions_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
This module tries to fix Telegram bug/undocumented feature where
setting/removing reaction sometimes triggers message edit event.
This bug/feature introduces a security vulnerability in TGPy,
because message reevaluation can be triggered by other users.
"""
import json
from enum import Enum
from hashlib import sha256

from telethon.tl.custom import Message

content_hashes: dict[tuple[int, int], bytes] = {}


def get_content_hash(message: Message) -> bytes:
entities = [json.dumps(e.to_dict()) for e in message.entities or []]
data = str(len(entities)) + '\n' + '\n'.join(entities) + message.raw_text
return sha256(data.encode('utf-8')).digest()


class ReactionsFixResult(Enum):
ignore = 1
evaluate = 2
show_warning = 3


def check_hash(message: Message) -> ReactionsFixResult:
message_uid = (message.chat_id, message.id)
content_hash = get_content_hash(message)
if message_uid not in content_hashes:
return ReactionsFixResult.show_warning
if content_hashes[message_uid] == content_hash:
return ReactionsFixResult.ignore
return ReactionsFixResult.evaluate


def update_hash(message: Message) -> None:
message_uid = (message.chat_id, message.id)
content_hashes[message_uid] = get_content_hash(message)


__all__ = ['ReactionsFixResult', 'check_hash', 'update_hash']
18 changes: 9 additions & 9 deletions tgpy/run_code/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from telethon.errors import MessageIdInvalidError
from telethon.tl.custom import Message

from tgpy import app, message_design
Expand All @@ -17,7 +16,7 @@ def get_variable_names(include_orig=True):
# fmt: on


async def eval_message(code: str, message: Message, uses_orig=False) -> None:
async def eval_message(code: str, message: Message, uses_orig: bool) -> Message:
await message_design.edit_message(message, code, 'Running...')

app.ctx.msg = message
Expand Down Expand Up @@ -48,10 +47,11 @@ async def eval_message(code: str, message: Message, uses_orig=False) -> None:
result = convert_result(result)
exc = ''

try:
# noinspection PyProtectedMember
await message_design.edit_message(
message, code, result, traceback=exc, output=app.ctx._print_output
)
except MessageIdInvalidError:
pass
# noinspection PyProtectedMember
return await message_design.edit_message(
message,
code,
result,
traceback=exc,
output=app.ctx._print_output,
)
13 changes: 13 additions & 0 deletions tgpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from subprocess import PIPE, Popen

import appdirs
from telethon.tl import types

from tgpy import version

Expand Down Expand Up @@ -109,6 +110,17 @@ def get_version():
return 'unknown'


def peer_to_id(peer: types.TypePeer):
if isinstance(peer, types.PeerUser):
return peer.user_id
elif isinstance(peer, types.PeerChat):
return peer.chat_id
elif isinstance(peer, types.PeerChannel):
return peer.channel_id
else:
raise TypeError(f'Unknown peer type: {type(peer)}')


__all__ = [
'DATA_DIR',
'MODULES_DIR',
Expand All @@ -126,4 +138,5 @@ def get_version():
'get_user',
'get_hostname',
'running_in_docker',
'peer_to_id',
]

0 comments on commit cf6e64e

Please sign in to comment.