diff --git a/server/matchmaker/__init__.py b/server/matchmaker/__init__.py index 63068ebdf..031824957 100644 --- a/server/matchmaker/__init__.py +++ b/server/matchmaker/__init__.py @@ -5,6 +5,7 @@ games, currently just used for 1v1 ``ladder``. """ from .map_pool import MapPool +from .match_offer import MatchOffer, OfferTimeoutError from .matchmaker_queue import MatchmakerQueue from .pop_timer import PopTimer from .search import CombinedSearch, OnMatchedCallback, Search @@ -12,7 +13,9 @@ __all__ = ( "CombinedSearch", "MapPool", + "MatchOffer", "MatchmakerQueue", + "OfferTimeoutError", "OnMatchedCallback", "PopTimer", "Search", diff --git a/server/matchmaker/match_offer.py b/server/matchmaker/match_offer.py new file mode 100644 index 000000000..3abb6edea --- /dev/null +++ b/server/matchmaker/match_offer.py @@ -0,0 +1,85 @@ +import asyncio +from datetime import datetime +from typing import Generator, Iterable + +from ..players import Player + + +class OfferTimeoutError(asyncio.TimeoutError): + pass + + +class MatchOffer(object): + """ + Track which players are ready for a match to begin. + + Once a player has become ready, they cannot become unready again. State + changes are eagerly broadcast to other players in the MatchOffer. + """ + + def __init__(self, players: Iterable[Player], expires_at: datetime): + self.expires_at = expires_at + self._players_ready = {player: False for player in players} + self.future = asyncio.Future() + + def get_unready_players(self) -> Generator[Player, None, None]: + return ( + player for player, ready in self._players_ready.items() + if not ready + ) + + def get_ready_players(self) -> Generator[Player, None, None]: + return ( + player for player, ready in self._players_ready.items() + if ready + ) + + def ready_player(self, player: Player) -> None: + """ + Mark a player as ready. + + Broadcasts the state change to other players. + """ + if self._players_ready[player]: + # This client's state is probably out of date + player.write_message({ + "command": "match_info", + **self.to_dict(), + "ready": True + }) + else: + self._players_ready[player] = True + self.write_broadcast_update() + + if all(self._players_ready.values()) and not self.future.done(): + self.future.set_result(True) + + async def wait_ready(self) -> None: + """Wait for all players to have readied up.""" + timeout = (self.expires_at - datetime.now()).total_seconds() + if timeout <= 0: + raise OfferTimeoutError() + + try: + await asyncio.wait_for(self.future, timeout=timeout) + except asyncio.TimeoutError: + raise OfferTimeoutError() + + def write_broadcast_update(self) -> None: + """Queue the `match_info` message to be sent to all players in the + MatchOffer.""" + info = self.to_dict() + for player, ready in self._players_ready.items(): + player.write_message({ + "command": "match_info", + **info, + "ready": ready + }) + + def to_dict(self) -> dict: + return { + "expires_at": self.expires_at.isoformat(), + "players_total": len(self._players_ready), + # Works because `True` is counted as 1 and `False` as 0 + "players_ready": sum(self._players_ready.values()) + } diff --git a/tests/unit_tests/test_match_offer.py b/tests/unit_tests/test_match_offer.py new file mode 100644 index 000000000..598fc8052 --- /dev/null +++ b/tests/unit_tests/test_match_offer.py @@ -0,0 +1,72 @@ +from datetime import datetime, timedelta + +import mock +import pytest + +from server.matchmaker import MatchOffer, OfferTimeoutError +from tests.utils import fast_forward + + +@pytest.fixture +def offer(player_factory): + return MatchOffer( + [player_factory(player_id=i) for i in range(5)], + datetime(2020, 1, 31, 14, 30, 36) + ) + + +def test_match_offer_api(offer): + + assert offer.to_dict() == { + "expires_at": "2020-01-31T14:30:36", + "players_total": 5, + "players_ready": 0 + } + + assert len(list(offer.get_ready_players())) == 0 + assert len(list(offer.get_unready_players())) == 5 + + +def test_broadcast_called_on_ready(offer): + offer.write_broadcast_update = mock.Mock() + player = next(offer.get_unready_players()) + + offer.ready_player(player) + + offer.write_broadcast_update.assert_called_once() + + +def test_ready_player_bad_key(offer, player_factory): + with pytest.raises(KeyError): + offer.ready_player(player_factory(player_id=42)) + + +@pytest.mark.asyncio +async def test_wait_ready_timeout(offer): + with pytest.raises(OfferTimeoutError): + await offer.wait_ready() + + +@pytest.mark.asyncio +@fast_forward(5) +async def test_wait_ready_timeout_some_ready(offer): + offer.expires_at = datetime.now() + timedelta(seconds=5) + + players = offer.get_unready_players() + p1, p2 = next(players), next(players) + + offer.ready_player(p1) + offer.ready_player(p2) + + with pytest.raises(OfferTimeoutError): + await offer.wait_ready() + + +@pytest.mark.asyncio +async def test_wait_ready(offer): + offer.expires_at = datetime.now() + timedelta(seconds=5) + + for player in offer.get_unready_players(): + offer.ready_player(player) + + await offer.wait_ready()