Skip to content

Commit

Permalink
mitogen: Factor MessageHeader class out of Message
Browse files Browse the repository at this point in the history
  • Loading branch information
moreati committed Aug 2, 2023
1 parent ec212a1 commit 47b9309
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 38 deletions.
96 changes: 70 additions & 26 deletions mitogen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import linecache
import logging
import os
import operator
import pickle as py_pickle
import pstats
import signal
Expand Down Expand Up @@ -767,6 +768,54 @@ def find_class(self, module, func):
_Unpickler = pickle.Unpickler


class MessageError(Error): pass
class MessageMagicError(MessageError): pass
class MessageSizeError(MessageError): pass


class MessageHeader(tuple):
__slots__ = ()
_struct = struct.Struct('>hLLLLLL')
MAGIC = 0x4d49 # b'MI'
SIZE = _struct.size

def __new__(cls, magic, dst, src, auth, handle, reply_to, data_size):
args = (magic, dst, src, auth, handle, reply_to, data_size)
return tuple.__new__(cls, args)

magic = property(operator.itemgetter(0))
dst = property(operator.itemgetter(1))
src = property(operator.itemgetter(2))
auth = property(operator.itemgetter(3))
handle = property(operator.itemgetter(4))
reply_to = property(operator.itemgetter(5))
data_size = property(operator.itemgetter(6))

@classmethod
def unpack(cls, buffer, max_message_size):
self = cls(*cls._struct.unpack(buffer))
if self.magic != cls.MAGIC:
raise MessageMagicError(
'Expected magic %x, got %x' % (cls.MAGIC, self.magic),
)
if self.data_size > max_message_size:
raise MessageSizeError(
'Maximum size exceeded (got %d, max %d)'
% (self.data_size, max_message_size),
)
return self

def pack(self):
return self._struct.pack(*self)

def __repr__(self):
return '%s.%s(magic=%d, dst=%d, src=%d, auth=%d, handle=%d, reply_to=%d, data_size=%d)' % (
self.__class__.__module__, self.__class__.__name__,
self.magic, self.dst, self.src, self.auth, self.handle,
self.reply_to, self.data_size,
)


class Message(object):
"""
Messages are the fundamental unit of communication, comprising fields from
Expand Down Expand Up @@ -810,10 +859,6 @@ class Message(object):
#: the :class:`mitogen.select.Select` interface. Defaults to :data:`None`.
receiver = None

HEADER_FMT = '>hLLLLLL'
HEADER_LEN = struct.calcsize(HEADER_FMT)
HEADER_MAGIC = 0x4d49 # 'MI'

def __init__(self, **kwargs):
"""
Construct a message from from the supplied `kwargs`. :attr:`src_id` and
Expand All @@ -825,12 +870,11 @@ def __init__(self, **kwargs):
assert isinstance(self.data, BytesType), 'Message data is not Bytes'

def pack(self):
return (
struct.pack(self.HEADER_FMT, self.HEADER_MAGIC, self.dst_id,
self.src_id, self.auth_id, self.handle,
self.reply_to or 0, len(self.data))
+ self.data
hdr = MessageHeader(
MessageHeader.MAGIC, self.dst_id, self.src_id, self.auth_id,
self.handle, self.reply_to or 0, len(self.data),
)
return hdr.pack() + self.data

def _unpickle_context(self, context_id, name):
return _unpickle_context(context_id, name, router=self.router)
Expand Down Expand Up @@ -2138,37 +2182,32 @@ def on_receive(self, broker, buf):
)

def _receive_one(self, broker):
if self._input_buf_len < Message.HEADER_LEN:
if self._input_buf_len < MessageHeader.SIZE:
return False

msg = Message()
msg.router = self._router
(magic, msg.dst_id, msg.src_id, msg.auth_id,
msg.handle, msg.reply_to, msg_len) = struct.unpack(
Message.HEADER_FMT,
self._input_buf[0][:Message.HEADER_LEN],
)

if magic != Message.HEADER_MAGIC:
try:
hdr = MessageHeader.unpack(
self._input_buf[0][:MessageHeader.SIZE],
self._router.max_message_size,
)
except MessageMagicError:
LOG.error(self.corrupt_msg, self.stream.name, self._input_buf[0][:2048])
self.stream.on_disconnect(broker)
return False

if msg_len > self._router.max_message_size:
LOG.error('%r: Maximum message size exceeded (got %d, max %d)',
self, msg_len, self._router.max_message_size)
except MessageSizeError as exc:
LOG.error('%r: %s', self, exc)
self.stream.on_disconnect(broker)
return False

total_len = msg_len + Message.HEADER_LEN
total_len = MessageHeader.SIZE + hdr.data_size
if self._input_buf_len < total_len:
_vv and IOLOG.debug(
'%r: Input too short (want %d, got %d)',
self, msg_len, self._input_buf_len - Message.HEADER_LEN
self, hdr.data_size, self._input_buf_len - MessageHeader.SIZE
)
return False

start = Message.HEADER_LEN
start = MessageHeader.SIZE
prev_start = start
remain = total_len
bits = []
Expand All @@ -2180,6 +2219,11 @@ def _receive_one(self, broker):
prev_start = start
start = 0

msg = Message(
dst_id=hdr.dst, src_id=hdr.src, auth_id=hdr.auth,
handle=hdr.handle, reply_to=hdr.reply_to,
)
msg.router = self._router
msg.data = b('').join(bits)
self._input_buf.appendleft(buf[prev_start+len(bit):])
self._input_buf_len -= total_len
Expand Down
2 changes: 1 addition & 1 deletion mitogen/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def on_shutdown(self):
# The IO loop pumps 128KiB chunks. An ideal message is a multiple of this,
# odd-sized messages waste one tiny write() per message on the trailer.
# Therefore subtract 10 bytes pickle overhead + 24 bytes header.
IO_SIZE = mitogen.core.CHUNK_SIZE - (mitogen.core.Message.HEADER_LEN + (
IO_SIZE = mitogen.core.CHUNK_SIZE - (mitogen.core.MessageHeader.SIZE + (
len(
mitogen.core.Message.pickled(
mitogen.core.Blob(b(' ') * mitogen.core.CHUNK_SIZE)
Expand Down
62 changes: 52 additions & 10 deletions tests/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,57 @@
from mitogen.core import b


class MessageHeaderTest(testlib.TestCase):
def test_attributes(self):
hdr = mitogen.core.MessageHeader(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(hdr.magic, 1)
self.assertEqual(hdr.dst, 2)
self.assertEqual(hdr.src, 3)
self.assertEqual(hdr.auth, 4)
self.assertEqual(hdr.handle, 5)
self.assertEqual(hdr.reply_to, 6)
self.assertEqual(hdr.data_size, 7)

def test_unpack(self):
hdr1 = mitogen.core.MessageHeader(0x4d49, 2, 3, 4, 5, 6, 7)
hdr2 = mitogen.core.MessageHeader.unpack(
b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07',
max_message_size=100,
)
self.assertEqual(hdr1, hdr2)

self.assertRaises(
mitogen.core.MessageMagicError,
mitogen.core.MessageHeader.unpack,
b'AB\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07',
max_message_size=100,
)

self.assertRaises(
mitogen.core.MessageSizeError,
mitogen.core.MessageHeader.unpack,
b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07',
max_message_size=6,
)

def test_pack(self):
hdr = mitogen.core.MessageHeader(
mitogen.core.MessageHeader.MAGIC, 2, 3, 4, 5, 6, 7,
)
self.assertEqual(
hdr.pack(),
b'MI\0\0\0\x02\0\0\0\x03\0\0\0\x04\0\0\0\x05\0\0\0\x06\0\0\0\x07',
)
self.assertEqual(len(hdr.pack()), mitogen.core.MessageHeader.SIZE)

def test_repr(self):
hdr = mitogen.core.MessageHeader(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(
repr(hdr),
'mitogen.core.MessageHeader(magic=1, dst=2, src=3, auth=4, handle=5, reply_to=6, data_size=7)',
)


class ConstructorTest(testlib.TestCase):
klass = mitogen.core.Message

Expand Down Expand Up @@ -64,18 +115,9 @@ def test_data_hates_unicode(self):
class PackTest(testlib.TestCase):
klass = mitogen.core.Message

def test_header_format_sanity(self):
self.assertEqual(self.klass.HEADER_LEN,
struct.calcsize(self.klass.HEADER_FMT))

def test_header_length_correct(self):
s = self.klass(dst_id=123, handle=123).pack()
self.assertEqual(len(s), self.klass.HEADER_LEN)

def test_magic(self):
s = self.klass(dst_id=123, handle=123).pack()
magic, = struct.unpack('>h', s[:2])
self.assertEqual(self.klass.HEADER_MAGIC, magic)
self.assertEqual(len(s), mitogen.core.MessageHeader.SIZE)

def test_dst_id(self):
s = self.klass(dst_id=123, handle=123).pack()
Expand Down
2 changes: 1 addition & 1 deletion tests/mitogen_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_corruption(self):
protocol = self.klass(router, 1)
protocol.stream = stream

junk = mitogen.core.b('x') * mitogen.core.Message.HEADER_LEN
junk = mitogen.core.b('x') * mitogen.core.MessageHeader.SIZE

capture = testlib.LogCapturer()
capture.start()
Expand Down

0 comments on commit 47b9309

Please sign in to comment.