From 2a356ceae81449fd68389df0e2bc83bfee65158e Mon Sep 17 00:00:00 2001 From: gatsik <74517072+Gatsik@users.noreply.github.com> Date: Mon, 23 May 2022 22:57:27 +0300 Subject: [PATCH] Round floats before encoding --- server/config.py | 4 ++-- server/protocol/protocol.py | 35 ++++++++-------------------- tests/integration_tests/test_game.py | 4 ++-- tests/unit_tests/test_protocol.py | 6 ++--- 4 files changed, 17 insertions(+), 32 deletions(-) diff --git a/server/config.py b/server/config.py index 910b0524b..af2a76df6 100644 --- a/server/config.py +++ b/server/config.py @@ -124,8 +124,8 @@ def __init__(self): self.QUEUE_POP_DESIRED_MATCHES = 2.5 # How many previous queue sizes to consider self.QUEUE_POP_TIME_MOVING_AVG_SIZE = 5 - # The number of decimal places to use for float serialization - self.JSON_FLOAT_DIGITS_PRECISION = 2 + # The maximum number of decimal places to use for float serialization + self.JSON_FLOAT_MAX_DIGITS = 2 self._defaults = { key: value for key, value in vars(self).items() if key.isupper() diff --git a/server/protocol/protocol.py b/server/protocol/protocol.py index 37bea8a42..4b51b90c9 100644 --- a/server/protocol/protocol.py +++ b/server/protocol/protocol.py @@ -10,32 +10,17 @@ class CustomJSONEncoder(json.JSONEncoder): - # taken from https://stackoverflow.com/a/60243503 + # taken from https://stackoverflow.com/a/53798633 def encode(self, o): - if isinstance(o, dict): - dict_content = [] - for key, value in o.items(): - # do not use precision for float dict keys - if (isinstance(key, int) or isinstance(key, float)): - new_key = f'"{key}"'.lower() - elif key is None: - new_key = '"null"' - elif isinstance(key, str): - new_key = self.encode(key) - else: - raise TypeError( - 'Keys must be str, int, float, bool or None, ' - f'not {key.__class__.__name__}', - ) - new_value = self.encode(value) - dict_content.append(f"{new_key}:{new_value}") - return "{" + ",".join(dict_content) + "}" - elif isinstance(o, (list, tuple)) and not isinstance(o, str): - return "[" + ",".join(map(self.encode, o)) + "]" - elif isinstance(o, float): - return f"{o:.{config.JSON_FLOAT_DIGITS_PRECISION}f}" - else: - return super().encode(o) + def round_floats(o): + if isinstance(o, float): + return round(o, config.JSON_FLOAT_MAX_DIGITS) + if isinstance(o, dict): + return {k: round_floats(v) for k, v in o.items()} + if isinstance(o, (list, tuple)): + return [round_floats(x) for x in o] + return o + return super().encode(round_floats(o)) json_encoder = CustomJSONEncoder(separators=(",", ":")) diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py index a52d1f033..a7366bd60 100644 --- a/tests/integration_tests/test_game.py +++ b/tests/integration_tests/test_game.py @@ -310,7 +310,7 @@ async def test_game_ended_rates_game(lobby_server): async def test_game_ended_broadcasts_rating_update( lobby_server, channel, mocker, ): - mocker.patch("server.config.JSON_FLOAT_DIGITS_PRECISION", 4) + mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 4) mq_proto_all = await connect_mq_consumer( lobby_server, channel, @@ -620,7 +620,7 @@ async def test_ladder_game_draw_bug(lobby_server, database, mocker): their own ACU in order to kill the enemy ACU and be awarded a victory instead of a draw. """ - mocker.patch("server.config.JSON_FLOAT_DIGITS_PRECISION", 13) + mocker.patch("server.config.JSON_FLOAT_MAX_DIGITS", 13) player1_id, proto1, player2_id, proto2 = await queue_players_for_matchmaking(lobby_server) msg1, msg2 = await asyncio.gather(*[ diff --git a/tests/unit_tests/test_protocol.py b/tests/unit_tests/test_protocol.py index 3a42d903c..15600f331 100644 --- a/tests/unit_tests/test_protocol.py +++ b/tests/unit_tests/test_protocol.py @@ -262,9 +262,9 @@ async def test_read_when_disconnected(protocol): def test_json_encoder_float_serialization(): - assert json_encoder.encode(123.0) == '123.00' + assert json_encoder.encode(123.0) == '123.0' assert json_encoder.encode(0.99) == '0.99' - assert json_encoder.encode(0.999) == '1.00' + assert json_encoder.encode(0.999) == '1.0' def test_json_encoder_encodes_server_messages(): @@ -305,7 +305,7 @@ def test_json_encoder_encodes_server_messages(): for queue in expected_matchmaker_info_dict["queues"]: queue["queue_pop_time_delta"] = round( - queue["queue_pop_time_delta"], config.JSON_FLOAT_DIGITS_PRECISION, + queue["queue_pop_time_delta"], config.JSON_FLOAT_MAX_DIGITS, ) assert new_encode(matchmaker_info_dict) == (