Skip to content

Commit

Permalink
Merge pull request #17 from remopas/dev
Browse files Browse the repository at this point in the history
big performance upgrade by using rawio
  • Loading branch information
remkop22 authored May 19, 2022
2 parents afb88c3 + f920310 commit d063d06
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 258 deletions.
1 change: 1 addition & 0 deletions src/binread/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from .format import Format, formatclass
from .reader import ByteReader
from .types import Array, String, Bool, Char, Tuple, Bytes
from .types import I8, I16, I32, I64
from .types import U8, U16, U32, U64
Expand Down
147 changes: 51 additions & 96 deletions src/binread/format.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""This module contains the main classes used in binread."""

from io import RawIOBase
from .reader import ByteOrder, ByteReader
from abc import ABC, abstractmethod
from dataclasses import dataclass
from struct import unpack
from typing import Any, Callable, Dict, Tuple, Union, Optional
import sys
from typing import Any, Callable, Dict, SupportsBytes, Union, Optional


try:
Expand All @@ -13,15 +13,6 @@
from typing_extensions import Literal


ByteOrder = Literal["little", "big", "native"]
"""Specifies the endianness. `native` equals `sys.byteorder`."""


class NotEnoughBytes(Exception):
def __init__(self, msg: str = "not enough bytes") -> None:
super().__init__(msg)


class FieldType(ABC):
"""Abstract base class of all field types. Can be used to create a custom field type.
Expand All @@ -35,43 +26,30 @@ def __init__(
byteorder: Optional[ByteOrder] = None,
to: Optional[Callable] = None,
):
self._byteorder: Optional[ByteOrder] = byteorder
self._default_byteorder: ByteOrder = "native"
self.byteorder: Optional[ByteOrder] = byteorder
self.default_byteorder: ByteOrder = "native"
self.to = to

@abstractmethod
def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
"""Extracts the required bytes to construct this field.
Args:
data: The buffer to read.
fields: Any previous read fields used as context.
Returns:
(Any, int): The field value that is constructed and the bytes read.
Raises:
NotEnoughBytes: If not enough bytes are provided to construct this field.
"""
def extract(self, data: ByteReader, fields: Dict[str, Any]) -> Any:
pass

def read_field(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, size = self.extract(data, fields)
def read_field(self, data: ByteReader, fields: Dict[str, Any]) -> Any:
value = self.extract(data, fields)

if self.to:
value = self.to(value)

return value, size
return value

def _inheriting_byteorder(self) -> ByteOrder:
if self._byteorder:
return self._byteorder
def inheriting_byteorder(self) -> ByteOrder:
if self.byteorder:
return self.byteorder
else:
return self._default_byteorder
return self.default_byteorder

@staticmethod
def _to_instance(field: Union["FieldType", type]) -> Optional["FieldType"]:
def to_instance(field: Union["FieldType", type]) -> Optional["FieldType"]:
if isinstance(field, FieldType):
return field
elif (
Expand All @@ -85,95 +63,70 @@ def _to_instance(field: Union["FieldType", type]) -> Optional["FieldType"]:
else:
return None

def byteorder(self) -> Literal["little", "big"]:
if self._byteorder:
byteorder = self._byteorder
else:
byteorder = self._default_byteorder

if byteorder == "little" or byteorder == "big":
return byteorder
else:
return sys.byteorder


class Integer(FieldType):
def __init__(self, size: int, signed: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.signed = signed
self._size = size
self.size = size

def extract(self, data: bytes, fields: Dict[str, Any]):
if self._size > len(data):
raise NotEnoughBytes()

return (
int.from_bytes(data[: self._size], self.byteorder(), signed=self.signed),
self._size,
def extract(self, data: ByteReader, fields: Dict[str, Any]) -> int:
return data.read_integer(
self.size, self.byteorder or self.default_byteorder, self.signed
)


class Float(FieldType):
def __init__(self, size: Literal[2, 4, 8], *args, **kwargs):
super().__init__(*args, **kwargs)
self._size = size

def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
if self._size > len(data):
raise NotEnoughBytes()

byteorder = self.byteorder()
if byteorder == "little":
char = "<"
else:
char = ">"

if self._size == 2:
char += "e"
elif self._size == 4:
char += "f"
elif self._size == 8:
char += "d"
else:
raise Exception("invalid float size, must be either 2, 4 or 8")
self.size: Literal[2, 4, 8] = size

return *unpack(char, data[: self._size]), self._size
def extract(self, data: ByteReader, fields: Dict[str, Any]) -> float:
return data.read_float(self.size, self.byteorder or self.default_byteorder)


class Format(FieldType):
def __init__(self, fields: Dict[str, Union[FieldType, type]], *args, **kwargs):
super().__init__(*args, **kwargs)
self._fields: dict[str, FieldType] = {}
self.fields: dict[str, FieldType] = {}

byteorder = self.inheriting_byteorder()

byteorder = self._inheriting_byteorder()
for name, field in fields.items():
field = self._to_instance(field)
field = self.to_instance(field)

if not field:
raise Exception(f"unknown field type '{field}' with key '{name}'")

field._default_byteorder = byteorder
self._fields[name] = field
field.default_byteorder = byteorder
self.fields[name] = field

def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, bytes_read = self.read(data, allow_leftover=True, return_bytes=True)
return value, bytes_read # type: ignore
def extract(self, data: ByteReader, fields: Dict[str, Any]) -> Dict[str, Any]:
return self.read(data, allow_leftover=True)

def read(
self, data: bytes, allow_leftover: bool = False, return_bytes: bool = False
self,
byte_data: Union[ByteReader, RawIOBase, bytes, SupportsBytes],
allow_leftover: bool = False,
return_bytes: bool = False,
) -> Dict[str, Any]:
result = {}
total = 0
for name, field in self._fields.items():
result[name], bytes_read = field.read_field(data, result)
data = data[bytes_read:]
total += bytes_read

if len(data) != 0 and not allow_leftover:
if not isinstance(byte_data, ByteReader):
data = ByteReader(byte_data)
else:
data = byte_data

start_pos = data.tell()

for name, field in self.fields.items():
result[name] = field.read_field(data, result)

if (not data.is_eof()) and (not allow_leftover):
raise Exception("left over bytes")

if return_bytes:
return result, total # type: ignore
return result, data.tell() - start_pos # type: ignore
else:
return result

Expand All @@ -196,14 +149,16 @@ def read(self, *args, **kwargs) -> Dict[str, Any]:
def formatclass(*args, **kwargs):
with_args = True

if len(args) == 1 and isinstance(args[0], Callable):
if len(args) == 1 and isinstance(args[0], type):
with_args = False
kwargs = {}
cls = args[0]
args = []

def decorator(cls):
def decorator(cls) -> type:
fields = {}
for name, field in cls.__dict__.items():
field = FieldType._to_instance(field)
field = FieldType.to_instance(field)

if field:
fields[name] = field
Expand All @@ -229,4 +184,4 @@ def read(*args, **kwargs):
if with_args:
return decorator
else:
return decorator(args[0])
return decorator(cls)
75 changes: 75 additions & 0 deletions src/binread/reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from abc import ABC, abstractmethod
from io import BytesIO, RawIOBase
from typing import SupportsBytes, Union
from struct import unpack
import sys


try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

ByteOrder = Literal["little", "big", "native"]


def normalize_byteorder(byteorder: ByteOrder) -> Literal["little", "big"]:
if byteorder == "native":
return sys.byteorder
else:
return byteorder


class NotEnoughBytes(Exception):
def __init__(self, msg: str = "not enough bytes") -> None:
super().__init__(msg)


class ByteReader(ABC):
def __init__(self, data: Union[bytes, SupportsBytes, RawIOBase]):
if isinstance(data, RawIOBase):
self.data = data
elif isinstance(data, bytes):
self.data = BytesIO(data)
elif isinstance(data, SupportsBytes):
self.data = BytesIO(data.__bytes__())

def tell(self) -> int:
return self.data.tell()

def is_eof(self) -> bool:
current_pos = self.tell()
self.data.seek(0, 2)
end_pos = self.tell()
self.data.seek(current_pos)
return current_pos >= end_pos

def read_integer(self, size: int, byteorder: ByteOrder, signed: bool) -> int:
return int.from_bytes(
self.read_bytes(size), normalize_byteorder(byteorder), signed=signed
)

def read_float(self, size: Literal[2, 4, 8], byteorder: ByteOrder) -> float:
byteorder = normalize_byteorder(byteorder)

if byteorder == "little":
char = "<"
else:
char = ">"

if size == 2:
char += "e"
elif size == 4:
char += "f"
elif size == 8:
char += "d"
else:
raise Exception("invalid float size, must be either 2, 4 or 8")

return unpack(char, self.read_bytes(size))[0]

def read_bytes(self, size: int) -> bytes:
data = self.data.read(size)
if (not data) or (len(data) != size):
raise NotEnoughBytes()
return data
Loading

0 comments on commit d063d06

Please sign in to comment.