Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reading and writing boolean values #555

Merged
merged 9 commits into from
Jun 25, 2023
44 changes: 35 additions & 9 deletions mcstatus/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from ctypes import c_uint32 as unsigned_int32
from ctypes import c_uint64 as unsigned_int64
from ipaddress import ip_address
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import asyncio_dgram

from mcstatus.address import Address

if TYPE_CHECKING:
from typing_extensions import Self, SupportsIndex, TypeAlias
from types import TracebackType

from typing_extensions import Literal, Self, SupportsIndex, TypeAlias

BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]"

Expand Down Expand Up @@ -244,7 +246,7 @@ def __repr__(self) -> str:
@staticmethod
def _unpack(format_: str, data: bytes) -> int:
"""Unpack data as bytes with format in big-endian."""
return struct.unpack(">" + format_, bytes(data))[0]
return cast(int, struct.unpack(">" + format_, bytes(data))[0])

def read_varint(self) -> int:
"""Read varint from ``self`` and return it.
Expand Down Expand Up @@ -312,7 +314,7 @@ def read_ulong(self) -> int:

def read_bool(self) -> bool:
"""Return `True` or `False`. Read 1 byte."""
return self._unpack("?", self.read(1)) == 1
return cast(bool, self._unpack("?", self.read(1)))

def read_buffer(self) -> "Connection":
"""Read a varint for length, then return a new connection from length read bytes."""
Expand All @@ -337,7 +339,7 @@ def __repr__(self) -> str:
@staticmethod
def _unpack(format_: str, data: bytes) -> int:
"""Unpack data as bytes with format in big-endian."""
return struct.unpack(">" + format_, bytes(data))[0]
return cast(int, struct.unpack(">" + format_, bytes(data))[0])

async def read_varint(self) -> int:
"""Read varint from ``self`` and return it.
Expand Down Expand Up @@ -405,7 +407,7 @@ async def read_ulong(self) -> int:

async def read_bool(self) -> bool:
"""Return `True` or `False`. Read 1 byte."""
return self._unpack("?", await self.read(1)) == 1
return cast(bool, self._unpack("?", await self.read(1)))

async def read_buffer(self) -> Connection:
"""Read a varint for length, then return a new connection from length read bytes."""
Expand Down Expand Up @@ -526,8 +528,16 @@ def close(self) -> None:
def __enter__(self) -> Self:
return self

def __exit__(self, *_) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False


class TCPSocketConnection(SocketConnection):
Expand Down Expand Up @@ -637,8 +647,16 @@ async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False


class UDPAsyncSocketConnection(BaseAsyncConnection):
Expand Down Expand Up @@ -683,5 +701,13 @@ async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *_) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> Literal[False]:
self.close()
# Return false, we don't want to suppress
# exceptions raised in the context
return False