diff --git a/server/ladder_service.py b/server/ladder_service.py index 56032e4aa..ca970db31 100644 --- a/server/ladder_service.py +++ b/server/ladder_service.py @@ -3,10 +3,12 @@ """ import asyncio +import json import random from collections import defaultdict from typing import Dict, List, Optional, Set, Tuple +import aio_pika import aiocron from sqlalchemy import and_, func, select, text, true @@ -29,9 +31,12 @@ matchmaker_queue_map_pool ) from .decorators import with_logger +from .factions import Faction from .game_service import GameService from .games import Game, InitMode, LadderGame from .matchmaker import MapPool, MatchmakerQueue, OnMatchedCallback, Search +from .message_queue_service import MessageQueueService +from .player_service import PlayerService from .players import Player, PlayerState from .protocol import DisconnectedError from .types import GameLaunchOptions, Map, NeroxisGeneratedMap @@ -54,17 +59,35 @@ def __init__( self, database: FAFDatabase, game_service: GameService, + player_service: PlayerService, + message_queue_service: MessageQueueService ): self._db = database self._informed_players: Set[Player] = set() self.game_service = game_service + self.player_service = player_service + self.message_queue_service = message_queue_service self.queues = {} + self._initialized = False self._searches: Dict[Player, Dict[str, Search]] = defaultdict(dict) async def initialize(self) -> None: + if self._initialized: + return + await self.update_data() + await self.message_queue_service.declare_exchange( + config.MQ_EXCHANGE_NAME + ) + await self.message_queue_service.consume( + config.MQ_EXCHANGE_NAME, + "request.match.create", + self.handle_mq_matchmaking_request + ) + self._update_cron = aiocron.crontab("*/10 * * * *", func=self.update_data) + self._initialized = True async def update_data(self) -> None: async with self._db.acquire() as conn: @@ -343,6 +366,135 @@ def write_rating_progress(self, player: Player, rating_type: str) -> None: ) }) + async def handle_mq_matchmaking_request( + self, + message: aio_pika.IncomingMessage + ): + try: + game = await self._handle_mq_matchmaking_request(message) + except Exception as e: + if isinstance(e, GameLaunchError): + code = "launch_failed" + args = [{"player_id": player.id} for player in e.players] + elif isinstance(e, json.JSONDecodeError): + code = "malformed_request" + args = [{"message": str(e)}] + elif isinstance(e, KeyError): + code = "malformed_request" + args = [{"message": f"missing {e.args[0]}"}] + else: + code, *args = e.args + + await self.message_queue_service.publish( + config.MQ_EXCHANGE_NAME, + "error.match.create", + {"error_code": code, "args": args}, + correlation_id=message.correlation_id + ) + else: + await self.message_queue_service.publish( + config.MQ_EXCHANGE_NAME, + "success.match.create", + {"game_id": game.id}, + correlation_id=message.correlation_id + ) + + async def _handle_mq_matchmaking_request( + self, + message: aio_pika.IncomingMessage + ): + self._logger.debug( + "Got matchmaking request: %s", message.correlation_id + ) + request = json.loads(message.body) + # TODO: Use id instead of name? + queue_name = request.get("matchmaker_queue") + map_name = request["map_name"] + game_name = request["game_name"] + participants = request["participants"] + featured_mod = request.get("featured_mod") + if not featured_mod and not queue_name: + raise KeyError("featured_mod") + + if queue_name and queue_name not in self.queues: + raise Exception("invalid_request", "invalid queue") + + if not participants: + raise Exception("invalid_request", "empty participants") + + player_ids = [participant["player_id"] for participant in participants] + missing_players = [ + id for id in player_ids if self.player_service[id] is None + ] + if missing_players: + raise Exception( + "players_not_found", + *[{"player_id": id} for id in missing_players] + ) + + all_players = [ + self.player_service[player_id] for player_id in player_ids + ] + non_idle_players = [ + player for player in all_players + if player.state != PlayerState.IDLE + ] + if non_idle_players: + raise Exception( + "invalid_state", + [ + {"player_id": player.id, "state": player.state.name} + for player in all_players + ] + ) + + queue = self.queues[queue_name] if queue_name else None + featured_mod = featured_mod or queue.featured_mod + host = all_players[0] + guests = all_players[1:] + + for player in all_players: + player.state = PlayerState.STARTING_AUTOMATCH + + try: + game = self.game_service.create_game( + game_class=LadderGame, + game_mode=featured_mod, + host=host, + name="Matchmaker Game", + mapname=map_name, + matchmaker_queue_id=queue.id if queue else None, + rating_type=queue.rating_type if queue else None, + max_players=len(participants) + ) + game.init_mode = InitMode.AUTO_LOBBY + game.set_name_unchecked(game_name) + + for participant in participants: + player_id = participant["player_id"] + faction = Faction.from_value(participant["faction"]) + team = participant["team"] + slot = participant["slot"] + + game.set_player_option(player_id, "Faction", faction.value) + game.set_player_option(player_id, "Team", team) + game.set_player_option(player_id, "StartSpot", slot) + game.set_player_option(player_id, "Army", slot) + game.set_player_option(player_id, "Color", slot) + + await self.launch_game(game, host, guests) + + return game + except Exception: + self._logger.exception("") + await game.on_game_end() + + for player in all_players: + if player.state == PlayerState.STARTING_AUTOMATCH: + player.state = PlayerState.IDLE + + raise + def on_match_found( self, s1: Search, @@ -497,7 +649,7 @@ async def launch_game( def game_options(player: Player) -> GameLaunchOptions: return options._replace( team=game.get_player_option(player.id, "Team"), - faction=player.faction, + faction=game.get_player_option(player.id, "Faction"), map_position=game.get_player_option(player.id, "StartSpot") ) diff --git a/server/message_queue_service.py b/server/message_queue_service.py index cb78de59b..813933b10 100644 --- a/server/message_queue_service.py +++ b/server/message_queue_service.py @@ -4,7 +4,7 @@ import asyncio import json -from typing import Dict +from typing import Callable, Dict, Optional import aio_pika from aio_pika import DeliveryMode, ExchangeType @@ -125,10 +125,11 @@ async def _shutdown(self) -> None: async def publish( self, exchange_name: str, - routing: str, + routing_key: str, payload: Dict, mandatory: bool = False, delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT, + correlation_id: Optional[str] = None ) -> None: if not self._is_ready: self._logger.warning( @@ -136,23 +137,56 @@ async def publish( ) return - exchange = self._exchanges.get(exchange_name) - if exchange is None: - raise KeyError(f"Unknown exchange {exchange_name}.") + exchange = self._get_exchange(exchange_name) message = aio_pika.Message( - json.dumps(payload).encode(), delivery_mode=delivery_mode + json.dumps(payload).encode(), + delivery_mode=delivery_mode, + correlation_id=correlation_id, ) async with self._channel.transaction(): await exchange.publish( message, - routing_key=routing, - mandatory=mandatory + routing_key=routing_key, + mandatory=mandatory, ) self._logger.log( - TRACE, "Published message %s to %s/%s", payload, exchange_name, routing + TRACE, "Published message %s to %s/%s", + payload, + exchange_name, + routing_key + ) + + async def consume( + self, + exchange_name: str, + routing_key: str, + process_message: Callable[[aio_pika.IncomingMessage], None] + ) -> None: + await self.initialize() + if not self._is_ready: + self._logger.warning( + "Not connected to RabbitMQ, unable to declare queue." ) + return + + exchange = self._get_exchange(exchange_name) + queue = await self._channel.declare_queue( + None, + auto_delete=True, + durable=False + ) + + await queue.bind(exchange, routing_key) + await queue.consume(process_message, exclusive=True) + + def _get_exchange(self, exchange_name: str) -> aio_pika.Exchange: + exchange = self._exchanges.get(exchange_name) + if exchange is None: + raise KeyError(f"Unknown exchange {exchange_name}.") + + return exchange @synchronizedmethod("initialization_lock") async def reconnect(self) -> None: diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 55cb483b0..79f5716a5 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -34,9 +34,20 @@ def mock_games(): @pytest.fixture -async def ladder_service(mocker, database, game_service): +async def ladder_service( + mocker, + database, + game_service, + player_service, + message_queue_service +): mocker.patch("server.matchmaker.pop_timer.config.QUEUE_POP_TIME_MAX", 1) - ladder_service = LadderService(database, game_service) + ladder_service = LadderService( + database, + game_service, + player_service, + message_queue_service + ) await ladder_service.initialize() yield ladder_service await ladder_service.shutdown() @@ -384,17 +395,22 @@ async def channel(): await connection.close() -async def connect_mq_consumer(server, channel, routing_key): - """ - Returns a subclass of Protocol that yields messages read from a rabbitmq - exchange. - """ +async def connect_mq_queue(channel, routing_key): exchange = await channel.declare_exchange( config.MQ_EXCHANGE_NAME, aio_pika.ExchangeType.TOPIC ) - queue = await channel.declare_queue("", exclusive=True) + queue = await channel.declare_queue(None, exclusive=True) await queue.bind(exchange, routing_key=routing_key) + return queue + + +async def connect_mq_consumer(channel, routing_key): + """ + Returns a subclass of Protocol that yields messages read from a rabbitmq + exchange. + """ + queue = await connect_mq_queue(channel, routing_key) proto = AioQueueProtocol(queue) await proto.consume() diff --git a/tests/integration_tests/test_matchmaker_requests.py b/tests/integration_tests/test_matchmaker_requests.py new file mode 100644 index 000000000..3e983ce51 --- /dev/null +++ b/tests/integration_tests/test_matchmaker_requests.py @@ -0,0 +1,211 @@ +# External matchmaker requests over rabbitmq +import asyncio +import json +import uuid + +import pytest + +from server.config import config +from tests.utils import fast_forward + +from .conftest import connect_and_sign_in, connect_mq_queue, read_until_command +from .test_game import client_response + +pytestmark = [pytest.mark.asyncio, pytest.mark.rabbitmq] + + +@fast_forward(10) +async def test_valid_request_1v1( + lobby_server, + channel, + message_queue_service +): + test_id, _, proto1 = await connect_and_sign_in( + ("test", "test_password"), lobby_server + ) + rhiza_id, _, proto2 = await connect_and_sign_in( + ("Rhiza", "puff_the_magic_dragon"), lobby_server + ) + success_queue = await connect_mq_queue(channel, "success.match.create") + error_queue = await connect_mq_queue(channel, "error.match.create") + + await asyncio.gather(*( + read_until_command(proto, "game_info") + for proto in (proto1, proto2) + )) + + # Include all the information we could possibly need + correlation_id = str(uuid.uuid4()) + await message_queue_service.publish( + config.MQ_EXCHANGE_NAME, + "request.match.create", + { + "matchmaker_queue": "ladder1v1", + "featured_mod": "ladder1v1", + "game_name": "test VERSUS Rhiza", + "map_name": "scmp_003", + "participants": [ + { + "player_id": test_id, + "team": 2, + "slot": 1, + "faction": "uef" + }, + { + "player_id": rhiza_id, + "team": 3, + "slot": 2, + "faction": "cybran" + } + ] + }, + correlation_id=correlation_id + ) + + msg1, msg2 = await asyncio.gather( + client_response(proto1), + client_response(proto2) + ) + assert msg1["uid"] == msg2["uid"] + assert msg1["mapname"] == msg2["mapname"] + assert msg1["name"] == msg2["name"] + assert msg1["mod"] == msg2["mod"] + assert msg1["rating_type"] == msg2["rating_type"] + assert msg1["expected_players"] == msg2["expected_players"] + + assert msg1["mapname"] == "scmp_003" + assert msg1["name"] == "test VERSUS Rhiza" + assert msg1["mod"] == "ladder1v1" + assert msg1["rating_type"] == "ladder_1v1" + assert msg1["expected_players"] == 2 + + assert msg1["team"] == 2 + assert msg1["map_position"] == 1 + assert msg1["faction"] == 1 + + assert msg2["team"] == 3 + assert msg2["map_position"] == 2 + assert msg2["faction"] == 3 + + await proto1.send_message({ + "target": "game", + "command": "GameState", + "args": ["Launching"] + }) + + message = await success_queue.iterator(timeout=5).__anext__() + assert message.correlation_id == correlation_id + assert json.loads(message.body.decode()) == { + "game_id": msg1["uid"] + } + assert await error_queue.get(fail=False) is None + + +@fast_forward(10) +async def test_player_offline( + lobby_server, + channel, + message_queue_service +): + rhiza_id, _, proto = await connect_and_sign_in( + ("Rhiza", "puff_the_magic_dragon"), lobby_server + ) + success_queue = await connect_mq_queue(channel, "success.match.create") + error_queue = await connect_mq_queue(channel, "error.match.create") + + await read_until_command(proto, "game_info") + + # Include all the information we could possibly need + correlation_id = str(uuid.uuid4()) + await message_queue_service.publish( + config.MQ_EXCHANGE_NAME, + "request.match.create", + { + "matchmaker_queue": "ladder1v1", + "game_name": "test VERSUS Rhiza", + "map_name": "scmp_003", + "participants": [ + { + "player_id": 1, + "team": 2, + "slot": 1, + "faction": "uef" + }, + { + "player_id": rhiza_id, + "team": 3, + "slot": 2, + "faction": "cybran" + } + ] + }, + correlation_id=correlation_id + ) + + message = await error_queue.iterator(timeout=5).__anext__() + assert message.correlation_id == correlation_id + assert json.loads(message.body.decode()) == { + "error_code": "players_not_found", "args": [{"player_id": 1}] + } + assert await success_queue.get(fail=False) is None + + +@fast_forward(100) +async def test_players_dont_connect( + lobby_server, + channel, + message_queue_service +): + test_id, _, proto1 = await connect_and_sign_in( + ("test", "test_password"), lobby_server + ) + rhiza_id, _, proto2 = await connect_and_sign_in( + ("Rhiza", "puff_the_magic_dragon"), lobby_server + ) + success_queue = await connect_mq_queue(channel, "success.match.create") + error_queue = await connect_mq_queue(channel, "error.match.create") + + await asyncio.gather(*( + read_until_command(proto, "game_info") + for proto in (proto1, proto2) + )) + + # Include all the information we could possibly need + correlation_id = str(uuid.uuid4()) + await message_queue_service.publish( + config.MQ_EXCHANGE_NAME, + "request.match.create", + { + "matchmaker_queue": "ladder1v1", + "featured_mod": "faf", + "game_name": "test VERSUS Rhiza", + "map_name": "scmp_003", + "participants": [ + { + "player_id": test_id, + "team": 2, + "slot": 1, + "faction": "aeon" + }, + { + "player_id": rhiza_id, + "team": 3, + "slot": 2, + "faction": "seraphim" + } + ] + }, + correlation_id=correlation_id + ) + + msg = await client_response(proto1) + assert msg["faction"] == 2 + # Mod field sould override the mod from queue + assert msg["mod"] == "faf" + + message = await error_queue.iterator(timeout=85).__anext__() + assert message.correlation_id == correlation_id + assert json.loads(message.body.decode()) == { + "error_code": "launch_failed", "args": [{"player_id": rhiza_id}] + } + assert await success_queue.get(fail=False) is None diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py index 744fb8055..66ee9e39b 100644 --- a/tests/integration_tests/test_server.py +++ b/tests/integration_tests/test_server.py @@ -99,12 +99,10 @@ async def test_player_info_broadcast(lobby_server): @fast_forward(5) async def test_player_info_broadcast_to_rabbitmq(lobby_server, channel): mq_proto = await connect_mq_consumer( - lobby_server, channel, "broadcast.playerInfo.update" ) mq_proto_all = await connect_mq_consumer( - lobby_server, channel, "broadcast.*.update" ) @@ -339,7 +337,6 @@ async def test_game_info_broadcast_to_players_in_lobby(lobby_server): @fast_forward(10) async def test_info_broadcast_to_rabbitmq(lobby_server, channel): mq_proto_all = await connect_mq_consumer( - lobby_server, channel, "broadcast.*.update" ) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 6c622a387..bff51d2b1 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -22,17 +22,25 @@ def ladder_and_game_service_context( @asynccontextmanager async def make_ladder_and_game_service(): async with database_context(request) as database: + player_service = mock.Mock() + message_queue_service = mock.Mock( + declare_exchange=CoroutineMock(), + consume=CoroutineMock() + ) with mock.patch("server.matchmaker.pop_timer.config.QUEUE_POP_TIME_MAX", 1): game_service = GameService( database, - player_service=mock.Mock(), + player_service=player_service, game_stats_service=mock.Mock(), rating_service=mock.Mock(), - message_queue_service=mock.Mock( - declare_exchange=CoroutineMock() - ) + message_queue_service=message_queue_service + ) + ladder_service = LadderService( + database, + game_service, + player_service, + message_queue_service ) - ladder_service = LadderService(database, game_service) await game_service.initialize() await ladder_service.initialize() @@ -50,9 +58,16 @@ async def ladder_service( mocker, database, game_service: GameService, + player_service, + message_queue_service ): mocker.patch("server.matchmaker.pop_timer.config.QUEUE_POP_TIME_MAX", 1) - ladder_service = LadderService(database, game_service) + ladder_service = LadderService( + database, + game_service, + player_service, + message_queue_service + ) await ladder_service.initialize() yield ladder_service diff --git a/tests/unit_tests/test_ladder_service.py b/tests/unit_tests/test_ladder_service.py index 410b2f430..e5f68689c 100644 --- a/tests/unit_tests/test_ladder_service.py +++ b/tests/unit_tests/test_ladder_service.py @@ -35,8 +35,18 @@ def game(database, game_service, game_stats_service): ) -async def test_queue_initialization(database, game_service): - ladder_service = LadderService(database, game_service) +async def test_queue_initialization( + database, + game_service, + player_service, + message_queue_service +): + ladder_service = LadderService( + database, + game_service, + player_service, + message_queue_service + ) def make_mock_queue(*args, **kwargs): queue = create_autospec(MatchmakerQueue)