From 47b93094a825e5cdce91f4640d601eb29a7336c5 Mon Sep 17 00:00:00 2001 From: Alex Willmer Date: Wed, 2 Aug 2023 11:00:37 +0100 Subject: [PATCH] mitogen: Factor MessageHeader class out of Message --- mitogen/core.py | 96 +++++++++++++++++++++++++--------- mitogen/service.py | 2 +- tests/message_test.py | 62 ++++++++++++++++++---- tests/mitogen_protocol_test.py | 2 +- 4 files changed, 124 insertions(+), 38 deletions(-) diff --git a/mitogen/core.py b/mitogen/core.py index bee722e63..b55086bf4 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -44,6 +44,7 @@ import linecache import logging import os +import operator import pickle as py_pickle import pstats import signal @@ -767,6 +768,54 @@ def find_class(self, module, func): _Unpickler = pickle.Unpickler +class MessageError(Error): pass +class MessageMagicError(MessageError): pass +class MessageSizeError(MessageError): pass + + +class MessageHeader(tuple): + __slots__ = () + _struct = struct.Struct('>hLLLLLL') + MAGIC = 0x4d49 # b'MI' + SIZE = _struct.size + + def __new__(cls, magic, dst, src, auth, handle, reply_to, data_size): + args = (magic, dst, src, auth, handle, reply_to, data_size) + return tuple.__new__(cls, args) + + magic = property(operator.itemgetter(0)) + dst = property(operator.itemgetter(1)) + src = property(operator.itemgetter(2)) + auth = property(operator.itemgetter(3)) + handle = property(operator.itemgetter(4)) + reply_to = property(operator.itemgetter(5)) + data_size = property(operator.itemgetter(6)) + + @classmethod + def unpack(cls, buffer, max_message_size): + self = cls(*cls._struct.unpack(buffer)) + if self.magic != cls.MAGIC: + raise MessageMagicError( + 'Expected magic %x, got %x' % (cls.MAGIC, self.magic), + ) + if self.data_size > max_message_size: + raise MessageSizeError( + 'Maximum size exceeded (got %d, max %d)' + % (self.data_size, max_message_size), + ) + return self + + def pack(self): + return self._struct.pack(*self) + + def __repr__(self): + return '%s.%s(magic=%d, dst=%d, src=%d, auth=%d, handle=%d, reply_to=%d, data_size=%d)' % ( + self.__class__.__module__, self.__class__.__name__, + self.magic, self.dst, self.src, self.auth, self.handle, + self.reply_to, self.data_size, + ) + + class Message(object): """ Messages are the fundamental unit of communication, comprising fields from @@ -810,10 +859,6 @@ class Message(object): #: the :class:`mitogen.select.Select` interface. Defaults to :data:`None`. receiver = None - HEADER_FMT = '>hLLLLLL' - HEADER_LEN = struct.calcsize(HEADER_FMT) - HEADER_MAGIC = 0x4d49 # 'MI' - def __init__(self, **kwargs): """ Construct a message from from the supplied `kwargs`. :attr:`src_id` and @@ -825,12 +870,11 @@ def __init__(self, **kwargs): assert isinstance(self.data, BytesType), 'Message data is not Bytes' def pack(self): - return ( - struct.pack(self.HEADER_FMT, self.HEADER_MAGIC, self.dst_id, - self.src_id, self.auth_id, self.handle, - self.reply_to or 0, len(self.data)) - + self.data + hdr = MessageHeader( + MessageHeader.MAGIC, self.dst_id, self.src_id, self.auth_id, + self.handle, self.reply_to or 0, len(self.data), ) + return hdr.pack() + self.data def _unpickle_context(self, context_id, name): return _unpickle_context(context_id, name, router=self.router) @@ -2138,37 +2182,32 @@ def on_receive(self, broker, buf): ) def _receive_one(self, broker): - if self._input_buf_len < Message.HEADER_LEN: + if self._input_buf_len < MessageHeader.SIZE: return False - msg = Message() - msg.router = self._router - (magic, msg.dst_id, msg.src_id, msg.auth_id, - msg.handle, msg.reply_to, msg_len) = struct.unpack( - Message.HEADER_FMT, - self._input_buf[0][:Message.HEADER_LEN], - ) - - if magic != Message.HEADER_MAGIC: + try: + hdr = MessageHeader.unpack( + self._input_buf[0][:MessageHeader.SIZE], + self._router.max_message_size, + ) + except MessageMagicError: LOG.error(self.corrupt_msg, self.stream.name, self._input_buf[0][:2048]) self.stream.on_disconnect(broker) return False - - if msg_len > self._router.max_message_size: - LOG.error('%r: Maximum message size exceeded (got %d, max %d)', - self, msg_len, self._router.max_message_size) + except MessageSizeError as exc: + LOG.error('%r: %s', self, exc) self.stream.on_disconnect(broker) return False - total_len = msg_len + Message.HEADER_LEN + total_len = MessageHeader.SIZE + hdr.data_size if self._input_buf_len < total_len: _vv and IOLOG.debug( '%r: Input too short (want %d, got %d)', - self, msg_len, self._input_buf_len - Message.HEADER_LEN + self, hdr.data_size, self._input_buf_len - MessageHeader.SIZE ) return False - start = Message.HEADER_LEN + start = MessageHeader.SIZE prev_start = start remain = total_len bits = [] @@ -2180,6 +2219,11 @@ def _receive_one(self, broker): prev_start = start start = 0 + msg = Message( + dst_id=hdr.dst, src_id=hdr.src, auth_id=hdr.auth, + handle=hdr.handle, reply_to=hdr.reply_to, + ) + msg.router = self._router msg.data = b('').join(bits) self._input_buf.appendleft(buf[prev_start+len(bit):]) self._input_buf_len -= total_len diff --git a/mitogen/service.py b/mitogen/service.py index 0e5f64197..41e7a9097 100644 --- a/mitogen/service.py +++ b/mitogen/service.py @@ -981,7 +981,7 @@ def on_shutdown(self): # The IO loop pumps 128KiB chunks. An ideal message is a multiple of this, # odd-sized messages waste one tiny write() per message on the trailer. # Therefore subtract 10 bytes pickle overhead + 24 bytes header. - IO_SIZE = mitogen.core.CHUNK_SIZE - (mitogen.core.Message.HEADER_LEN + ( + IO_SIZE = mitogen.core.CHUNK_SIZE - (mitogen.core.MessageHeader.SIZE + ( len( mitogen.core.Message.pickled( mitogen.core.Blob(b(' ') * mitogen.core.CHUNK_SIZE) diff --git a/tests/message_test.py b/tests/message_test.py index 2d2299d1b..85ee78b56 100644 --- a/tests/message_test.py +++ b/tests/message_test.py @@ -11,6 +11,57 @@ from mitogen.core import b +class MessageHeaderTest(testlib.TestCase): + def test_attributes(self): + hdr = mitogen.core.MessageHeader(1, 2, 3, 4, 5, 6, 7) + self.assertEqual(hdr.magic, 1) + self.assertEqual(hdr.dst, 2) + self.assertEqual(hdr.src, 3) + self.assertEqual(hdr.auth, 4) + self.assertEqual(hdr.handle, 5) + self.assertEqual(hdr.reply_to, 6) + self.assertEqual(hdr.data_size, 7) + + def test_unpack(self): + hdr1 = mitogen.core.MessageHeader(0x4d49, 2, 3, 4, 5, 6, 7) + hdr2 = mitogen.core.MessageHeader.unpack( + b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07', + max_message_size=100, + ) + self.assertEqual(hdr1, hdr2) + + self.assertRaises( + mitogen.core.MessageMagicError, + mitogen.core.MessageHeader.unpack, + b'AB\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07', + max_message_size=100, + ) + + self.assertRaises( + mitogen.core.MessageSizeError, + mitogen.core.MessageHeader.unpack, + b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07', + max_message_size=6, + ) + + def test_pack(self): + hdr = mitogen.core.MessageHeader( + mitogen.core.MessageHeader.MAGIC, 2, 3, 4, 5, 6, 7, + ) + self.assertEqual( + hdr.pack(), + b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07', + ) + self.assertEqual(len(hdr.pack()), mitogen.core.MessageHeader.SIZE) + + def test_repr(self): + hdr = mitogen.core.MessageHeader(1, 2, 3, 4, 5, 6, 7) + self.assertEqual( + repr(hdr), + 'mitogen.core.MessageHeader(magic=1, dst=2, src=3, auth=4, handle=5, reply_to=6, data_size=7)', + ) + + class ConstructorTest(testlib.TestCase): klass = mitogen.core.Message @@ -64,18 +115,9 @@ def test_data_hates_unicode(self): class PackTest(testlib.TestCase): klass = mitogen.core.Message - def test_header_format_sanity(self): - self.assertEqual(self.klass.HEADER_LEN, - struct.calcsize(self.klass.HEADER_FMT)) - def test_header_length_correct(self): s = self.klass(dst_id=123, handle=123).pack() - self.assertEqual(len(s), self.klass.HEADER_LEN) - - def test_magic(self): - s = self.klass(dst_id=123, handle=123).pack() - magic, = struct.unpack('>h', s[:2]) - self.assertEqual(self.klass.HEADER_MAGIC, magic) + self.assertEqual(len(s), mitogen.core.MessageHeader.SIZE) def test_dst_id(self): s = self.klass(dst_id=123, handle=123).pack() diff --git a/tests/mitogen_protocol_test.py b/tests/mitogen_protocol_test.py index d6e3cc959..eed8e5ecd 100644 --- a/tests/mitogen_protocol_test.py +++ b/tests/mitogen_protocol_test.py @@ -16,7 +16,7 @@ def test_corruption(self): protocol = self.klass(router, 1) protocol.stream = stream - junk = mitogen.core.b('x') * mitogen.core.Message.HEADER_LEN + junk = mitogen.core.b('x') * mitogen.core.MessageHeader.SIZE capture = testlib.LogCapturer() capture.start()