Skip to content

Commit

Permalink
Asyncio drain fix (#522)
Browse files Browse the repository at this point in the history
* Add test for gameconnection error

* Add lock around protocol calls to drain

* Add tests for asyncio drain bug

* Propagate connection errors to server context

* Reset user online guage when the server restarts
  • Loading branch information
Askaholic authored Jan 28, 2020
1 parent e9edac2 commit 0ea040f
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 10 deletions.
4 changes: 3 additions & 1 deletion server/gameconnection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio

from server.db import FAFDatabase
from sqlalchemy import text, select, or_
from sqlalchemy import or_, select, text

from .abc.base_game import GameConnectionState
from .config import TRACE
Expand Down Expand Up @@ -170,6 +170,8 @@ async def handle_action(self, command, args):
)
except (TypeError, ValueError) as e:
self._logger.exception("Bad command arguments: %s", e)
except ConnectionError as e:
raise e
except Exception as e: # pragma: no cover
self._logger.exception(e)
self._logger.exception("Something awful happened in a game thread!")
Expand Down
3 changes: 3 additions & 0 deletions server/lobbyconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ async def on_message_received(self, message):
except (KeyError, ValueError) as ex:
self._logger.exception(ex)
await self.abort("Garbage command: {}".format(message))
except ConnectionError as e:
# Propagate connection errors to the ServerContext error handler.
raise e
except Exception as ex: # pragma: no cover
await self.send({'command': 'invalid'})
self._logger.exception(ex)
Expand Down
13 changes: 10 additions & 3 deletions server/protocol/qdatastreamprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __init__(self, reader: StreamReader, writer: StreamWriter):
self.reader = reader
self.writer = writer

# drain() cannot be called concurrently by multiple coroutines:
# http://bugs.python.org/issue29930.
self._drain_lock = asyncio.Lock()

@staticmethod
def read_qstring(buffer: bytes, pos: int=0) -> Tuple[int, str]:
"""
Expand Down Expand Up @@ -133,7 +137,8 @@ async def send_message(self, message: dict):
self.writer.write(
self.pack_message(json.dumps(message, separators=(',', ':')))
)
await self.writer.drain()
async with self._drain_lock:
await self.writer.drain()

async def send_messages(self, messages):
server.stats.incr('server.sent_messages')
Expand All @@ -142,9 +147,11 @@ async def send_messages(self, messages):
for msg in messages
]
self.writer.writelines(payload)
await self.writer.drain()
async with self._drain_lock:
await self.writer.drain()

async def send_raw(self, data):
server.stats.incr('server.sent_messages')
self.writer.write(data)
await self.writer.drain()
async with self._drain_lock:
await self.writer.drain()
5 changes: 5 additions & 0 deletions server/servercontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def __repr__(self):

async def listen(self, host, port):
self._logger.debug("ServerContext.listen(%s, %s)", host, port)
# TODO: Use tags so we don't need to manually reset each one
server.stats.gauge('user.agents.None', 0)
server.stats.gauge('user.agents.downlords_faf_client', 0)
server.stats.gauge('user.agents.faf_client', 0)

self._server = await asyncio.start_server(
self.client_connected,
host=host,
Expand Down
9 changes: 9 additions & 0 deletions tests/unit_tests/test_gameconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,12 @@ async def test_handle_action_OperationComplete_invalid(ugame: Game, game_connect
row = await result.fetchone()

assert row is None


async def test_handle_action_invalid(game_connection: GameConnection):
game_connection.abort = CoroutineMock()

await game_connection.handle_action('ThisDoesntExist', [1, 2, 3])

game_connection.abort.assert_not_called()
game_connection.protocol.send_message.assert_not_called()
71 changes: 65 additions & 6 deletions tests/unit_tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from asyncio import StreamReader

import asyncio
from unittest import mock
import pytest
import struct
from asyncio import StreamReader
from unittest import mock

import pytest
from server.protocol import QDataStreamProtocol

pytestmark = pytest.mark.asyncio
Expand All @@ -14,21 +13,49 @@
def reader(event_loop):
return StreamReader(loop=event_loop)


@pytest.fixture
def writer():
return mock.Mock()


@pytest.fixture
def protocol(reader, writer):
return QDataStreamProtocol(reader, writer)


async def test_QDataStreamProtocol_recv_small_message(protocol,reader):
@pytest.fixture
def unix_srv(event_loop):
async def do_nothing(client_reader, client_writer):
await client_reader.read()

srv = event_loop.run_until_complete(
asyncio.start_unix_server(do_nothing, '/tmp/test.sock')
)

yield srv

srv.close()
event_loop.run_until_complete(srv.wait_closed())


@pytest.fixture
def unix_protocol(unix_srv, event_loop):
(reader, writer) = event_loop.run_until_complete(
asyncio.open_unix_connection('/tmp/test.sock')
)
protocol = QDataStreamProtocol(reader, writer)
yield protocol

protocol.close()


async def test_QDataStreamProtocol_recv_small_message(protocol, reader):
data = QDataStreamProtocol.pack_block(b''.join([QDataStreamProtocol.pack_qstring('{"some_header": true}'),
QDataStreamProtocol.pack_qstring('Goodbye')]))
reader.feed_data(data)

message =await protocol.read_message()
message = await protocol.read_message()

assert message == {'some_header': True, 'legacy': ['Goodbye']}

Expand Down Expand Up @@ -60,3 +87,35 @@ async def test_unpacks_evil_qstring(protocol, reader):
message = await protocol.read_message()

assert message == {'command': 'ask_session'}


async def test_send_message_simultaneous_writes(unix_protocol):
msg = {
"command": "test",
"data": '*' * (4096*4)
}

# If drain calls are not synchronized, then this will raise an
# AssertionError from within asyncio
await asyncio.gather(*(unix_protocol.send_message(msg) for i in range(20)))


async def test_send_messages_simultaneous_writes(unix_protocol):
msg = {
"command": "test",
"data": '*' * (4096*4)
}

# If drain calls are not synchronized, then this will raise an
# AssertionError from within asyncio
await asyncio.gather(*(
unix_protocol.send_messages((msg, msg)) for i in range(20))
)


async def test_send_raw_simultaneous_writes(unix_protocol):
msg = b'*' * (4096*4)

# If drain calls are not synchronized, then this will raise an
# AssertionError from within asyncio
await asyncio.gather(*(unix_protocol.send_raw(msg) for i in range(20)))

0 comments on commit 0ea040f

Please sign in to comment.