Skip to content

Commit

Permalink
fixed array struct element (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
remkop22 committed May 18, 2022
1 parent 4c487b4 commit 289c2b2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 70 deletions.
6 changes: 3 additions & 3 deletions src/binread/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from .format import Format, format
from .types import Array, String, Bool, Char, Bytes
from .format import Format, formatclass
from .types import Array, String, Bool, Char, Tuple, Bytes
from .types import I8, I16, I32, I64
from .types import U8, U16, U32, U64
from .types import F16, F32, F64
from .types import F16, F32, F64
102 changes: 61 additions & 41 deletions src/binread/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


ByteOrder = Literal["little", "big", "native"]
"""Specifies the endiannes. `native` equals `sys.byteorder`."""
"""Specifies the endianness. `native` equals `sys.byteorder`."""


class NotEnoughBytes(Exception):
Expand All @@ -26,7 +26,7 @@ class FieldType(ABC):
"""Abstract base class of all field types. Can be used to create a custom field type.
Args:
byteorder: specifies the endiannes of this type.
byteorder: specifies the endianness of this type.
to: specifies a callable to transform the extracted data.
"""

Expand Down Expand Up @@ -64,6 +64,27 @@ def read_field(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:

return value, size

def _inheriting_byteorder(self) -> ByteOrder:
if self._byteorder:
return self._byteorder
else:
return self._default_byteorder

@staticmethod
def _to_instance(field: Union["FieldType", type]) -> Optional["FieldType"]:
if isinstance(field, FieldType):
return field
elif (
isinstance(field, type)
and issubclass(field, FieldType)
and field != FieldType
):
return field() # type: ignore
elif hasattr(field, "_field_type"):
getattr(field, "_field_type")
else:
return None

def byteorder(self) -> Literal["little", "big"]:
if self._byteorder:
byteorder = self._byteorder
Expand Down Expand Up @@ -124,26 +145,15 @@ def __init__(self, fields: Dict[str, Union[FieldType, type]], *args, **kwargs):
super().__init__(*args, **kwargs)
self._fields: dict[str, FieldType] = {}

if self._byteorder:
byteorder = self._byteorder
else:
byteorder = self._default_byteorder

byteorder = self._inheriting_byteorder()
for name, field in fields.items():
if isinstance(field, FieldType):
self._fields[name] = field
elif (
isinstance(field, type)
and issubclass(field, FieldType)
and field != FieldType
):
self._fields[name] = field() # type: ignore
elif hasattr(field, "_field_type"):
self._fields[name] = getattr(field, "_field_type")
else:
field = self._to_instance(field)

if not field:
raise Exception(f"unknown field type '{field}' with key '{name}'")

self._fields[name]._default_byteorder = byteorder
field._default_byteorder = byteorder
self._fields[name] = field

def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, bytes_read = self.read(data, allow_leftover=True, return_bytes=True)
Expand All @@ -168,34 +178,44 @@ def read(
return result


def format(cls: type) -> type:
def formatclass(*args, **kwargs):
with_args = True

fields = {}
for name, field in cls.__dict__.items():
if _is_field_type(field):
fields[name] = field
if len(args) == 0 and isinstance(args[0], Callable):
with_args = False
args = []
kwargs = {}

for name in fields.keys():
cls.__annotations__[name] = Any
def decorator(cls):
fields = {}
for name, field in cls.__dict__.items():
field = FieldType._to_instance(field)

cls = dataclass(cls)
if field:
fields[name] = field

fmt = Format(fields)
setattr(cls, "_field_type", fmt)
for name in fields.keys():
cls.__annotations__[name] = Any

@staticmethod
def read(*args, **kwargs):
field_dict = fmt.read(*args, **kwargs)
return cls(**field_dict)
cls = dataclass(cls)

fmt = Format(fields, *args, **kwargs)
setattr(cls, "_field_type", fmt)

setattr(cls, "read", read)
@staticmethod
def read(*args, **kwargs):
field_dict = fmt.read(*args, **kwargs)

if isinstance(field_dict, tuple):
return cls(**field_dict[0]), field_dict[1] # type: ignore
else:
return cls(**field_dict)

return cls
setattr(cls, "read", read)

return cls

def _is_field_type(obj: Any) -> bool:
return (
isinstance(obj, FieldType)
or (isinstance(obj, type) and issubclass(obj, FieldType))
or hasattr(obj, "_field_type")
)
if with_args:
return decorator
else:
return decorator(args[0])
41 changes: 15 additions & 26 deletions src/binread/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def extract(self, data: bytes, fields: Dict[str, Any]):
class Array(FieldType):
def __init__(
self,
element: Union[FieldType, Type[FieldType]],
element: Union[FieldType, Type],
length: Union[int, str, Callable[[Dict[str, Any]], int], None] = None,
length_bytes: Union[int, str, Callable[[Dict[str, Any]], int], None] = None,
terminator: Union[bytes, None] = None,
Expand All @@ -81,21 +81,14 @@ def __init__(
):
super().__init__(*args, **kwargs)

if self._byteorder:
byteorder = self._byteorder
else:
byteorder = self._default_byteorder

if isinstance(element, FieldType):
element._default_byteorder = byteorder
self.element = element
elif issubclass(element, FieldType) and element != FieldType:
element_instance = element()
element_instance._default_byteorder = byteorder
self.element = element_instance
else:
element = self._to_instance(element)

if not element:
raise Exception(f"invalid array element {element}")

element._default_byteorder = self._inheriting_byteorder()
self.element = element

self._length = length
self._length_bytes = length_bytes
self._terminator = terminator
Expand Down Expand Up @@ -206,22 +199,18 @@ def extract_with_terminator(
class Tuple(FieldType):
def __init__(self, fields: Iterable[Union[FieldType, type]], *args, **kwargs):
super().__init__(*args, **kwargs)
self._fields = []
self._fields = [] # type: ignore
byteorder = self._inheriting_byteorder()

for field in fields:
if isinstance(field, FieldType):
self._fields.append(field)
elif (
isinstance(field, type)
and issubclass(field, FieldType)
and field != FieldType
):
self._fields.append(field()) # type: ignore
elif hasattr(field, "_field_type"):
self._fields.append(getattr(field, "_field_type"))
else:
field = self._to_instance(field)

if not field:
raise Exception(f"unknown field type '{field}'")

field._default_byteorder = byteorder
self._fields.append(field)

self._fields: TupleType[FieldType] = tuple(self._fields)

def extract(self, data: bytes, fields: Dict[str, Any]) -> TupleType[Any, int]:
Expand Down

0 comments on commit 289c2b2

Please sign in to comment.