Skip to content

Commit

Permalink
Merge pull request #3 from remopas/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
remkop22 authored May 17, 2022
2 parents 837f0bd + 5a84efc commit 5e6d10d
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 37 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
# Binread

Read binary any binary format
28 changes: 19 additions & 9 deletions src/binread/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
pass

def read(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
def read_field(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, size = self.extract(data, fields)

if self.to:
Expand Down Expand Up @@ -99,25 +99,35 @@ def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
return *unpack(char, data[: self._size]), self._size


class Format:
def __init__(self, fields: Dict[str, Union[FieldType, type]]):
class Format(FieldType):
def __init__(self, fields: Dict[str, Union[FieldType, type]], *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields: dict[str, FieldType] = {}

for name, field in fields.items():
if isinstance(field, FieldType):
self.fields[name] = field
elif issubclass(field, FieldType) and field != FieldType:
self.fields[name] = field()
self.fields[name] = field() # type: ignore

def read(self, data: bytes, allow_leftover: bool = False) -> Dict[str, Any]:
def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, bytes_read = self.read(data, allow_leftover=True, return_bytes=True)
return value, bytes_read # type: ignore

def read(
self, data: bytes, allow_leftover: bool = False, return_bytes: bool = False
) -> Union[Dict[str, Any], Tuple[Dict[str, Any], int]]:
result = {}
total = 0
for name, field in self.fields.items():
result[name], bytes_read = field.read(data, result)
result[name], bytes_read = field.read_field(data, result)
data = data[bytes_read:]
total += bytes_read

if len(data) != 0 and not allow_leftover:
print(data)
print(result)
raise Exception("left over bytes")

return result
if return_bytes:
return result, total
else:
return result
24 changes: 20 additions & 4 deletions src/binread/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .format import Integer, FieldType, Float
from .format import Format, Integer, FieldType, Float
from typing import Any, Tuple, Type, Union, Dict


Expand Down Expand Up @@ -104,7 +104,7 @@ def extract_with_length(
result = [None] * length
total = 0
for i in range(length):
result[i], bytes_read = self.element.read(data, {})
result[i], bytes_read = self.element.read_field(data, {})
data = data[bytes_read:]
total += bytes_read
return result, total
Expand All @@ -116,13 +116,12 @@ def extract_with_terminator(
total = 0

while not data.startswith(terminator):
value, bytes_read = self.element.read(data, {})
value, bytes_read = self.element.read_field(data, {})
result.append(value)
data = data[bytes_read:]
total += bytes_read

total += len(terminator)
print(total)

return result, total

Expand All @@ -134,3 +133,20 @@ def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
return self.extract_with_terminator(data, fields, self._terminator)
else:
raise Exception("array must either have a length or a terminator")


class String(Array):
def __init__(
self,
encoding: str = "utf-8",
*args,
**kwargs,
):
super().__init__(U8, *args, **kwargs)
self.encoding = encoding

def extract(self, data: bytes, fields: Dict[str, Any]) -> Tuple[Any, int]:
value, bytes_read = super().extract(data, fields)
return bytes(value).decode(self.encoding), bytes_read


82 changes: 58 additions & 24 deletions tests/binread/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,42 @@

class TestTypes(unittest.TestCase):
def test_f64_little(self):
value, bytes_read = binread.F64(byteorder="little").read(
value, bytes_read = binread.F64(byteorder="little").read_field(
struct.pack("<d", 5.2), {}
)
self.assertAlmostEqual(value, 5.2)
self.assertEqual(bytes_read, 8)

def test_f64_big(self):
value, bytes_read = binread.F64(byteorder="big").read(
value, bytes_read = binread.F64(byteorder="big").read_field(
struct.pack(">d", 5.2), {}
)
self.assertAlmostEqual(value, 5.2)
self.assertEqual(bytes_read, 8)

def test_f32_little(self):
value, bytes_read = binread.F32(byteorder="little").read(
value, bytes_read = binread.F32(byteorder="little").read_field(
struct.pack("<f", 5.2), {}
)
self.assertAlmostEqual(value, 5.2, places=6)
self.assertEqual(bytes_read, 4)

def test_f32_big(self):
value, bytes_read = binread.F32(byteorder="big").read(
value, bytes_read = binread.F32(byteorder="big").read_field(
struct.pack(">f", 5.2), {}
)
self.assertAlmostEqual(value, 5.2, places=6)
self.assertEqual(bytes_read, 4)

def test_f16_little(self):
value, bytes_read = binread.F16(byteorder="little").read(
value, bytes_read = binread.F16(byteorder="little").read_field(
struct.pack("<e", 5.2), {}
)
self.assertAlmostEqual(value, 5.2, places=2)
self.assertEqual(bytes_read, 2)

def test_f16_big(self):
value, bytes_read = binread.F16(byteorder="big").read(
value, bytes_read = binread.F16(byteorder="big").read_field(
struct.pack(">e", 5.2), {}
)
self.assertAlmostEqual(value, 5.2, places=2)
Expand All @@ -50,111 +50,145 @@ def test_f16_big(self):
def test_u64_big(self):
self.assertEqual(
(2**64 - 1, 8),
binread.U64(byteorder="big").read(struct.pack(">Q", 2**64 - 1), {}),
binread.U64(byteorder="big").read_field(struct.pack(">Q", 2**64 - 1), {}),
)

def test_u64_little(self):
self.assertEqual(
(2**64 - 1, 8),
binread.U64(byteorder="little").read(struct.pack("<Q", 2**64 - 1), {}),
binread.U64(byteorder="little").read_field(
struct.pack("<Q", 2**64 - 1), {}
),
)

def test_u32_big(self):
self.assertEqual(
(2**32 - 1, 4),
binread.U32(byteorder="big").read(struct.pack(">L", 2**32 - 1), {}),
binread.U32(byteorder="big").read_field(struct.pack(">L", 2**32 - 1), {}),
)

def test_u32_little(self):
self.assertEqual(
(2**32 - 1, 4),
binread.U32(byteorder="little").read(struct.pack("<L", 2**32 - 1), {}),
binread.U32(byteorder="little").read_field(
struct.pack("<L", 2**32 - 1), {}
),
)

def test_u16_big(self):
self.assertEqual(
(2**16 - 1, 2),
binread.U16(byteorder="big").read(struct.pack(">H", 2**16 - 1), {}),
binread.U16(byteorder="big").read_field(struct.pack(">H", 2**16 - 1), {}),
)

def test_u16_little(self):
self.assertEqual(
(2**16 - 1, 2),
binread.U16(byteorder="little").read(struct.pack("<H", 2**16 - 1), {}),
binread.U16(byteorder="little").read_field(
struct.pack("<H", 2**16 - 1), {}
),
)

def test_u8_big(self):
self.assertEqual(
(2**8 - 1, 1),
binread.U8(byteorder="big").read(struct.pack(">B", 2**8 - 1), {}),
binread.U8(byteorder="big").read_field(struct.pack(">B", 2**8 - 1), {}),
)

def test_u8_little(self):
self.assertEqual(
(2**8 - 1, 1),
binread.U8(byteorder="little").read(struct.pack("<B", 2**8 - 1), {}),
binread.U8(byteorder="little").read_field(
struct.pack("<B", 2**8 - 1), {}
),
)

def test_i64_big(self):
self.assertEqual(
(-(2**63) + 1, 8),
binread.I64(byteorder="big").read(struct.pack(">q", -(2**63) + 1), {}),
binread.I64(byteorder="big").read_field(
struct.pack(">q", -(2**63) + 1), {}
),
)

def test_i64_little(self):
self.assertEqual(
(-(2**63) + 1, 8),
binread.I64(byteorder="little").read(struct.pack("<q", -(2**63) + 1), {}),
binread.I64(byteorder="little").read_field(
struct.pack("<q", -(2**63) + 1), {}
),
)

def test_i32_big(self):
self.assertEqual(
(-(2**31) + 1, 4),
binread.I32(byteorder="big").read(struct.pack(">l", -(2**31) + 1), {}),
binread.I32(byteorder="big").read_field(
struct.pack(">l", -(2**31) + 1), {}
),
)

def test_i32_little(self):
self.assertEqual(
(-(2**31) + 1, 4),
binread.I32(byteorder="little").read(struct.pack("<l", -(2**31) + 1), {}),
binread.I32(byteorder="little").read_field(
struct.pack("<l", -(2**31) + 1), {}
),
)

def test_i16_big(self):
self.assertEqual(
(-(2**15) + 1, 2),
binread.I16(byteorder="big").read(struct.pack(">h", -(2**15) + 1), {}),
binread.I16(byteorder="big").read_field(
struct.pack(">h", -(2**15) + 1), {}
),
)

def test_i16_little(self):
self.assertEqual(
(-(2**15) + 1, 2),
binread.I16(byteorder="little").read(struct.pack("<h", -(2**15) + 1), {}),
binread.I16(byteorder="little").read_field(
struct.pack("<h", -(2**15) + 1), {}
),
)

def test_i8_big(self):
self.assertEqual(
(-(2**7) + 1, 1),
binread.I8(byteorder="big").read(struct.pack(">b", -(2**7) + 1), {}),
binread.I8(byteorder="big").read_field(
struct.pack(">b", -(2**7) + 1), {}
),
)

def test_i8_little(self):
self.assertEqual(
(-(2**7) + 1, 1),
binread.I8(byteorder="little").read(struct.pack("<b", -(2**7) + 1), {}),
binread.I8(byteorder="little").read_field(
struct.pack("<b", -(2**7) + 1), {}
),
)

def test_array(self):
self.assertEqual(
([1, 2, 3, 4, 5], 5 * 2),
binread.Array(element=binread.U16, length=5).read(
binread.Array(element=binread.U16, length=5).read_field(
struct.pack("5H", 1, 2, 3, 4, 5), {}
),
)

def test_array_variable_terminated(self):
self.assertEqual(
([1, 2, 3, 4], 5 * 2),
binread.Array(element=binread.U16, terminator=b"\x00\x00").read(
binread.Array(element=binread.U16, terminator=b"\x00\x00").read_field(
struct.pack("6H", 1, 2, 3, 4, 0, 10), {}
),
)

def test_array_struct(self):
value = binread.Array(
binread.Format({"array": binread.Array(binread.U16, length=5)}),
length=2,
).read_field(struct.pack("10H", 1, 2, 3, 4, 5, 1, 2, 3, 4, 5), {})

self.assertEqual(
([{"array": [1, 2, 3, 4, 5]}, {"array": [1, 2, 3, 4, 5]}], 10 * 2), value
)

0 comments on commit 5e6d10d

Please sign in to comment.