Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reading and writing boolean values #555

Merged
merged 9 commits into from
Jun 25, 2023
56 changes: 49 additions & 7 deletions mcstatus/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from ctypes import c_uint32 as unsigned_int32
from ctypes import c_uint64 as unsigned_int64
from ipaddress import ip_address
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import asyncio_dgram

from mcstatus.address import Address

if TYPE_CHECKING:
from typing_extensions import Self, SupportsIndex, TypeAlias
from types import TracebackType

from typing_extensions import Literal, Self, SupportsIndex, TypeAlias

BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]"

Expand Down Expand Up @@ -122,6 +124,10 @@ def write_ulong(self, value: int) -> None:
"""Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``."""
self.write(self._pack("Q", value))

def write_bool(self, value: bool) -> None:
"""Write 1 byte for boolean `True` or `False`"""
self.write(self._pack("?", value))

def write_buffer(self, buffer: "Connection") -> None:
"""Flush buffer, then write a varint of the length of the buffer's data, then write buffer data."""
data = buffer.flush()
Expand Down Expand Up @@ -214,6 +220,10 @@ async def write_ulong(self, value: int) -> None:
"""Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``."""
await self.write(self._pack("Q", value))

async def write_bool(self, value: bool) -> None:
"""Write 1 byte for boolean `True` or `False`"""
await self.write(self._pack("?", value))

async def write_buffer(self, buffer: "Connection") -> None:
"""Flush buffer, then write a varint of the length of the buffer's data, then write buffer data."""
data = buffer.flush()
Expand All @@ -236,7 +246,7 @@ def __repr__(self) -> str:
@staticmethod
def _unpack(format_: str, data: bytes) -> int:
"""Unpack data as bytes with format in big-endian."""
return struct.unpack(">" + format_, bytes(data))[0]
return cast(int, struct.unpack(">" + format_, bytes(data))[0])

def read_varint(self) -> int:
"""Read varint from ``self`` and return it.
Expand Down Expand Up @@ -302,6 +312,10 @@ def read_ulong(self) -> int:
"""Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes."""
return self._unpack("Q", self.read(8))

def read_bool(self) -> bool:
"""Return `True` or `False`. Read 1 byte."""
return cast(bool, self._unpack("?", self.read(1)))

def read_buffer(self) -> "Connection":
"""Read a varint for length, then return a new connection from length read bytes."""
length = self.read_varint()
Expand All @@ -325,7 +339,7 @@ def __repr__(self) -> str:
@staticmethod
def _unpack(format_: str, data: bytes) -> int:
"""Unpack data as bytes with format in big-endian."""
return struct.unpack(">" + format_, bytes(data))[0]
return cast(int, struct.unpack(">" + format_, bytes(data))[0])

async def read_varint(self) -> int:
"""Read varint from ``self`` and return it.
Expand Down Expand Up @@ -391,6 +405,10 @@ async def read_ulong(self) -> int:
"""Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes."""
return self._unpack("Q", await self.read(8))

async def read_bool(self) -> bool:
"""Return `True` or `False`. Read 1 byte."""
return cast(bool, self._unpack("?", await self.read(1)))

async def read_buffer(self) -> Connection:
"""Read a varint for length, then return a new connection from length read bytes."""
length = await self.read_varint()
Expand Down Expand Up @@ -510,8 +528,16 @@ def close(self) -> None:
def __enter__(self) -> Self:
return self

def __exit__(self, *_) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False


class TCPSocketConnection(SocketConnection):
Expand Down Expand Up @@ -621,8 +647,16 @@ async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False


class UDPAsyncSocketConnection(BaseAsyncConnection):
Expand Down Expand Up @@ -667,5 +701,13 @@ async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False
20 changes: 20 additions & 0 deletions tests/protocol/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,26 @@ def test_write_ulong_positive(self):

assert self.connection.flush() == bytearray.fromhex("8000000000000000")

def test_read_bool_true(self):
self.connection.receive(bytearray.fromhex("01"))

assert self.connection.read_bool() is True

def test_write_bool_true(self):
self.connection.write_bool(True)

assert self.connection.flush() == bytearray.fromhex("01")

def test_read_bool_false(self):
self.connection.receive(bytearray.fromhex("00"))

assert self.connection.read_bool() is False

def test_write_bool_false(self):
self.connection.write_bool(False)

assert self.connection.flush() == bytearray.fromhex("00")

def test_read_buffer(self):
self.connection.receive(bytearray.fromhex("027FAA"))
buffer = self.connection.read_buffer()
Expand Down