diff --git a/server/__init__.py b/server/__init__.py index 8f59ee643..e697034d6 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -15,11 +15,11 @@ import server.metrics as metrics from .api.api_accessor import ApiAccessor -from .asyncio_extensions import synchronizedmethod from .config import TRACE, config from .configuration_service import ConfigurationService from .control import run_control_server from .core import Service, create_services +from .core.asyncio_extensions import synchronizedmethod from .db import FAFDatabase from .game_service import GameService from .gameconnection import GameConnection @@ -101,7 +101,8 @@ def __init__( "loop": self.loop, }) - self.connection_factory = lambda: LobbyConnection( + self.connection_factory = lambda proto, addr: LobbyConnection( + proto, addr, database=database, geoip=self.services["geo_ip_service"], game_service=self.services["game_service"], diff --git a/server/core/__init__.py b/server/core/__init__.py index f0c3dc424..320475a77 100644 --- a/server/core/__init__.py +++ b/server/core/__init__.py @@ -1,8 +1,15 @@ +from .connection import Connection from .dependency_injector import DependencyInjector +from .protocol import Protocol +from .router import RouteError, Router from .service import Service, create_services __all__ = ( + "Connection", "DependencyInjector", + "Protocol", + "RouteError", + "Router", "Service", "create_services" ) diff --git a/server/asyncio_extensions.py b/server/core/asyncio_extensions.py similarity index 95% rename from server/asyncio_extensions.py rename to server/core/asyncio_extensions.py index 4a44d214e..5842b0972 100644 --- a/server/asyncio_extensions.py +++ b/server/core/asyncio_extensions.py @@ -57,17 +57,17 @@ def __call__(self, *args): @overload def synchronized() -> AsyncDecorator: - ... + ... # pragma: no cover @overload def synchronized(function: AsyncFunc) -> AsyncFunc: - ... + ... # pragma: no cover @overload def synchronized(lock: Optional[asyncio.Lock]) -> AsyncDecorator: - ... + ... # pragma: no cover def synchronized(*args): @@ -102,17 +102,17 @@ async def wrapped(*args, **kwargs): @overload def synchronizedmethod() -> AsyncDecorator: - ... + ... # pragma: no cover @overload def synchronizedmethod(function: AsyncFunc) -> AsyncFunc: - ... + ... # pragma: no cover @overload def synchronizedmethod(lock_name: Optional[str]) -> AsyncDecorator: - ... + ... # pragma: no cover def synchronizedmethod(*args): diff --git a/server/core/connection.py b/server/core/connection.py new file mode 100644 index 000000000..259d9a9de --- /dev/null +++ b/server/core/connection.py @@ -0,0 +1,95 @@ +from typing import Any + +from .protocol import Protocol +from .router import RouteError, Router +from .typedefs import Address, Handler, Message + + +class handler(): + """ + Decorator for adding a handler to a connection + """ + def __init__(self, key: Any = Router.missing, **filters: Any): + self.func = None + self.key = key + self.filters = filters + + def __call__(self, func: Handler) -> Handler: + self.func = func + return self + + +class ConnectionMeta(type): + def __new__(cls, name, bases, namespace): + if "router" not in namespace: + namespace["router"] = Router("command") + + router = namespace["router"] + for attrname, value in list(namespace.items()): + if isinstance(value, handler) and value.func: + router.register_func( + attrname, + value.key, + **value.filters + ) + # Unwrap the handler function + namespace[attrname] = value.func + return super().__new__(cls, name, bases, namespace) + + +class Connection(metaclass=ConnectionMeta): + """ + An object responsible for handling the lifecycle of a connection. Message + handlers can be added with the `handler` decorator. + + # Example + ``` + class FooConnection(Connection): + @handler("bar") + async def handle_bar(self, message): + print(message) + + conn = FooConnection(protocol, address) + await conn.on_message_received({"command": "bar"}) + await conn.handle_bar({"command": "bar"}) + # Both calls print "{'command': 'bar'}" + ``` + """ + def __init__(self, protocol: Protocol, address: Address): + self.protocol = protocol + self.address = address + + def dispatch(self, message: Message) -> Handler: + """ + Get the function registered to handle a message. + + :raises: RouteError if no handler was found + """ + try: + handler_name = self.router.dispatch(message) + except RouteError: + # We use type(self) because we want to properly follow the MRO + handler_name = super(type(self), self).router.dispatch(message) + + return getattr(self, handler_name) + + async def on_message_received(self, message: Message): + """ + Forward the message to the registered handler function. + + :raises: RouteError if no handler was found + """ + handler_func = self.dispatch(message) + return await handler_func(message) + + async def send(self, message): + """Send a message and wait for it to be sent.""" + self.write(message) + await self.protocol.drain() + + def write(self, message): + """Write a message into the send buffer.""" + self.protocol.write_message(message) + + async def on_connection_lost(self): + pass diff --git a/server/core/protocol.py b/server/core/protocol.py new file mode 100644 index 000000000..b27895de0 --- /dev/null +++ b/server/core/protocol.py @@ -0,0 +1,135 @@ +import contextlib +from abc import ABCMeta, abstractmethod +from asyncio import StreamReader, StreamWriter +from typing import List + +from .asyncio_extensions import synchronizedmethod + + +class DisconnectedError(ConnectionError): + """For signaling that a protocol has lost connection to the remote.""" + + +class Protocol(metaclass=ABCMeta): + def __init__(self, reader: StreamReader, writer: StreamWriter): + self.reader = reader + self.writer = writer + # Force calls to drain() to only return once the data has been sent + self.writer.transport.set_write_buffer_limits(high=0) + + @staticmethod + @abstractmethod + def encode_message(message: dict) -> bytes: + """ + Encode a message as raw bytes. Can be used along with `*_raw` methods. + """ + pass # pragma: no cover + + def is_connected(self) -> bool: + """ + Return whether or not the connection is still alive + """ + return not self.writer.is_closing() + + @abstractmethod + async def read_message(self) -> dict: + """ + Asynchronously read a message from the stream + + :raises: IncompleteReadError + :return dict: Parsed message + """ + pass # pragma: no cover + + async def send_message(self, message: dict) -> None: + """ + Send a single message in the form of a dictionary + + :param message: Message to send + :raises: DisconnectedError + """ + await self.send_raw(self.encode_message(message)) + + async def send_messages(self, messages: List[dict]) -> None: + """ + Send multiple messages in the form of a list of dictionaries. + + May be more optimal than sending a single message. + + :param messages: + :raises: DisconnectedError + """ + self.write_messages(messages) + await self.drain() + + async def send_raw(self, data: bytes) -> None: + """ + Send raw bytes. Should generally not be used. + + :param data: bytes to send + :raises: DisconnectedError + """ + self.write_raw(data) + await self.drain() + + def write_message(self, message: dict) -> None: + """ + Write a single message into the message buffer. Should be used when + sending broadcasts or when sending messages that are triggered by + incoming messages from other players. + + :param message: Message to send + """ + if not self.is_connected(): + raise DisconnectedError("Protocol is not connected!") + + self.write_raw(self.encode_message(message)) + + def write_messages(self, messages: List[dict]) -> None: + """ + Write multiple message into the message buffer. + + :param messages: List of messages to write + """ + if not self.is_connected(): + raise DisconnectedError("Protocol is not connected!") + + self.writer.writelines([self.encode_message(msg) for msg in messages]) + + def write_raw(self, data: bytes) -> None: + """ + Write raw bytes into the message buffer. Should generally not be used. + + :param data: bytes to send + """ + if not self.is_connected(): + raise DisconnectedError("Protocol is not connected!") + + self.writer.write(data) + + async def close(self) -> None: + """ + Close the underlying writer as soon as the buffer has emptied. + :return: + """ + self.writer.close() + with contextlib.suppress(Exception): + await self.writer.wait_closed() + + @synchronizedmethod + async def drain(self) -> None: + """ + Await the write buffer to empty. + See StreamWriter.drain() + + :raises: DisconnectedError if the client disconnects while waiting for + the write buffer to empty. + """ + # Method needs to be synchronized as drain() cannot be called + # concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. + try: + await self.writer.drain() + except Exception as e: + await self.close() + raise DisconnectedError("Protocol connection lost!") from e diff --git a/server/core/router.py b/server/core/router.py new file mode 100644 index 000000000..3bce4db89 --- /dev/null +++ b/server/core/router.py @@ -0,0 +1,164 @@ +import contextlib +from typing import Any, Dict, List, Optional + +from .typedefs import Handler, HandlerDecorator, Message + + +class RouteError(Exception): + """ + Raised when no matching route can be found for a message + """ + + +class Router(): + """ + Matches messages to handler functions. + """ + missing = object() + + def __init__(self, dispatch_key: Any = missing): + self.dispatch_key = dispatch_key + self.registry = SearchTree() + + def register(self, key: Any = missing, **filters: Any) -> HandlerDecorator: + def decorator(func: Handler) -> Handler: + self.register_func(func, key, **filters) + return func + return decorator + + def register_func( + self, + func: Handler, + key: Any = missing, + **filters: Any + ) -> None: + """ + Register a handler with a set of filters. Note that the order of + repeated calls matters as does the order of keyword arguments. + + :param key: Optional convenience for adding a filter using the default + key: {self.dispatch_key: key} + """ + if key is not self.missing: + if self.dispatch_key is self.missing: + raise RuntimeError("No default `dispatch_key` provided!") + filters = {self.dispatch_key: key, **filters} + self.registry.insert(func, filters) + + def dispatch(self, message: Message) -> Handler: + """ + Get the handler function that matches this message. + + :raises: RouteError if no matching route is found + """ + with contextlib.suppress(KeyError): + return self.registry[message] + + raise RouteError("No matching route") + + +class SearchTreeKeyNode(): + """ + Even-level node for matching against keys. + """ + def __init__(self) -> None: + self.handler: Optional[Handler] = None + self.nodes: List[SearchTreeValueNode] = [] + + def __getitem__(self, message: Message) -> Handler: + """ + Return the matching handler. + + :raises: KeyError if none exists + """ + for node in self.nodes: + with contextlib.suppress(KeyError): + return node[message] + + if self.handler: + return self.handler + + raise KeyError() + + def get(self, message: Message) -> Optional[Handler]: + """ + Return the matching handler if it exists, else None. + """ + with contextlib.suppress(KeyError): + return self[message] + return None + + def insert(self, handler: Handler, filters: Dict[Any, Any]) -> None: + """ + Add a handler to the search tree given a set of filters. Note that the + order of repeated insert calls matters, as does the iteration order of + filters. + + # Examples + ``` + tree.insert("foo_handler", {"first": "foo"}) + tree.insert("bar_handler", {"second": "bar"}) + + # Both are present, so first match is returned + assert tree[{"first": "foo", "second": "bar"}] == "foo_handler" + ``` + + ``` + tree.insert("foo_handler", {"first": "foo", "second": "bar"}) + + assert tree.get({"first": "foo"}) == "foo_handler" + # Second is a subkey of first and does not match on its own + assert tree.get({"second": "bar"}) is None + ``` + """ + try: + key, value = next(iter(filters.items())) + except StopIteration: + self.handler = handler + return + + # Find matching node for `key` + for value_node in self.nodes: + if value_node.key == key: + break + else: + value_node = SearchTreeValueNode(key) + self.nodes.append(value_node) + + # Get the sub-node for `value` + node = value_node.values.get(value) + if node is None: + node = SearchTreeKeyNode() + value_node.values[value] = node + + # Recurse + del filters[key] + node.insert(handler, filters) + + def __repr__(self, level: int = 0) -> str: + spacing = " " * level + nodes = "\n".join(node.__repr__(level + 1) for node in self.nodes) + return f"{spacing}handler: {self.handler}\n{spacing}nodes:\n{nodes}" + + +class SearchTreeValueNode(): + """ + Odd-level node for matching against values. + """ + def __init__(self, key: Any) -> None: + self.key = key + self.values: Dict[Any, SearchTreeKeyNode] = {} + + def __getitem__(self, message: Message) -> Handler: + return self.values[message[self.key]][message] + + def __repr__(self, level: int = 0) -> str: + spacing = " " * level + values = "\n".join( + f"{spacing}value: {value}\n{node.__repr__(level + 1)}" + for value, node in self.values.items() + ) + return f"{spacing}key: {self.key}\n{spacing}values:\n{values}" + + +SearchTree = SearchTreeKeyNode diff --git a/server/core/service.py b/server/core/service.py index ed2be8118..52bf6af43 100644 --- a/server/core/service.py +++ b/server/core/service.py @@ -1,5 +1,5 @@ import re -from typing import Dict, List +from typing import Any, Dict, List, Optional from .dependency_injector import DependencyInjector @@ -7,28 +7,22 @@ DependencyGraph = Dict[str, List[str]] -class ServiceMeta(type): - """ - For tracking which Services have been defined. - """ - - # Mapping from parameter name to class - services: Dict[str, type] = {} +service_registry: Dict[str, type] = {} - def __new__(cls, name, bases, attrs): - klass = type.__new__(cls, name, bases, attrs) - if name != "Service": - arg_name = snake_case(name) - cls.services[arg_name] = klass - return klass - -class Service(metaclass=ServiceMeta): +class Service(): """ All services should inherit from this class. Services are singleton objects which manage some server task. """ + def __init_subclass__(cls, name: Optional[str] = None, **kwargs: Any): + """ + For tracking which services have been defined. + """ + super().__init_subclass__(**kwargs) + arg_name = name or snake_case(cls.__name__) + service_registry[arg_name] = cls async def initialize(self) -> None: """ @@ -51,7 +45,7 @@ def create_services(injectables: Dict[str, object] = {}) -> Dict[str, Service]: injector = DependencyInjector() injector.add_injectables(**injectables) - return injector.build_classes(ServiceMeta.services) + return injector.build_classes(service_registry) def snake_case(string: str) -> str: diff --git a/server/core/typedefs.py b/server/core/typedefs.py new file mode 100644 index 000000000..16c6761f7 --- /dev/null +++ b/server/core/typedefs.py @@ -0,0 +1,19 @@ +from typing import Any, Callable, Dict, NamedTuple + +# Type aliases +Handler = Callable[..., Any] +HandlerDecorator = Callable[[Handler], Handler] +Message = Dict[Any, Any] + + +# Named tuples +class Address(NamedTuple): + """A peer IP address""" + + host: str + port: int + + @classmethod + def from_string(cls, address: str) -> "Address": + host, port = address.rsplit(":", 1) + return cls(host, int(port)) diff --git a/server/gameconnection.py b/server/gameconnection.py index 81e361668..42d0daf63 100644 --- a/server/gameconnection.py +++ b/server/gameconnection.py @@ -1,5 +1,6 @@ import asyncio import contextlib +from typing import Union from sqlalchemy import or_, select, text @@ -7,32 +8,30 @@ from .abc.base_game import GameConnectionState from .config import TRACE -from .db.models import ( - coop_leaderboard, - coop_map, - login, - teamkills -) +from .core import Protocol, RouteError +from .core.connection import Connection, handler +from .core.typedefs import Address +from .db.models import coop_leaderboard, coop_map, login, teamkills from .decorators import with_logger from .game_service import GameService from .games.game import Game, GameError, GameState, ValidityState, Victory from .player_service import PlayerService from .players import Player, PlayerState -from .protocol import DisconnectedError, GpgNetServerProtocol, Protocol +from .protocol import DisconnectedError @with_logger -class GameConnection(GpgNetServerProtocol): +class GameConnection(Connection): """ Responsible for connections to the game, using the GPGNet protocol """ - def __init__( self, + protocol: Protocol, + address: Address, database: FAFDatabase, game: Game, player: Player, - protocol: Protocol, player_service: PlayerService, games: GameService, state: GameConnectionState = GameConnectionState.INITIALIZING @@ -40,11 +39,9 @@ def __init__( """ Construct a new GameConnection """ - super().__init__() + super().__init__(protocol, address) self._db = database - self._logger.debug("GameConnection initializing") - self.protocol = protocol self._state = state self.game_service = games self.player_service = player_service @@ -54,6 +51,7 @@ def __init__( self._game = game self.finished_sim = False + self._logger.debug("GameConnection initialized") @property def state(self) -> GameConnectionState: @@ -84,6 +82,14 @@ def is_host(self) -> bool: self.player == self.game.host ) + async def send_gpgnet_message( + self, + command: str, + *arguments: Union[int, str, bool] + ): + message = {"command": command, "args": arguments} + await self.send(message) + async def send(self, message): """ Send a game message to the client. @@ -99,60 +105,6 @@ async def send(self, message): self._logger.log(TRACE, ">> %s: %s", self.player.login, message) await self.protocol.send_message(message) - async def _handle_idle_state(self): - """ - This message is sent by FA when it doesn't know what to do. - :return: None - """ - assert self.game - state = self.player.state - - if state == PlayerState.HOSTING: - self.game.state = GameState.LOBBY - self._state = GameConnectionState.CONNECTED_TO_HOST - self.game.add_game_connection(self) - self.game.host = self.player - elif state == PlayerState.JOINING: - return - else: - self._logger.error("Unknown PlayerState: %s", state) - await self.abort() - - async def _handle_lobby_state(self): - """ - The game has told us it is ready and listening on - self.player.game_port for UDP. - We determine the connectivity of the peer and respond - appropriately - """ - player_state = self.player.state - if player_state == PlayerState.HOSTING: - await self.send_HostGame(self.game.map_folder_name) - self.game.set_hosted() - # If the player is joining, we connect him to host - # followed by the rest of the players. - elif player_state == PlayerState.JOINING: - await self.connect_to_host(self.game.host.game_connection) - - if self._state is GameConnectionState.ENDED: - # We aborted while trying to connect - return - - self._state = GameConnectionState.CONNECTED_TO_HOST - - try: - self.game.add_game_connection(self) - except GameError as e: - await self.abort(f"GameError while joining {self.game.id}: {e}") - return - - tasks = [] - for peer in self.game.connections: - if peer != self and peer.player != self.game.host: - self._logger.debug("%s connecting to %s", self.player, peer) - tasks.append(self.connect_to_peer(peer)) - await asyncio.gather(*tasks) - async def connect_to_host(self, peer: "GameConnection"): """ Connect self to a given peer (host) @@ -162,16 +114,16 @@ async def connect_to_host(self, peer: "GameConnection"): await self.abort("The host left the lobby") return - await self.send_JoinGame(peer.player.login, peer.player.id) + await self.send_gpgnet_message( + "JoinGame", peer.player.login, peer.player.id + ) if not peer: await self.abort("The host left the lobby") return - await peer.send_ConnectToPeer( - player_name=self.player.login, - player_uid=self.player.id, - offer=True + await peer.send_gpgnet_message( + "ConnectToPeer", self.player.login, self.player.id, True ) async def connect_to_peer(self, peer: "GameConnection"): @@ -180,30 +132,25 @@ async def connect_to_peer(self, peer: "GameConnection"): :return: None """ if peer is not None: - await self.send_ConnectToPeer( - player_name=peer.player.login, - player_uid=peer.player.id, - offer=True + await self.send_gpgnet_message( + "ConnectToPeer", peer.player.login, peer.player.id, True ) if peer is not None: with contextlib.suppress(DisconnectedError): - await peer.send_ConnectToPeer( - player_name=self.player.login, - player_uid=self.player.id, - offer=False + await peer.send_gpgnet_message( + "ConnectToPeer", self.player.login, self.player.id, False ) - async def handle_action(self, command, args): + async def on_message_received(self, message) -> None: """ Handle GpgNetSend messages, wrapped in the JSON protocol - :param command: command type - :param args: command arguments - :return: None """ try: - await COMMAND_HANDLERS[command](self, *args) - except KeyError: + command, args = message.get("command"), message.get("args", []) + handler_func = self.dispatch(message) + await handler_func(*args) + except RouteError: self._logger.warning( "Unrecognized command %s: %s from player %s", command, args, self.player @@ -216,9 +163,11 @@ async def handle_action(self, command, args): self._logger.exception("Something awful happened in a game thread!") await self.abort() - async def handle_desync(self, *_args): # pragma: no cover + @handler("Desync") + async def handle_desync(self, *_args): self.game.desyncs += 1 + @handler("GameOption") async def handle_game_option(self, key, value): if not self.is_host(): return @@ -244,6 +193,7 @@ async def handle_game_option(self, key, value): self._mark_dirty() + @handler("GameMods") async def handle_game_mods(self, mode, args): if not self.is_host(): return @@ -268,6 +218,7 @@ async def handle_game_mods(self, mode, args): self._mark_dirty() + @handler("PlayerOption") async def handle_player_option(self, player_id, command, value): if not self.is_host(): return @@ -275,6 +226,7 @@ async def handle_player_option(self, player_id, command, value): self.game.set_player_option(int(player_id), command, value) self._mark_dirty() + @handler("AIOption") async def handle_ai_option(self, name, key, value): if not self.is_host(): return @@ -282,6 +234,7 @@ async def handle_ai_option(self, name, key, value): self.game.set_ai_option(str(name), key, value) self._mark_dirty() + @handler("ClearSlot") async def handle_clear_slot(self, slot): if not self.is_host(): return @@ -289,6 +242,7 @@ async def handle_clear_slot(self, slot): self.game.clear_slot(int(slot)) self._mark_dirty() + @handler("GameResult") async def handle_game_result(self, army, result): army = int(army) result = str(result).lower() @@ -298,7 +252,12 @@ async def handle_game_result(self, army, result): except (KeyError, ValueError): # pragma: no cover self._logger.warning("Invalid result for %s reported: %s", army, result) + @handler("OperationComplete") async def handle_operation_complete(self, army, secondary, delta): + # FIXME: This check is meant to prevent double insertion into the + # leaderboards, but it also requires that a player must be in the first + # lobby slot in order for the results to count. + # https://github.com/FAForever/server/issues/560 if not int(army) == 1: return @@ -329,12 +288,15 @@ async def handle_operation_complete(self, army, secondary, delta): ) ) + @handler("JsonStats") async def handle_json_stats(self, stats): self.game.report_army_stats(stats) + @handler("EnforceRating") async def handle_enforce_rating(self): self.game.enforce_rating = True + @handler("TeamkillReport") async def handle_teamkill_report(self, gametime, reporter_id, reporter_name, teamkiller_id, teamkiller_name): """ Sent when a player is teamkilled and clicks the 'Report' button. @@ -348,6 +310,7 @@ async def handle_teamkill_report(self, gametime, reporter_id, reporter_name, tea pass + @handler("TeamkillHappened") async def handle_teamkill_happened(self, gametime, victim_id, victim_name, teamkiller_id, teamkiller_name): """ Send automatically by the game whenever a teamkill happens. Takes @@ -376,6 +339,7 @@ async def handle_teamkill_happened(self, gametime, victim_id, victim_name, teamk ) ) + @handler("IceMsg") async def handle_ice_message(self, receiver_id, ice_msg): receiver_id = int(receiver_id) peer = self.player_service.get_player(receiver_id) @@ -403,6 +367,7 @@ async def handle_ice_message(self, receiver_id, ice_msg): receiver_id ) + @handler("GameState") async def handle_game_state(self, state): """ Changes in game state @@ -416,10 +381,6 @@ async def handle_game_state(self, state): return elif state == "Lobby": - # TODO: Do we still need to schedule with `ensure_future`? - # - # We do not yield from the task, since we - # need to keep processing other commands while it runs await self._handle_lobby_state() elif state == "Launching": @@ -446,6 +407,61 @@ async def handle_game_state(self, state): await self.on_connection_lost() self._mark_dirty() + async def _handle_idle_state(self): + """ + This message is sent by FA when it doesn't know what to do. + :return: None + """ + assert self.game + state = self.player.state + + if state == PlayerState.HOSTING: + self.game.state = GameState.LOBBY + self._state = GameConnectionState.CONNECTED_TO_HOST + self.game.add_game_connection(self) + self.game.host = self.player + elif state == PlayerState.JOINING: + return + else: + self._logger.error("Unknown PlayerState: %s", state) + await self.abort() + + async def _handle_lobby_state(self): + """ + The game has told us it is ready and listening on + self.player.game_port for UDP. + We determine the connectivity of the peer and respond + appropriately + """ + player_state = self.player.state + if player_state == PlayerState.HOSTING: + await self.send_gpgnet_message("HostGame", self.game.map_folder_name) + self.game.set_hosted() + # If the player is joining, we connect him to host + # followed by the rest of the players. + elif player_state == PlayerState.JOINING: + await self.connect_to_host(self.game.host.game_connection) + + if self._state is GameConnectionState.ENDED: + # We aborted while trying to connect + return + + self._state = GameConnectionState.CONNECTED_TO_HOST + + try: + self.game.add_game_connection(self) + except GameError as e: + await self.abort(f"GameError while joining {self.game.id}: {e}") + return + + tasks = [] + for peer in self.game.connections: + if peer != self and peer.player != self.game.host: + self._logger.debug("%s connecting to %s", self.player, peer) + tasks.append(self.connect_to_peer(peer)) + await asyncio.gather(*tasks) + + @handler("GameEnded") async def handle_game_ended(self, *args): """ Signals that the simulation has ended. @@ -457,6 +473,7 @@ async def handle_game_ended(self, *args): if self.game.ended: await self.game.on_game_end() + @handler("Rehost") async def handle_rehost(self, *args): """ Signals that the user has rehosted the game. This is currently unused but @@ -464,6 +481,7 @@ async def handle_rehost(self, *args): """ pass + @handler("Bottleneck") async def handle_bottleneck(self, *args): """ Not sure what this command means. This is currently unused but @@ -471,6 +489,7 @@ async def handle_bottleneck(self, *args): """ pass + @handler("BottleneckCleared") async def handle_bottleneck_cleared(self, *args): """ Not sure what this command means. This is currently unused but @@ -478,6 +497,7 @@ async def handle_bottleneck_cleared(self, *args): """ pass + @handler("Disconnected") async def handle_disconnected(self, *args): """ Not sure what this command means. This is currently unused but @@ -485,12 +505,14 @@ async def handle_disconnected(self, *args): """ pass + @handler("Chat") async def handle_chat(self, message: str): """ Whenever the player sends a chat message during the game lobby. """ pass + @handler("GameFull") async def handle_game_full(self): """ Sent when all game slots are full @@ -534,7 +556,9 @@ async def disconnect_all_peers(self): if peer == self: continue - tasks.append(peer.send_DisconnectFromPeer(self.player.id)) + tasks.append( + peer.send_gpgnet_message("DisconnectFromPeer", self.player.id) + ) for fut in asyncio.as_completed(tasks): try: @@ -556,28 +580,3 @@ async def on_connection_lost(self): def __str__(self): return "GameConnection({}, {})".format(self.player, self.game) - - -COMMAND_HANDLERS = { - "Desync": GameConnection.handle_desync, - "GameState": GameConnection.handle_game_state, - "GameOption": GameConnection.handle_game_option, - "GameMods": GameConnection.handle_game_mods, - "PlayerOption": GameConnection.handle_player_option, - "AIOption": GameConnection.handle_ai_option, - "ClearSlot": GameConnection.handle_clear_slot, - "GameResult": GameConnection.handle_game_result, - "OperationComplete": GameConnection.handle_operation_complete, - "JsonStats": GameConnection.handle_json_stats, - "EnforceRating": GameConnection.handle_enforce_rating, - "TeamkillReport": GameConnection.handle_teamkill_report, - "TeamkillHappened": GameConnection.handle_teamkill_happened, - "GameEnded": GameConnection.handle_game_ended, - "Rehost": GameConnection.handle_rehost, - "Bottleneck": GameConnection.handle_bottleneck, - "BottleneckCleared": GameConnection.handle_bottleneck_cleared, - "Disconnected": GameConnection.handle_disconnected, - "IceMsg": GameConnection.handle_ice_message, - "Chat": GameConnection.handle_chat, - "GameFull": GameConnection.handle_game_full -} diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index 21e604de1..0b2feb843 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import hashlib import json @@ -6,6 +5,7 @@ import urllib.parse import urllib.request from datetime import datetime +from functools import wraps from typing import Optional import aiohttp @@ -17,9 +17,12 @@ import server.metrics as metrics from server.db import FAFDatabase -from . import asyncio_extensions as asyncio_ from .abc.base_game import GameConnectionState, InitMode from .config import TRACE, config +from .core import Protocol, RouteError +from .core import asyncio_extensions as asyncio_ +from .core.connection import Connection, handler +from .core.typedefs import Address, Handler, HandlerDecorator from .db.models import ( avatars, avatars_list, @@ -39,9 +42,9 @@ from .ladder_service import LadderService from .player_service import PlayerService from .players import Player, PlayerState -from .protocol import DisconnectedError, Protocol +from .protocol import DisconnectedError from .rating import InclusiveRange, RatingType -from .types import Address, GameLaunchOptions +from .types import GameLaunchOptions class ClientError(Exception): @@ -85,11 +88,45 @@ def __init__(self, message, *args, **kwargs): self.message = message +def public(func: Handler) -> Handler: + """ + Mark a handler function as being callable without having logged in yet. + """ + func.public = True + return func + + +def permission(role: str) -> HandlerDecorator: + """ + Ensure that a handler is only called if the player has the appropriate + permission role. + """ + def decorator(func: Handler) -> Handler: + @wraps(func) + async def wrapper(self, message): + if ( + not self.player + or not await self.player_service.has_permission_role( + self.player, + role + ) + ): + await self.send({"command": "permission_denied"}) + return + + await func(self, message) + + return wrapper + return decorator + + @with_logger -class LobbyConnection: +class LobbyConnection(Connection): @timed() def __init__( self, + protocol: Protocol, + address: Address, database: FAFDatabase, game_service: GameService, players: PlayerService, @@ -97,30 +134,29 @@ def __init__( geoip: GeoIpService, ladder_service: LadderService ): + super().__init__(protocol, address) self._db = database self.geoip_service = geoip self.game_service = game_service self.player_service = players self.nts_client = nts_client - self.coturn_generator = CoturnHMAC(config.COTURN_HOSTS, config.COTURN_KEYS) + self.coturn_generator = CoturnHMAC( + config.COTURN_HOSTS, + config.COTURN_KEYS + ) self.ladder_service = ladder_service - self._authenticated = False - self.player = None # type: Player - self.game_connection = None # type: GameConnection - self.peer_address = None # type: Optional[Address] + self.authenticated = False + self.player: Optional[Player] = None + self.game_connection: Optional[GameConnection] = None self.session = int(random.randrange(0, 4294967295)) - self.protocol: Protocol = None self.user_agent = None self._version = None self._attempted_connectivity_test = False + metrics.server_connections.inc() self._logger.debug("LobbyConnection initialized") - @property - def authenticated(self): - return self._authenticated - def get_user_identifier(self) -> str: """For logging purposes""" if self.player: @@ -128,14 +164,8 @@ def get_user_identifier(self) -> str: return str(self.session) - @asyncio.coroutine - def on_connection_made(self, protocol: Protocol, peername: Address): - self.protocol = protocol - self.peer_address = peername - metrics.server_connections.inc() - async def abort(self, logspam=""): - self._authenticated = False + self.authenticated = False if self.player: self._logger.warning( "Client %s dropped. %s", self.player.login, logspam @@ -144,21 +174,13 @@ async def abort(self, logspam=""): self.player = None else: self._logger.warning( - "Aborting %s. %s", self.peer_address.host, logspam + "Aborting %s. %s", self.address.host, logspam ) if self.game_connection: await self.game_connection.abort() await self.protocol.close() - async def ensure_authenticated(self, cmd): - if not self._authenticated: - if cmd not in ["hello", "ask_session", "create_account", "ping", "pong", "Bottleneck"]: # Bottleneck is sent by the game during reconnect - metrics.unauth_messages.labels(cmd).inc() - await self.abort("Message invalid for unauthenticated connection: %s" % cmd) - return False - return True - async def on_message_received(self, message): """ Dispatches incoming messages @@ -166,24 +188,16 @@ async def on_message_received(self, message): self._logger.log(TRACE, "<< %s: %s", self.get_user_identifier(), message) try: - cmd = message["command"] - if not await self.ensure_authenticated(cmd): - return - target = message.get("target") - if target == "game": - if not self.game_connection: - return - - await self.game_connection.handle_action(cmd, message.get("args", [])) + handler_func = self.dispatch(message) + if not hasattr(handler_func, "public") and not self.authenticated: + cmd = message["command"] + metrics.unauth_messages.labels(cmd).inc() + await self.abort( + f"Message invalid for unauthenticated connection: {cmd}" + ) return - if target == "connectivity" and message.get("command") == "InitiateTest": - self._attempted_connectivity_test = True - raise ClientError("Your client version is no longer supported. Please update to the newest version: https://faforever.com") - - handler = getattr(self, "command_{}".format(cmd)) - await handler(message) - + await handler_func(message) except AuthenticationError as ex: await self.send({ "command": "authentication_failed", @@ -207,26 +221,61 @@ async def on_message_received(self, message): await self.abort(ex.message) except (KeyError, ValueError) as ex: self._logger.exception(ex) - await self.abort("Garbage command: {}".format(message)) + await self.abort(f"Garbage command: {message}") except ConnectionError as e: # Propagate connection errors to the ServerContext error handler. raise e except Exception as ex: # pragma: no cover + if not isinstance(ex, RouteError): + self._logger.exception(ex) await self.send({"command": "invalid"}) - self._logger.exception(ex) - await self.abort("Error processing command") + await self.abort(f"Error processing message: {message}") + + async def _handle_game_message(self, message): + if not self.game_connection: + return + + await self.game_connection.on_message_received(message) + + @handler(target="game") + async def target_game(self, message): + await self._handle_game_message(message) + + @handler(target="game", command="Bottleneck") + @public + async def target_game_bottleneck(self, message): + await self._handle_game_message(message) + + @handler(target="connectivity", command="InitiateTest") + async def handle_initiate_test(self, message): + self._attempted_connectivity_test = True + raise ClientError( + "Your client version is no longer supported. Please update to the " + "newest version: https://faforever.com" + ) - async def command_ping(self, msg): + @handler("ping") + @public + async def command_ping(self, message): await self.send({"command": "pong"}) - async def command_pong(self, msg): + @handler("pong") + @public + async def command_pong(self, message): pass + @handler("create_account") + @public async def command_create_account(self, message): - raise ClientError("FAF no longer supports direct registration. Please use the website to register.", recoverable=True) + raise ClientError( + "FAF no longer supports direct registration. Please use the " + "website to register.", + recoverable=True + ) + @handler("coop_list") async def command_coop_list(self, message): - """ Request for coop map list""" + """Request for coop map list""" async with self._db.acquire() as conn: result = await conn.execute(select([coop_map])) @@ -257,18 +306,14 @@ async def command_coop_list(self, message): await self.protocol.send_messages(maps) + @handler("matchmaker_info") async def command_matchmaker_info(self, message): await self.send({ "command": "matchmaker_info", "queues": [queue.to_dict() for queue in self.ladder_service.queues.values()] }) - async def send_game_list(self): - await self.send({ - "command": "game_info", - "games": [game.to_dict() for game in self.game_service.open_games] - }) - + @handler("social_remove") async def command_social_remove(self, message): if "friend" in message: subject_id = message["friend"] @@ -289,6 +334,7 @@ async def command_social_remove(self, message): with contextlib.suppress(KeyError): player_attr.remove(subject_id) + @handler("social_add") async def command_social_add(self, message): if "friend" in message: status = "FRIEND" @@ -310,86 +356,76 @@ async def command_social_add(self, message): player_attr.add(subject_id) - async def kick(self): - await self.send({ - "command": "notice", - "style": "kick", - }) - await self.abort() - - async def send_updated_achievements(self, updated_achievements): - await self.send({ - "command": "updated_achievements", - "updated_achievements": updated_achievements - }) - - async def command_admin(self, message): - action = message["action"] - - if action == "closeFA": - if await self.player_service.has_permission_role( - self.player, "ADMIN_KICK_SERVER" - ): - player = self.player_service[message["user_id"]] - if player: - self._logger.info( - "Administrative action: %s closed game for %s", - self.player, player - ) - with contextlib.suppress(DisconnectedError): - await player.send_message({ - "command": "notice", - "style": "kill", - }) - - elif action == "closelobby": - if await self.player_service.has_permission_role( - self.player, "ADMIN_KICK_SERVER" - ): - player = self.player_service[message["user_id"]] - if player and player.lobby_connection is not None: - self._logger.info( - "Administrative action: %s closed client for %s", - self.player, player - ) - with contextlib.suppress(DisconnectedError): - await player.lobby_connection.kick() + @handler("admin", action="closeFA") + @permission("ADMIN_KICK_SERVER") + async def command_admin_closefa(self, message): + """Tell a client to kill ForgedAlliance.exe""" + player = self.player_service[message["user_id"]] + if player: + self._logger.info( + "Administrative action: %s closed game for %s", + self.player, player + ) + player.write_message({ + "command": "notice", + "style": "kill", + }) - elif action == "broadcast": - message_text = message.get("message") - if not message_text: - return - if await self.player_service.has_permission_role( - self.player, "ADMIN_BROADCAST_MESSAGE" - ): - tasks = [] - for player in self.player_service: - # Check if object still exists: - # https://docs.python.org/3/library/weakref.html#weak-reference-objects - if player.lobby_connection is not None: - tasks.append( - player.lobby_connection.send_warning(message_text) - ) + @handler("admin", action="closelobby") + @permission("ADMIN_KICK_SERVER") + async def command_admin_closelobby(self, message): + """Tell a client to close entirely""" + player = self.player_service[message["user_id"]] + if player and player.lobby_connection is not None: + self._logger.info( + "Administrative action: %s closed client for %s", + self.player, player + ) + with contextlib.suppress(DisconnectedError): + await player.lobby_connection.kick() + + @handler("admin", action="broadcast") + @permission("ADMIN_BROADCAST_MESSAGE") + async def command_admin_broadcast(self, message): + """Send a notice message to all online players""" + message_text = message.get("message") + if not message_text: + return - self._logger.info( - "%s broadcasting message to all players: %s", - self.player.login, message_text + tasks = [] + for player in self.player_service: + # Check if object still exists: + # https://docs.python.org/3/library/weakref.html#weak-reference-objects + if player.lobby_connection is not None: + tasks.append( + player.lobby_connection.send_warning(message_text) ) - await asyncio_.gather_without_exceptions(tasks, Exception) - elif action == "join_channel": - if await self.player_service.has_permission_role( - self.player, "ADMIN_JOIN_CHANNEL" - ): - user_ids = message["user_ids"] - channel = message["channel"] - for user_id in user_ids: - player = self.player_service[user_id] - if player: - player.write_message({ - "command": "social", - "autojoin": [channel] - }) + self._logger.info( + "%s broadcasting message to all players: %s", + self.player.login, message_text + ) + await asyncio_.gather_without_exceptions(tasks, Exception) + + @handler("admin", action="join_channel") + @permission("ADMIN_JOIN_CHANNEL") + async def command_admin_join_channel(self, message): + """Tell a client to join an IRC channel""" + user_ids = message["user_ids"] + channel = message["channel"] + + for user_id in user_ids: + player = self.player_service[user_id] + if player: + player.write_message({ + "command": "social", + "autojoin": [channel] + }) + + @handler("admin") + async def command_admin_other(self, message): + """Ignore any other actions""" + pass async def check_user_login(self, conn, username, password): # TODO: Hash passwords server-side so the hashing actually *does* something. @@ -568,6 +604,18 @@ async def check_policy_conformity(self, player_id, uid_hash, session, ignore_res return response.get("result", "") == "honest" + @handler("ask_session") + @public + async def command_ask_session(self, message): + user_agent = message.get("user_agent") + version = message.get("version") + self._set_user_agent_and_version(user_agent, version) + + if await self._check_version(): + await self.send({"command": "session", "session": self.session}) + + @handler("hello") + @public async def command_hello(self, message): login = message["login"].strip() password = message["password"] @@ -580,7 +628,7 @@ async def command_hello(self, message): t_login.update().where( t_login.c.id == player_id ).values( - ip=self.peer_address.host, + ip=self.address.host, user_agent=self.user_agent, last_login=func.now() ) @@ -634,11 +682,11 @@ async def command_hello(self, message): await self.player_service.fetch_player_data(self.player) self.player_service[self.player.id] = self.player - self._authenticated = True + self.authenticated = True # Country # ------- - self.player.country = self.geoip_service.country(self.peer_address.host) + self.player.country = self.geoip_service.country(self.address.host) # Send the player their own player info. await self.send({ @@ -704,6 +752,7 @@ async def command_hello(self, message): await self.send_game_list() + @handler("restore_game_session") async def command_restore_game_session(self, message): assert self.player is not None @@ -711,20 +760,21 @@ async def command_restore_game_session(self, message): # Restore the player's game connection, if the game still exists and is live if not game_id or game_id not in self.game_service: - await self.send_warning("The game you were connected to does no longer exist") + await self.send_warning("The game you were connected to no longer exists") return - game = self.game_service[game_id] # type: Game + game: "Game" = self.game_service[game_id] if game.state is not GameState.LOBBY and game.state is not GameState.LIVE: await self.send_warning("The game you were connected to is no longer available") return self._logger.debug("Restoring game session of player %s to game %s", self.player, game) self.game_connection = GameConnection( + self.protocol, + self.address, database=self._db, game=game, player=self.player, - protocol=self.protocol, player_service=self.player_service, games=self.game_service, state=GameConnectionState.CONNECTED_TO_HOST @@ -734,93 +784,135 @@ async def command_restore_game_session(self, message): self.player.state = PlayerState.PLAYING self.player.game = game - async def command_ask_session(self, message): - user_agent = message.get("user_agent") - version = message.get("version") - self._set_user_agent_and_version(user_agent, version) + @handler("avatar", action="list_avatar") + async def command_avatar_list_avatar(self, message): + avatarList = [] - if await self._check_version(): - await self.send({"command": "session", "session": self.session}) + async with self._db.acquire() as conn: + result = await conn.execute( + select([ + avatars_list.c.url, + avatars_list.c.tooltip + ]).select_from( + avatars.outerjoin( + avatars_list + ) + ).where( + avatars.c.idUser == self.player.id + ) + ) + + async for row in result: + avatar = {"url": row["url"], "tooltip": row["tooltip"]} + avatarList.append(avatar) - async def command_avatar(self, message): - action = message["action"] + if avatarList: + await self.send({"command": "avatar", "avatarlist": avatarList}) - if action == "list_avatar": - avatarList = [] + @handler("avatar", action="select") + async def command_avatar_select(self, message): + avatar_url = message["avatar"] - async with self._db.acquire() as conn: + async with self._db.acquire() as conn: + if avatar_url is not None: result = await conn.execute( select([ - avatars_list.c.url, - avatars_list.c.tooltip + avatars_list.c.id, avatars_list.c.tooltip ]).select_from( - avatars.outerjoin( - avatars_list - ) + avatars.join(avatars_list) ).where( - avatars.c.idUser == self.player.id + and_( + avatars_list.c.url == avatar_url, + avatars.c.idUser == self.player.id + ) ) ) + row = await result.fetchone() + if not row: + return - async for row in result: - avatar = {"url": row["url"], "tooltip": row["tooltip"]} - avatarList.append(avatar) - - if avatarList: - await self.send({"command": "avatar", "avatarlist": avatarList}) - - elif action == "select": - avatar_url = message["avatar"] - - async with self._db.acquire() as conn: - if avatar_url is not None: - result = await conn.execute( - select([ - avatars_list.c.id, avatars_list.c.tooltip - ]).select_from( - avatars.join(avatars_list) - ).where( - and_( - avatars_list.c.url == avatar_url, - avatars.c.idUser == self.player.id - ) - ) - ) - row = await result.fetchone() - if not row: - return + await conn.execute( + avatars.update().where( + avatars.c.idUser == self.player.id + ).values( + selected=0 + ) + ) + self.player.avatar = None + if avatar_url is not None: await conn.execute( avatars.update().where( - avatars.c.idUser == self.player.id + and_( + avatars.c.idUser == self.player.id, + avatars.c.idAvatar == row[avatars_list.c.id] + ) ).values( - selected=0 + selected=1 ) ) - self.player.avatar = None + self.player.avatar = { + "url": avatar_url, + "tooltip": row[avatars_list.c.tooltip] + } + self.player_service.mark_dirty(self.player) - if avatar_url is not None: - await conn.execute( - avatars.update().where( - and_( - avatars.c.idUser == self.player.id, - avatars.c.idAvatar == row[avatars_list.c.id] - ) - ).values( - selected=1 - ) - ) - self.player.avatar = { - "url": avatar_url, - "tooltip": row[avatars_list.c.tooltip] - } - self.player_service.mark_dirty(self.player) - else: - raise KeyError("invalid action") + @handler("avatar") + async def command_avatar_other(self, message): + raise KeyError("invalid action") + + @handler("game_host") + async def command_game_host(self, message): + """Host a new custom game lobby""" + assert isinstance(self.player, Player) + if self._attempted_connectivity_test: + raise ClientError("Cannot join game. Please update your client to the newest version.") + + await self.abort_connection_if_banned() + + visibility = VisibilityState(message["visibility"]) + title = message.get("title") or f"{self.player.login}'s game" + + try: + title.encode("ascii") + except UnicodeEncodeError: + await self.send({ + "command": "notice", + "style": "error", + "text": "Non-ascii characters in game name detected." + }) + return + + mod = message.get("mod") or FeaturedModType.FAF + mapname = message.get("mapname") or "scmp_007" + password = message.get("password") + game_mode = mod.lower() + rating_min = message.get("rating_min") + rating_max = message.get("rating_max") + enforce_rating_range = bool(message.get("enforce_rating_range", False)) + if rating_min is not None: + rating_min = float(rating_min) + if rating_max is not None: + rating_max = float(rating_max) + + game = self.game_service.create_game( + visibility=visibility, + game_mode=game_mode, + host=self.player, + name=title, + mapname=mapname, + password=password, + rating_type=RatingType.GLOBAL, + displayed_rating_range=InclusiveRange(rating_min, rating_max), + enforce_rating_range=enforce_rating_range + ) + await self.launch_game(game, is_host=True) + + @handler("game_join") async def command_game_join(self, message): """ - We are going to join a game. + Join an existing custom game lobby """ assert isinstance(self.player, Player) @@ -865,7 +957,9 @@ async def command_game_join(self, message): await self.launch_game(game, is_host=False) + @handler("game_matchmaking") async def command_game_matchmaking(self, message): + """Join or leave a matchmaker queue""" queue_name = str( message.get("queue_name") or message.get("mod", "ladder1v1") ) @@ -889,52 +983,6 @@ async def command_game_matchmaking(self, message): queue_name=queue_name ) - async def command_game_host(self, message): - assert isinstance(self.player, Player) - - if self._attempted_connectivity_test: - raise ClientError("Cannot join game. Please update your client to the newest version.") - - await self.abort_connection_if_banned() - - visibility = VisibilityState(message["visibility"]) - title = message.get("title") or f"{self.player.login}'s game" - - try: - title.encode("ascii") - except UnicodeEncodeError: - await self.send({ - "command": "notice", - "style": "error", - "text": "Non-ascii characters in game name detected." - }) - return - - mod = message.get("mod") or FeaturedModType.FAF - mapname = message.get("mapname") or "scmp_007" - password = message.get("password") - game_mode = mod.lower() - rating_min = message.get("rating_min") - rating_max = message.get("rating_max") - enforce_rating_range = bool(message.get("enforce_rating_range", False)) - if rating_min is not None: - rating_min = float(rating_min) - if rating_max is not None: - rating_max = float(rating_max) - - game = self.game_service.create_game( - visibility=visibility, - game_mode=game_mode, - host=self.player, - name=title, - mapname=mapname, - password=password, - rating_type=RatingType.GLOBAL, - displayed_rating_range=InclusiveRange(rating_min, rating_max), - enforce_rating_range=enforce_rating_range - ) - await self.launch_game(game, is_host=True) - async def launch_game( self, game, @@ -950,10 +998,11 @@ async def launch_game( game.host = self.player self.game_connection = GameConnection( + self.protocol, + self.address, database=self._db, game=game, player=self.player, - protocol=self.protocol, player_service=self.player_service, games=self.game_service ) @@ -965,11 +1014,6 @@ async def launch_game( "args": ["/numgames", self.player.game_count[game.rating_type]], "uid": game.id, "mod": game.game_mode, - # Following parameters may not be used by the client yet. They are - # needed for setting up auto-lobby style matches such as ladder, gw, - # and team machmaking where the server decides what these game - # options are. Currently, options for ladder are hardcoded into the - # client. "name": game.name, "init_mode": game.init_mode.value, **options._asdict() @@ -977,6 +1021,7 @@ async def launch_game( await self.send({k: v for k, v in cmd.items() if v is not None}) + @handler("modvault") async def command_modvault(self, message): type = message["type"] @@ -1044,6 +1089,7 @@ async def command_modvault(self, message): else: raise ValueError("invalid type argument") + @handler("ice_servers") async def command_ice_servers(self, message): if not self.player: return @@ -1063,6 +1109,25 @@ async def command_ice_servers(self, message): "ttl": ttl }) + async def kick(self): + await self.send({ + "command": "notice", + "style": "kick", + }) + await self.abort() + + async def send_game_list(self): + await self.send({ + "command": "game_info", + "games": [game.to_dict() for game in self.game_service.open_games] + }) + + async def send_updated_achievements(self, updated_achievements): + await self.send({ + "command": "updated_achievements", + "updated_achievements": updated_achievements + }) + async def send_warning(self, message: str, fatal: bool = False): """ Display a warning message to the client diff --git a/server/message_queue_service.py b/server/message_queue_service.py index 64d87fb07..a1cf44d14 100644 --- a/server/message_queue_service.py +++ b/server/message_queue_service.py @@ -6,9 +6,9 @@ from aio_pika import DeliveryMode, ExchangeType from aio_pika.exceptions import ProbableAuthenticationError -from .asyncio_extensions import synchronizedmethod from .config import TRACE, config from .core import Service +from .core.asyncio_extensions import synchronizedmethod from .decorators import with_logger diff --git a/server/protocol/__init__.py b/server/protocol/__init__.py index 460ac4dc1..dd98350f5 100644 --- a/server/protocol/__init__.py +++ b/server/protocol/__init__.py @@ -1,12 +1,10 @@ -from .gpgnet import GpgNetClientProtocol, GpgNetServerProtocol -from .protocol import DisconnectedError, Protocol +from ..core.protocol import DisconnectedError +from .protocol import Protocol from .qdatastream import QDataStreamProtocol from .simple_json import SimpleJsonProtocol __all__ = ( "DisconnectedError", - "GpgNetClientProtocol", - "GpgNetServerProtocol", "Protocol", "QDataStreamProtocol", "SimpleJsonProtocol" diff --git a/server/protocol/gpgnet.py b/server/protocol/gpgnet.py deleted file mode 100644 index 5e8a60ce8..000000000 --- a/server/protocol/gpgnet.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import List, Union - - -class GpgNetServerProtocol(metaclass=ABCMeta): - """ - Defines an interface for the server side GPGNet protocol - """ - async def send_ConnectToPeer(self, player_name: str, player_uid: int, offer: bool): - """ - Tells a client that has a listening LobbyComm instance to connect to the given peer - :param player_name: Remote player name - :param player_uid: Remote player identifier - """ - await self.send_gpgnet_message("ConnectToPeer", [player_name, player_uid, offer]) - - async def send_JoinGame(self, remote_player_name: str, remote_player_uid: int): - """ - Tells the game to join the given peer by ID - :param remote_player_name: - :param remote_player_uid: - """ - await self.send_gpgnet_message("JoinGame", [remote_player_name, remote_player_uid]) - - async def send_HostGame(self, map_path): - """ - Tells the game to start listening for incoming connections as a host - :param map_path: Which scenario to use - """ - await self.send_gpgnet_message("HostGame", [str(map_path)]) - - async def send_DisconnectFromPeer(self, id: int): - """ - Instructs the game to disconnect from the peer given by id - - :param id: - :return: - """ - await self.send_gpgnet_message("DisconnectFromPeer", [id]) - - async def send_gpgnet_message(self, command_id: str, arguments: List[Union[int, str, bool]]): - message = {"command": command_id, "args": arguments} - await self.send(message) - - @abstractmethod - async def send(self, message): - pass # pragma: no cover - - -class GpgNetClientProtocol(metaclass=ABCMeta): - def send_GameState(self, arguments: List[Union[int, str, bool]]) -> None: - """ - Sent by the client when the state of LobbyComm changes - """ - self.send_gpgnet_message("GameState", arguments) - - @abstractmethod - def send_gpgnet_message(self, command_id, arguments: List[Union[int, str, bool]]) -> None: - pass # pragma: no cover diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index 9622bf75a..07d92e280 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -1,142 +1,20 @@ -import contextlib import json -from abc import ABCMeta, abstractmethod -from asyncio import StreamReader, StreamWriter from typing import List import server.metrics as metrics -from ..asyncio_extensions import synchronizedmethod +from ..core import Protocol as _Protocol json_encoder = json.JSONEncoder(separators=(",", ":")) -class DisconnectedError(ConnectionError): - """For signaling that a protocol has lost connection to the remote.""" - - -class Protocol(metaclass=ABCMeta): - def __init__(self, reader: StreamReader, writer: StreamWriter): - self.reader = reader - self.writer = writer - # Force calls to drain() to only return once the data has been sent - self.writer.transport.set_write_buffer_limits(high=0) - - @staticmethod - @abstractmethod - def encode_message(message: dict) -> bytes: - """ - Encode a message as raw bytes. Can be used along with `*_raw` methods. - """ - pass # pragma: no cover - - def is_connected(self) -> bool: - """ - Return whether or not the connection is still alive - """ - return not self.writer.is_closing() - - @abstractmethod - async def read_message(self) -> dict: - """ - Asynchronously read a message from the stream - - :raises: IncompleteReadError - :return dict: Parsed message - """ - pass # pragma: no cover - - async def send_message(self, message: dict) -> None: - """ - Send a single message in the form of a dictionary - - :param message: Message to send - :raises: DisconnectedError - """ - await self.send_raw(self.encode_message(message)) - - async def send_messages(self, messages: List[dict]) -> None: - """ - Send multiple messages in the form of a list of dictionaries. - - May be more optimal than sending a single message. - - :param messages: - :raises: DisconnectedError - """ - self.write_messages(messages) - await self.drain() - - async def send_raw(self, data: bytes) -> None: - """ - Send raw bytes. Should generally not be used. - - :param data: bytes to send - :raises: DisconnectedError - """ - self.write_raw(data) - await self.drain() - - def write_message(self, message: dict) -> None: - """ - Write a single message into the message buffer. Should be used when - sending broadcasts or when sending messages that are triggered by - incoming messages from other players. - - :param message: Message to send - """ - if not self.is_connected(): - raise DisconnectedError("Protocol is not connected!") - - self.write_raw(self.encode_message(message)) +class Protocol(_Protocol): + """For hooking in metric collection""" def write_messages(self, messages: List[dict]) -> None: - """ - Write multiple message into the message buffer. - - :param messages: List of messages to write - """ metrics.sent_messages.labels(self.__class__.__name__).inc() - if not self.is_connected(): - raise DisconnectedError("Protocol is not connected!") - - self.writer.writelines([self.encode_message(msg) for msg in messages]) + super().write_messages(messages) def write_raw(self, data: bytes) -> None: - """ - Write raw bytes into the message buffer. Should generally not be used. - - :param data: bytes to send - """ metrics.sent_messages.labels(self.__class__.__name__).inc() - if not self.is_connected(): - raise DisconnectedError("Protocol is not connected!") - - self.writer.write(data) - - async def close(self) -> None: - """ - Close the underlying writer as soon as the buffer has emptied. - :return: - """ - self.writer.close() - with contextlib.suppress(Exception): - await self.writer.wait_closed() - - @synchronizedmethod - async def drain(self) -> None: - """ - Await the write buffer to empty. - See StreamWriter.drain() - - :raises: DisconnectedError if the client disconnects while waiting for - the write buffer to empty. - """ - # Method needs to be synchronized as drain() cannot be called - # concurrently by multiple coroutines: - # http://bugs.python.org/issue29930. - try: - await self.writer.drain() - except Exception as e: - await self.close() - raise DisconnectedError("Protocol connection lost!") from e + super().write_raw(data) diff --git a/server/servercontext.py b/server/servercontext.py index e8360af61..0424da639 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -4,14 +4,14 @@ import server.metrics as metrics +from .core.typedefs import Address from .decorators import with_logger from .lobbyconnection import LobbyConnection from .protocol import Protocol, QDataStreamProtocol -from .types import Address @with_logger -class ServerContext: +class ServerContext(): """ Base class for managing connections and holding state about them. """ @@ -19,7 +19,7 @@ class ServerContext: def __init__( self, name: str, - connection_factory: Callable[[], LobbyConnection], + connection_factory: Callable[[Protocol, Address], LobbyConnection], protocol_class: Type[Protocol] = QDataStreamProtocol, ): super().__init__() @@ -79,11 +79,13 @@ def write_broadcast_raw(self, data, validate_fn=lambda _: True): async def client_connected(self, stream_reader, stream_writer): self._logger.debug("%s: Client connected", self.name) protocol = self.protocol_class(stream_reader, stream_writer) - connection = self._connection_factory() + connection = self._connection_factory( + protocol, + Address(*stream_writer.get_extra_info("peername")) + ) self.connections[connection] = protocol try: - await connection.on_connection_made(protocol, Address(*stream_writer.get_extra_info("peername"))) metrics.user_connections.labels("None").inc() while protocol.is_connected(): message = await protocol.read_message() diff --git a/server/types.py b/server/types.py index 2f80cca08..0281905c0 100644 --- a/server/types.py +++ b/server/types.py @@ -1,18 +1,6 @@ from typing import NamedTuple, Optional -class Address(NamedTuple): - """A peer IP address""" - - host: str - port: int - - @classmethod - def from_string(cls, address: str) -> "Address": - host, port = address.rsplit(":", 1) - return cls(host, int(port)) - - class GameLaunchOptions(NamedTuple): """Additional options used to configure the FA lobby""" diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py index fdeefcdb9..8442524cb 100644 --- a/tests/integration_tests/test_server.py +++ b/tests/integration_tests/test_server.py @@ -323,7 +323,7 @@ async def test_host_coop_game(lobby_server): @pytest.mark.parametrize("command", ["game_host", "game_join"]) -async def test_server_ban_prevents_hosting(lobby_server, database, command): +async def test_server_ban_prevents_playing(lobby_server, database, command): """ Players who are banned while they are online, should immediately be prevented from joining or hosting games until their ban expires. diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py index 1b2753601..37b862fdd 100644 --- a/tests/integration_tests/test_servercontext.py +++ b/tests/integration_tests/test_servercontext.py @@ -13,18 +13,14 @@ @pytest.fixture -def mock_server(event_loop): +def mock_connection(event_loop): class MockServer: def __init__(self): - self.protocol, self.peername, self.user_agent = None, None, None + self.protocol = None + self.peername = None + self.user_agent = None self.on_connection_lost = CoroutineMock() - async def on_connection_made(self, protocol, peername): - self.protocol = protocol - self.peername = peername - self.protocol.writer.write_eof() - self.protocol.reader.feed_eof() - async def on_message_received(self, msg): pass @@ -32,16 +28,24 @@ async def on_message_received(self, msg): @pytest.fixture -async def mock_context(mock_server): - ctx = ServerContext("TestServer", lambda: mock_server) +async def mock_context(mock_connection): + def connection_factory(protocol, peername): + mock_connection.protocol = protocol + mock_connection.peername = peername + mock_connection.protocol.writer.write_eof() + mock_connection.protocol.reader.feed_eof() + return mock_connection + ctx = ServerContext("TestServer", connection_factory) yield await ctx.listen("127.0.0.1", None), ctx ctx.close() @pytest.fixture async def context(): - def make_connection() -> LobbyConnection: + def make_connection(protocol, address) -> LobbyConnection: return LobbyConnection( + protocol, + address, database=mock.Mock(), geoip=mock.Mock(), game_service=mock.Mock(), @@ -55,17 +59,17 @@ def make_connection() -> LobbyConnection: ctx.close() -async def test_serverside_abort(event_loop, mock_context, mock_server): +async def test_serverside_abort(event_loop, mock_context, mock_connection): srv, ctx = mock_context (reader, writer) = await asyncio.open_connection(*srv.sockets[0].getsockname()) proto = QDataStreamProtocol(reader, writer) await proto.send_message({"some_junk": True}) await exhaust_callbacks(event_loop) - mock_server.on_connection_lost.assert_any_call() + mock_connection.on_connection_lost.assert_any_call() -async def test_connection_broken_external(context, mock_server): +async def test_connection_broken_external(context, mock_connection): """ When the connection breaks while the server is calling protocol.send from somewhere other than the main read - response loop. Make sure that this diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 9cf6ed10b..0c7800adc 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -25,10 +25,11 @@ def game_connection( event_loop ): conn = GameConnection( + asynctest.create_autospec(QDataStreamProtocol), + ("localhost", 8001), database=database, game=game, player=players.hosting, - protocol=asynctest.create_autospec(QDataStreamProtocol), player_service=player_service, games=game_service ) diff --git a/tests/unit_tests/test_asyncio_extensions.py b/tests/unit_tests/core/test_asyncio_extensions.py similarity index 98% rename from tests/unit_tests/test_asyncio_extensions.py rename to tests/unit_tests/core/test_asyncio_extensions.py index d9ad42e60..3ce081767 100644 --- a/tests/unit_tests/test_asyncio_extensions.py +++ b/tests/unit_tests/core/test_asyncio_extensions.py @@ -3,7 +3,7 @@ import pytest from asynctest import CoroutineMock -from server.asyncio_extensions import ( +from server.core.asyncio_extensions import ( gather_without_exceptions, synchronized, synchronizedmethod diff --git a/tests/unit_tests/core/test_connection.py b/tests/unit_tests/core/test_connection.py new file mode 100644 index 000000000..d985aa72c --- /dev/null +++ b/tests/unit_tests/core/test_connection.py @@ -0,0 +1,89 @@ +import mock +import pytest + +from server.core import RouteError +from server.core.connection import Connection, handler + + +@pytest.mark.asyncio +async def test_basic(): + foo = mock.Mock() + bar = mock.Mock() + + class TestConnection(Connection): + @handler("foo") + async def handle_foo(self, message): + foo(message) + + @handler("bar") + async def handle_bar(self, message): + bar(message) + + conn = TestConnection(mock.Mock(), mock.Mock()) + + # Static dispatch + await conn.handle_foo({"command": "foo"}) + foo.assert_called_once_with({"command": "foo"}) + await conn.handle_bar({"command": "bar"}) + bar.assert_called_once_with({"command": "bar"}) + + foo.reset_mock() + bar.reset_mock() + + # Dynamic dispatch + await conn.on_message_received({"command": "foo"}) + foo.assert_called_once_with({"command": "foo"}) + await conn.on_message_received({"command": "bar"}) + bar.assert_called_once_with({"command": "bar"}) + + +@pytest.mark.asyncio +async def test_inheritance(): + foo = mock.Mock() + bar = mock.Mock() + foo2 = mock.Mock() + baz = mock.Mock() + + class Base(Connection): + @handler("foo") + async def handle_foo(self, message): + foo(message) + + @handler("bar") + async def handle_bar(self, message): + bar(message) + + class Child(Base): + @handler("foo") + async def handle_foo_2(self, message): + foo2(message) + + @handler("baz") + async def handle_baz(self, message): + baz(message) + + base = Base(mock.Mock(), mock.Mock()) + child = Child(mock.Mock(), mock.Mock()) + + await base.on_message_received({"command": "foo"}) + foo.assert_called_once_with({"command": "foo"}) + foo2.assert_not_called() + foo.reset_mock() + + await child.on_message_received({"command": "foo"}) + foo.assert_not_called() + foo2.assert_called_once_with({"command": "foo"}) + + await base.on_message_received({"command": "bar"}) + bar.assert_called_once_with({"command": "bar"}) + bar.reset_mock() + + await child.on_message_received({"command": "bar"}) + bar.assert_called_once_with({"command": "bar"}) + + with pytest.raises(RouteError): + await base.on_message_received({"command": "baz"}) + baz.assert_not_called() + + await child.on_message_received({"command": "baz"}) + baz.assert_called_once_with({"command": "baz"}) diff --git a/tests/unit_tests/core/test_router.py b/tests/unit_tests/core/test_router.py new file mode 100644 index 000000000..423e01ab1 --- /dev/null +++ b/tests/unit_tests/core/test_router.py @@ -0,0 +1,142 @@ +from textwrap import dedent + +import mock +import pytest + +from server.core.router import RouteError, Router, SearchTree + + +@pytest.fixture +def router(): + return Router("command") + + +def test_router_basic(router): + async def handle_foo(*args): + pass + + router.register_func(handle_foo, "foo") + + @router.register("bar") + def handle_bar(a): + pass + + @router.register("baz") + async def handle_baz(b, c): + pass + + assert router.dispatch({"command": "foo"}) is handle_foo + assert router.dispatch({"command": "bar"}) is handle_bar + assert router.dispatch({"command": "baz"}) is handle_baz + + with pytest.raises(RouteError): + router.dispatch({"command": "qux"}) + + +def test_router_filters(router): + @router.register("foo") + async def handle_foo(): + pass + + @router.register("foo", bar="hello") + async def handle_foo_hello(): + pass + + @router.register(bar="hello") + async def handle_bar(): + pass + + assert router.dispatch({"command": "foo"}) is handle_foo + assert router.dispatch({"command": "foo", "bar": "hello"}) is handle_foo_hello + assert router.dispatch({"bar": "hello"}) is handle_bar + + +def test_router_error(): + router = Router() + + with pytest.raises(RuntimeError): + router.register_func(mock.Mock(), "foo") + + +def test_filters(router): + @router.register("foo", filter="hello") + def handle_hello(): + pass + + @router.register("foo", filter=10) + def handle_ten(): + pass + + @router.register("foo") + def handle_foo(): + pass + + assert router.dispatch({"command": "foo", "filter": "hello"}) is handle_hello + assert router.dispatch({"command": "foo", "filter": 10}) is handle_ten + assert router.dispatch({"command": "foo", "filter": "world"}) is handle_foo + + +def test_search_tree(): + tree = SearchTree() + + target_game_mock = mock.Mock() + command_foo_mock = mock.Mock() + foo_hello_mock = mock.Mock() + foo_ten_mock = mock.Mock() + command_bar_mock = mock.Mock() + + tree.insert(target_game_mock, {"target": "game"}) + tree.insert(command_foo_mock, {"command": "foo"}) + tree.insert(foo_hello_mock, {"command": "foo", "filter": "hello"}) + tree.insert(foo_ten_mock, {"command": "foo", "filter": 10}) + tree.insert(command_bar_mock, {"command": "bar"}) + + assert tree[{"target": "game"}] is target_game_mock + assert tree[{"target": "game", "command": "foo"}] is target_game_mock + assert tree[{"command": "foo"}] is command_foo_mock + assert tree[{"command": "foo", "filter": "hello"}] is foo_hello_mock + assert tree[{"command": "foo", "filter": 10}] is foo_ten_mock + assert tree[{"command": "foo", "filter": "world"}] is command_foo_mock + assert tree[{"command": "bar"}] is command_bar_mock + assert tree.get({}) is None + assert tree.get({"unknown": "message"}) is None + + with pytest.raises(KeyError): + tree[{}] + + with pytest.raises(KeyError): + tree[{"unknown": "message"}] + + +def test_search_tree_repr(): + tree = SearchTree() + + tree.insert("game_handler", {"target": "game"}) + tree.insert("foo_handler", {"command": "foo"}) + tree.insert("foo_hello_handler", {"command": "foo", "filter": "hello"}) + tree.insert("bar_handler", {"command": "bar"}) + + assert repr(tree) == dedent(""" + handler: None + nodes: + key: target + values: + value: game + handler: game_handler + nodes: + + key: command + values: + value: foo + handler: foo_handler + nodes: + key: filter + values: + value: hello + handler: foo_hello_handler + nodes: + + value: bar + handler: bar_handler + nodes: + """.strip("\n")) diff --git a/tests/unit_tests/core/test_service.py b/tests/unit_tests/core/test_service.py new file mode 100644 index 000000000..1d63033bf --- /dev/null +++ b/tests/unit_tests/core/test_service.py @@ -0,0 +1,21 @@ +import mock + +from server.core import Service + + +def test_service_registry(): + with mock.patch("server.core.service.service_registry", {}) as registry: + class Foo(Service): + pass + + assert registry["foo"] is Foo + assert registry == {"foo": Foo} + + +def test_service_registry_name_override(): + with mock.patch("server.core.service.service_registry", {}) as registry: + class Foo(Service, name="FooService"): + pass + + assert registry["FooService"] is Foo + assert registry == {"FooService": Foo} diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/core/test_types.py similarity index 87% rename from tests/unit_tests/test_types.py rename to tests/unit_tests/core/test_types.py index 20bfb5b83..c346622f6 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/core/test_types.py @@ -1,4 +1,4 @@ -from server.types import Address +from server.core.typedefs import Address def test_address_from_string(): diff --git a/tests/unit_tests/test_gameconnection.py b/tests/unit_tests/test_gameconnection.py index 1364adae4..e1b266522 100644 --- a/tests/unit_tests/test_gameconnection.py +++ b/tests/unit_tests/test_gameconnection.py @@ -49,7 +49,7 @@ async def test_disconnect_all_peers( disconnect_done = mock.Mock() - async def fake_send_dc(player_id): + async def fake_send_dc(command, *args): await asyncio.sleep(1) # Take some time disconnect_done.success() return "OK" @@ -57,11 +57,11 @@ async def fake_send_dc(player_id): # Set up a peer that will disconnect without error ok_disconnect = asynctest.create_autospec(GameConnection) ok_disconnect.state = GameConnectionState.CONNECTED_TO_HOST - ok_disconnect.send_DisconnectFromPeer = fake_send_dc + ok_disconnect.send_gpgnet_message = fake_send_dc # Set up a peer that will throw an exception fail_disconnect = asynctest.create_autospec(GameConnection) - fail_disconnect.send_DisconnectFromPeer.return_value = Exception("Test exception") + fail_disconnect.send_gpgnet_message.side_effect = Exception("Test exception") fail_disconnect.state = GameConnectionState.CONNECTED_TO_HOST # Add the peers to the game @@ -78,7 +78,9 @@ async def test_connect_to_peer(game_connection): await game_connection.connect_to_peer(peer) - peer.send_ConnectToPeer.assert_called_once() + peer.send_gpgnet_message.assert_called_with( + "ConnectToPeer", "Paula_Bean", 1, False + ) async def test_connect_to_peer_disconnected(game_connection): @@ -86,13 +88,13 @@ async def test_connect_to_peer_disconnected(game_connection): await game_connection.connect_to_peer(None) peer = asynctest.create_autospec(GameConnection) - peer.send_ConnectToPeer.side_effect = DisconnectedError("Test error") + peer.send_gpgnet_message.side_effect = DisconnectedError("Test error") # The client disconnects right as we send the message await game_connection.connect_to_peer(peer) -async def test_handle_action_GameState_idle_adds_connection( +async def test_on_message_received_GameState_idle_adds_connection( game: Game, game_connection: GameConnection, players @@ -101,12 +103,15 @@ async def test_handle_action_GameState_idle_adds_connection( game_connection.player = players.hosting game_connection.game = game - await game_connection.handle_action("GameState", ["Idle"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Idle"] + }) game.add_game_connection.assert_called_with(game_connection) -async def test_handle_action_GameState_idle_non_searching_player_aborts( +async def test_on_message_received_GameState_idle_non_searching_player_aborts( game_connection: GameConnection, players ): @@ -115,12 +120,15 @@ async def test_handle_action_GameState_idle_non_searching_player_aborts( game_connection.abort = CoroutineMock() players.hosting.state = PlayerState.IDLE - await game_connection.handle_action("GameState", ["Idle"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Idle"] + }) game_connection.abort.assert_any_call() -async def test_handle_action_GameState_lobby_sends_HostGame( +async def test_on_message_received_GameState_lobby_sends_HostGame( game: Game, game_connection: GameConnection, event_loop, @@ -130,13 +138,16 @@ async def test_handle_action_GameState_lobby_sends_HostGame( game.map_file_path = "maps/some_map.zip" game.map_folder_name = "some_map" - await game_connection.handle_action("GameState", ["Lobby"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Lobby"] + }) await exhaust_callbacks(event_loop) - assert_message_sent(game_connection, "HostGame", [game.map_folder_name]) + assert_message_sent(game_connection, "HostGame", (game.map_folder_name,)) -async def test_handle_action_GameState_lobby_calls_ConnectToHost( +async def test_on_message_received_GameState_lobby_calls_ConnectToHost( game: Game, game_connection: GameConnection, event_loop, @@ -150,13 +161,16 @@ async def test_handle_action_GameState_lobby_calls_ConnectToHost( game.map_file_path = "maps/some_map.zip" game.map_folder_name = "some_map" - await game_connection.handle_action("GameState", ["Lobby"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Lobby"] + }) await exhaust_callbacks(event_loop) game_connection.connect_to_host.assert_called_with(players.hosting.game_connection) -async def test_handle_action_GameState_lobby_calls_ConnectToPeer( +async def test_on_message_received_GameState_lobby_calls_ConnectToPeer( game: Game, game_connection: GameConnection, event_loop, @@ -176,7 +190,10 @@ async def test_handle_action_GameState_lobby_calls_ConnectToPeer( players.peer.game_connection = peer_conn game.connections = [peer_conn] - await game_connection.handle_action("GameState", ["Lobby"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Lobby"] + }) await exhaust_callbacks(event_loop) game_connection.connect_to_peer.assert_called_with(peer_conn) @@ -198,13 +215,16 @@ async def test_handle_lobby_state_handles_GameError( real_game.host = players.hosting real_game.state = GameState.ENDED - await game_connection.handle_action("GameState", ["Lobby"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Lobby"] + }) await exhaust_callbacks(event_loop) game_connection.abort.assert_called_once() -async def test_handle_action_GameState_lobby_calls_abort( +async def test_on_message_received_GameState_lobby_calls_abort( game: Game, game_connection: GameConnection, event_loop, @@ -219,13 +239,16 @@ async def test_handle_action_GameState_lobby_calls_abort( game.map_file_path = "maps/some_map.zip" game.map_folder_name = "some_map" - await game_connection.handle_action("GameState", ["Lobby"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Lobby"] + }) await exhaust_callbacks(event_loop) game_connection.abort.assert_called_once() -async def test_handle_action_GameState_launching_calls_launch( +async def test_on_message_received_GameState_launching_calls_launch( game: Game, game_connection: GameConnection, players @@ -234,64 +257,99 @@ async def test_handle_action_GameState_launching_calls_launch( game_connection.game = game game.launch = CoroutineMock() - await game_connection.handle_action("GameState", ["Launching"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Launching"] + }) game.launch.assert_any_call() -async def test_handle_action_GameState_ended_calls_on_connection_lost( +async def test_on_message_received_GameState_ended_calls_on_connection_lost( game_connection: GameConnection ): game_connection.on_connection_lost = CoroutineMock() - await game_connection.handle_action("GameState", ["Ended"]) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Ended"] + }) game_connection.on_connection_lost.assert_called_once_with() -async def test_handle_action_PlayerOption(game: Game, game_connection: GameConnection): - await game_connection.handle_action("PlayerOption", [1, "Color", 2]) +async def test_on_message_received_PlayerOption(game: Game, game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "PlayerOption", + "args": [1, "Color", 2] + }) game.set_player_option.assert_called_once_with(1, "Color", 2) -async def test_handle_action_PlayerOption_malformed_no_raise(game_connection: GameConnection): - await game_connection.handle_action("PlayerOption", [1, "Sheeo", "Color", 2]) +async def test_on_message_received_PlayerOption_malformed_no_raise(game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "PlayerOption", + "args": [1, "Sheeo", "Color", 2] + }) # Shouldn't raise an exception -async def test_handle_action_PlayerOption_not_host( +async def test_on_message_received_PlayerOption_not_host( game: Game, game_connection: GameConnection, players ): game_connection.player = players.joining - await game_connection.handle_action("PlayerOption", [1, "Color", 2]) + await game_connection.on_message_received({ + "command": "PlayerOption", + "args": [1, "Color", 2] + }) game.set_player_option.assert_not_called() -async def test_handle_action_GameMods(game: Game, game_connection: GameConnection): - await game_connection.handle_action("GameMods", ["uids", "foo baz"]) +async def test_on_message_received_GameMods(game: Game, game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["uids", "foo baz"] + }) assert game.mods == {"baz": "test-mod2", "foo": "test-mod"} -async def test_handle_action_GameMods_activated(game: Game, game_connection: GameConnection): +async def test_on_message_received_GameMods_activated(game: Game, game_connection: GameConnection): game.mods = {"a": "b"} - await game_connection.handle_action("GameMods", ["activated", 0]) + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["activated", 0] + }) assert game.mods == {} - await game_connection.handle_action("GameMods", ["activated", "0"]) + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["activated", "0"] + }) assert game.mods == {} -async def test_handle_action_GameMods_not_host( +async def test_on_message_received_GameMods_unknown(game_connection: GameConnection): + # No exceptions raised + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["unknown", 0] + }) + + +async def test_on_message_received_GameMods_not_host( game: Game, game_connection: GameConnection, players ): game_connection.player = players.joining mods = game.mods - await game_connection.handle_action("GameMods", ["uids", "foo baz"]) + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["uids", "foo baz"] + }) assert game.mods == mods -async def test_handle_action_GameMods_post_launch_updates_played_cache( +async def test_on_message_received_GameMods_post_launch_updates_played_cache( game: Game, game_connection: GameConnection, database @@ -299,8 +357,14 @@ async def test_handle_action_GameMods_post_launch_updates_played_cache( game.launch = CoroutineMock() game.remove_game_connection = CoroutineMock() - await game_connection.handle_action("GameMods", ["uids", "foo bar EA040F8E-857A-4566-9879-0D37420A5B9D"]) - await game_connection.handle_action("GameState", ["Launching"]) + await game_connection.on_message_received({ + "command": "GameMods", + "args": ["uids", "foo bar EA040F8E-857A-4566-9879-0D37420A5B9D"] + }) + await game_connection.on_message_received({ + "command": "GameState", + "args": ["Launching"] + }) async with database.acquire() as conn: result = await conn.execute("select `played` from table_mod where uid=%s", ("EA040F8E-857A-4566-9879-0D37420A5B9D", )) @@ -308,97 +372,151 @@ async def test_handle_action_GameMods_post_launch_updates_played_cache( assert 2 == row[0] -async def test_handle_action_AIOption(game: Game, game_connection: GameConnection): - await game_connection.handle_action("AIOption", ["QAI", "StartSpot", 1]) +async def test_on_message_received_AIOption(game: Game, game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "AIOption", + "args": ["QAI", "StartSpot", 1] + }) game.set_ai_option.assert_called_once_with("QAI", "StartSpot", 1) -async def test_handle_action_AIOption_not_host( +async def test_on_message_received_AIOption_not_host( game: Game, game_connection: GameConnection, players ): game_connection.player = players.joining - await game_connection.handle_action("AIOption", ["QAI", "StartSpot", 1]) + await game_connection.on_message_received({ + "command": "AIOption", + "args": ["QAI", "StartSpot", 1] + }) game.set_ai_option.assert_not_called() -async def test_handle_action_ClearSlot(game: Game, game_connection: GameConnection): - await game_connection.handle_action("ClearSlot", [1]) +async def test_on_message_received_ClearSlot(game: Game, game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "ClearSlot", + "args": [1] + }) game.clear_slot.assert_called_once_with(1) - await game_connection.handle_action("ClearSlot", ["1"]) + await game_connection.on_message_received({ + "command": "ClearSlot", + "args": ["1"] + }) game.clear_slot.assert_called_with(1) -async def test_handle_action_ClearSlot_not_host( +async def test_on_message_received_ClearSlot_not_host( game: Game, game_connection: GameConnection, players ): game_connection.player = players.joining - await game_connection.handle_action("ClearSlot", [1]) + await game_connection.on_message_received({ + "command": "ClearSlot", + "args": [1] + }) game.clear_slot.assert_not_called() -async def test_handle_action_GameResult_calls_add_result(game: Game, game_connection: GameConnection): +async def test_on_message_received_GameResult_calls_add_result(game: Game, game_connection: GameConnection): game_connection.connect_to_host = CoroutineMock() - await game_connection.handle_action("GameResult", [0, "score -5"]) + await game_connection.on_message_received({ + "command": "GameResult", + "args": [0, "score -5"] + }) game.add_result.assert_called_once_with(game_connection.player.id, 0, "score", -5) -async def test_handle_action_GameOption(game: Game, game_connection: GameConnection): +async def test_on_message_received_GameOption(game: Game, game_connection: GameConnection): game.gameOptions = {"AIReplacement": "Off"} - await game_connection.handle_action("GameOption", ["Victory", "sandbox"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["Victory", "sandbox"] + }) assert game.gameOptions["Victory"] == Victory.SANDBOX - await game_connection.handle_action("GameOption", ["AIReplacement", "On"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["AIReplacement", "On"] + }) assert game.gameOptions["AIReplacement"] == "On" - await game_connection.handle_action("GameOption", ["Slots", "7"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["Slots", "7"] + }) assert game.max_players == 7 # I don't know what these paths actually look like - await game_connection.handle_action("GameOption", ["ScenarioFile", "C:\\Maps\\Some_Map"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["ScenarioFile", "C:\\Maps\\Some_Map"] + }) assert game.map_file_path == "maps/some_map.zip" - await game_connection.handle_action("GameOption", ["Title", "All welcome"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["Title", "All welcome"] + }) assert game.name == game.sanitize_name("All welcome") - await game_connection.handle_action("GameOption", ["ArbitraryKey", "ArbitraryValue"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["ArbitraryKey", "ArbitraryValue"] + }) assert game.gameOptions["ArbitraryKey"] == "ArbitraryValue" -async def test_handle_action_GameOption_not_host( +async def test_on_message_received_GameOption_not_host( game: Game, game_connection: GameConnection, players ): game_connection.player = players.joining game.gameOptions = {"Victory": "asdf"} - await game_connection.handle_action("GameOption", ["Victory", "sandbox"]) + await game_connection.on_message_received({ + "command": "GameOption", + "args": ["Victory", "sandbox"] + }) assert game.gameOptions == {"Victory": "asdf"} async def test_json_stats(game_connection: GameConnection, game_stats_service, players, game): game_stats_service.process_game_stats = mock.Mock() - await game_connection.handle_action("JsonStats", ['{"stats": {}}']) + await game_connection.on_message_received({ + "command": "JsonStats", + "args": ['{"stats": {}}'] + }) game.report_army_stats.assert_called_once_with('{"stats": {}}') -async def test_handle_action_EnforceRating(game: Game, game_connection: GameConnection): - await game_connection.handle_action("EnforceRating", []) +async def test_on_message_received_EnforceRating(game: Game, game_connection: GameConnection): + await game_connection.on_message_received({ + "command": "EnforceRating", + "args": [] + }) assert game.enforce_rating is True -async def test_handle_action_TeamkillReport(game: Game, game_connection: GameConnection, database): +async def test_on_message_received_TeamkillReport(game: Game, game_connection: GameConnection, database): game.launch = CoroutineMock() - await game_connection.handle_action("TeamkillReport", ["200", "2", "Dostya", "3", "Rhiza"]) + await game_connection.on_message_received({ + "command": "TeamkillReport", + "args": ["200", "2", "Dostya", "3", "Rhiza"] + }) async with database.acquire() as conn: - result = await conn.execute("select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=200", - game.id) + result = await conn.execute( + "select game_id,id from moderation_report where reporter_id=2 and game_id=%s and game_incident_timecode=200", + game.id + ) report = await result.fetchone() assert report is None -async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameConnection, database): + +async def test_on_message_received_TeamkillHappened(game: Game, game_connection: GameConnection, database): game.launch = CoroutineMock() - await game_connection.handle_action("TeamkillHappened", ["200", "2", "Dostya", "3", "Rhiza"]) + await game_connection.on_message_received({ + "command": "TeamkillHappened", + "args": ["200", "2", "Dostya", "3", "Rhiza"] + }) async with database.acquire() as conn: result = await conn.execute("select game_id from teamkills where victim=2 and teamkiller=3 and game_id=%s and gametime=200", @@ -407,38 +525,47 @@ async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameC assert game.id == row[0] -async def test_handle_action_TeamkillHappened_AI(game: Game, game_connection: GameConnection, database): +async def test_on_message_received_TeamkillHappened_AI(game: Game, game_connection: GameConnection, database): # Should fail with a sql constraint error if this isn't handled correctly game_connection.abort = CoroutineMock() - await game_connection.handle_action("TeamkillHappened", ["200", 0, "Dostya", "0", "Rhiza"]) + await game_connection.on_message_received({ + "command": "TeamkillHappened", + "args": ["200", 0, "Dostya", "0", "Rhiza"] + }) game_connection.abort.assert_not_called() -async def test_handle_action_GameEnded_ends_sim( +async def test_on_message_received_GameEnded_ends_sim( game: Game, game_connection: GameConnection ): game.ended = False - await game_connection.handle_action("GameEnded", []) + await game_connection.on_message_received({ + "command": "GameEnded", + "args": [] + }) assert game_connection.finished_sim game.check_sim_end.assert_called_once() game.on_game_end.assert_not_called() -async def test_handle_action_GameEnded_ends_game( +async def test_on_message_received_GameEnded_ends_game( game: Game, game_connection: GameConnection ): game.ended = True - await game_connection.handle_action("GameEnded", []) + await game_connection.on_message_received({ + "command": "GameEnded", + "args": [] + }) assert game_connection.finished_sim game.check_sim_end.assert_called_once() game.on_game_end.assert_called_once() -async def test_handle_action_OperationComplete(ugame: Game, game_connection: GameConnection, database): +async def test_on_message_received_OperationComplete(ugame: Game, game_connection: GameConnection, database): """ Sends an OperationComplete action to handle action and verifies that the `coop_leaderboard` table is updated accordingly. @@ -452,19 +579,22 @@ async def test_handle_action_OperationComplete(ugame: Game, game_connection: Gam secondary = 1 time_taken = "09:08:07.654321" - await game_connection.handle_action("OperationComplete", ["1", secondary, time_taken]) + await game_connection.on_message_received({ + "command": "OperationComplete", + "args": ["1", secondary, time_taken] + }) async with database.acquire() as conn: result = await conn.execute( "SELECT secondary, gameuid from `coop_leaderboard` where gameuid=%s", - ugame.id) - + ugame.id + ) row = await result.fetchone() assert (secondary, ugame.id) == (row[0], row[1]) -async def test_handle_action_OperationComplete_invalid(ugame: Game, game_connection: GameConnection, database): +async def test_on_message_received_OperationComplete_invalid(ugame: Game, game_connection: GameConnection, database): """ Sends an OperationComplete action to handle action and verifies that the `coop_leaderboard` table is updated accordingly. @@ -478,19 +608,22 @@ async def test_handle_action_OperationComplete_invalid(ugame: Game, game_connect secondary = 1 time_taken = "09:08:07.654321" - await game_connection.handle_action("OperationComplete", ["1", secondary, time_taken]) + await game_connection.on_message_received({ + "command": "OperationComplete", + "args": ["1", secondary, time_taken] + }) async with database.acquire() as conn: result = await conn.execute( "SELECT secondary, gameuid from `coop_leaderboard` where gameuid=%s", - ugame.id) - + ugame.id + ) row = await result.fetchone() assert row is None -async def test_handle_action_IceMsg( +async def test_on_message_received_IceMsg( game_connection: GameConnection, player_service, player_factory @@ -498,7 +631,10 @@ async def test_handle_action_IceMsg( peer = player_factory(player_id=2) peer.game_connection = asynctest.create_autospec(GameConnection) player_service[peer.id] = peer - await game_connection.handle_action("IceMsg", [2, "the message"]) + await game_connection.on_message_received({ + "command": "IceMsg", + "args": [2, "the message"] + }) peer.game_connection.send.assert_called_once_with({ "command": "IceMsg", @@ -506,14 +642,17 @@ async def test_handle_action_IceMsg( }) -async def test_handle_action_IceMsg_for_non_existent_player( +async def test_on_message_received_IceMsg_for_non_existent_player( game_connection: GameConnection, ): # No exceptions raised - await game_connection.handle_action("IceMsg", [3826, "the message"]) + await game_connection.on_message_received({ + "command": "IceMsg", + "args": [3826, "the message"] + }) -async def test_handle_action_IceMsg_for_non_connected( +async def test_on_message_received_IceMsg_for_non_connected( game_connection: GameConnection, player_service, player_factory @@ -522,7 +661,19 @@ async def test_handle_action_IceMsg_for_non_connected( del peer.game_connection player_service[peer.id] = peer # No exceptions raised - await game_connection.handle_action("IceMsg", [2, "the message"]) + await game_connection.on_message_received({ + "command": "IceMsg", + "args": [2, "the message"] + }) + + +async def test_on_message_received_desync(game_connection: GameConnection): + game_connection.game.desyncs = 0 + await game_connection.on_message_received({ + "command": "Desync", + "args": [] + }) + assert game_connection.game.desyncs == 1 @pytest.mark.parametrize("action", ( @@ -530,18 +681,25 @@ async def test_handle_action_IceMsg_for_non_connected( "Bottleneck", "BottleneckCleared", "Disconnected", - "Chat", "GameFull" )) -async def test_handle_action_ignored(game_connection: GameConnection, action): +async def test_on_message_received_ignored(game_connection: GameConnection, action): # No exceptions raised - await game_connection.handle_action(action, []) + await game_connection.on_message_received({"command": action, "args": []}) -async def test_handle_action_invalid(game_connection: GameConnection): +async def test_on_message_received_Chat_ignored(game_connection: GameConnection): + # No exceptions raised + await game_connection.on_message_received({"command": "Chat", "args": ["Test"]}) + + +async def test_on_message_received_invalid(game_connection: GameConnection): game_connection.abort = CoroutineMock() - await game_connection.handle_action("ThisDoesntExist", [1, 2, 3]) + await game_connection.on_message_received({ + "command": "ThisDoesntExist", + "args": [1, 2, 3] + }) game_connection.abort.assert_not_called() game_connection.protocol.send_message.assert_not_called() diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index 00c1cc4fb..891fd0954 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -11,6 +11,7 @@ from server.abc.base_game import InitMode from server.config import config +from server.core.typedefs import Address from server.db.models import ban, friends_and_foes from server.game_service import GameService from server.gameconnection import GameConnection @@ -24,7 +25,6 @@ from server.players import PlayerState from server.protocol import DisconnectedError, QDataStreamProtocol from server.rating import InclusiveRange, RatingType -from server.types import Address pytestmark = pytest.mark.asyncio @@ -97,6 +97,8 @@ def lobbyconnection( mock_nts_client ): lc = LobbyConnection( + mock_protocol, + Address("127.0.0.1", 1234), database=database, geoip=mock_geoip, game_service=mock_games, @@ -106,10 +108,8 @@ def lobbyconnection( ) lc.player = mock_player - lc.protocol = mock_protocol lc.player_service.fetch_player_data = CoroutineMock() - lc.peer_address = Address("127.0.0.1", 1234) - lc._authenticated = True + lc.authenticated = True return lc @@ -141,7 +141,7 @@ async def start_app(): async def test_unauthenticated_calls_abort(lobbyconnection, test_game_info): - lobbyconnection._authenticated = False + lobbyconnection.authenticated = False lobbyconnection.abort = CoroutineMock() await lobbyconnection.on_message_received({ @@ -163,7 +163,9 @@ async def test_bad_command_calls_abort(lobbyconnection): }) lobbyconnection.send.assert_called_once_with({"command": "invalid"}) - lobbyconnection.abort.assert_called_once_with("Error processing command") + lobbyconnection.abort.assert_called_once_with( + "Error processing message: {'command': 'this_isnt_real'}" + ) async def test_command_pong_does_nothing(lobbyconnection): @@ -561,7 +563,7 @@ async def test_command_admin_closeFA(lobbyconnection, player_factory): "user_id": tuna.id }) - tuna.lobby_connection.send.assert_any_call({ + tuna.lobby_connection.write.assert_any_call({ "command": "notice", "style": "kill", }) @@ -569,7 +571,7 @@ async def test_command_admin_closeFA(lobbyconnection, player_factory): async def test_game_subscription(lobbyconnection: LobbyConnection): game = Mock() - game.handle_action = CoroutineMock() + game.on_message_received = CoroutineMock() lobbyconnection.game_connection = game await lobbyconnection.on_message_received({ @@ -578,7 +580,11 @@ async def test_game_subscription(lobbyconnection: LobbyConnection): "target": "game" }) - game.handle_action.assert_called_with("test", ["foo", 42]) + game.on_message_received.assert_called_with({ + "command": "test", + "args": ["foo", 42], + "target": "game" + }) async def test_command_avatar_list(mocker, lobbyconnection: LobbyConnection): @@ -774,7 +780,7 @@ async def test_game_connection_not_restored_if_no_such_game_exists(lobbyconnecti lobbyconnection.send.assert_any_call({ "command": "notice", "style": "info", - "text": "The game you were connected to does no longer exist" + "text": "The game you were connected to no longer exists" })