Skip to content

Commit

Permalink
Merge pull request #69 from us-irs/improve-bytefield-impls
Browse files Browse the repository at this point in the history
Improve bytefield impls
  • Loading branch information
robamu authored Jan 23, 2024
2 parents f514035 + 21dbc77 commit 88f9ed9
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 24 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

# [unreleased]

# [v0.23.0] 2024-01-23

## Changed

- Explicitely disambigute `ByteFieldU<[8, 16, 32, 64]>.from_bytes` from
`UnsignedByteField.from_bytes` by renaming them to
`ByteFieldU<[8, 16, 32, 64].from_<[8, 16, 32, 64]>_bytes`. This might break calling code which
might now call `UnsignedByteField.from_bytes`.
- Improve `ByteFieldGenerator.from_int` and `ByteFieldGenerator.from_bytes` method. These
will now raise an exception if the passed value width in not in [1, 2, 4, 8].

## Added

- Added `ByteFieldU64` variant.

# [v0.22.0] 2023-12-22

## Changed
Expand Down
59 changes: 45 additions & 14 deletions spacepackets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def get_printable_data_string(print_format: PrintFormats, data: bytes) -> str:
class IntByteConversion:
@staticmethod
def signed_struct_specifier(byte_num: int) -> str:
if byte_num not in [1, 2, 4, 8]:
raise ValueError("Invalid byte number, must be one of [1, 2, 4, 8]")
if byte_num == 1:
return "!b"
elif byte_num == 2:
Expand All @@ -63,13 +61,10 @@ def signed_struct_specifier(byte_num: int) -> str:
return "!i"
elif byte_num == 8:
return "!q"
raise ValueError("Invalid byte number, must be one of [1, 2, 4, 8]")

@staticmethod
def unsigned_struct_specifier(byte_num: int) -> str:
if byte_num not in [1, 2, 4, 8]:
raise ValueError(
f"invalid byte number {byte_num}, must be one of [1, 2, 4, 8]"
)
if byte_num == 1:
return "!B"
elif byte_num == 2:
Expand All @@ -78,6 +73,7 @@ def unsigned_struct_specifier(byte_num: int) -> str:
return "!I"
elif byte_num == 8:
return "!Q"
raise ValueError(f"invalid byte number {byte_num}, must be one of [1, 2, 4, 8]")

@staticmethod
def to_signed(byte_num: int, val: int) -> bytes:
Expand Down Expand Up @@ -221,11 +217,12 @@ def __int__(self):
def __len__(self):
return self._byte_len

def __eq__(self, other: Union[UnsignedByteField, bytes]):
def __eq__(self, other: object) -> bool:
if isinstance(other, UnsignedByteField):
return self.value == other.value and self.byte_len == other.byte_len
elif isinstance(other, bytes):
return self._val_as_bytes == other
raise TypeError(f"Cannot compare {self.__class__.__name__} to {other}")

def __hash__(self):
"""Makes all unsigned byte fields usable as dictionary keys"""
Expand All @@ -247,7 +244,7 @@ def __init__(self, val: int):
super().__init__(val, 1)

@classmethod
def from_bytes(cls, stream: bytes) -> ByteFieldU8:
def from_u8_bytes(cls, stream: bytes) -> ByteFieldU8:
if len(stream) < 1:
raise ValueError(
"Passed stream not large enough, should be at least 1 byte"
Expand All @@ -265,7 +262,7 @@ def __init__(self, val: int):
super().__init__(val, 2)

@classmethod
def from_bytes(cls, stream: bytes) -> ByteFieldU16:
def from_u16_bytes(cls, stream: bytes) -> ByteFieldU16:
if len(stream) < 2:
raise ValueError(
"Passed stream not large enough, should be at least 2 byte"
Expand All @@ -287,10 +284,10 @@ def __init__(self, val: int):
super().__init__(val, 4)

@classmethod
def from_bytes(cls, stream: bytes) -> ByteFieldU32:
def from_u32_bytes(cls, stream: bytes) -> ByteFieldU32:
if len(stream) < 4:
raise ValueError(
"Passed stream not large enough, should be at least 4 byte"
"passed stream not large enough, should be at least 4 bytes"
)
return cls(
struct.unpack(IntByteConversion.unsigned_struct_specifier(4), stream[0:4])[
Expand All @@ -302,26 +299,60 @@ def __str__(self):
return self.default_string("U32")


class ByteFieldU64(UnsignedByteField):
"""Concrete variant of a variable length byte field which has a length of 8 bytes"""

def __init__(self, val: int):
super().__init__(val, 8)

@classmethod
def from_u64_bytes(cls, stream: bytes) -> ByteFieldU64:
if len(stream) < 8:
raise ValueError(
"passed stream not large enough, should be at least 8 byte"
)
return cls(
struct.unpack(IntByteConversion.unsigned_struct_specifier(8), stream[0:8])[
0
]
)

def __str__(self):
return self.default_string("U64")


class ByteFieldGenerator:
"""Static helpers to create the U8, U16 and U32 byte field variants of unsigned byte fields"""

@staticmethod
def from_int(byte_len: int, val: int) -> UnsignedByteField:
"""Generate an :py:class:`UnsignedByteField` from the byte length and a value.
:raise ValueError: Byte length is not one of [1, 2, 4, 8]."""
if byte_len == 1:
return ByteFieldU8(val)
elif byte_len == 2:
return ByteFieldU16(val)
elif byte_len == 4:
return ByteFieldU32(val)
elif byte_len == 8:
return ByteFieldU64(val)
raise ValueError("invalid byte length")

@staticmethod
def from_bytes(byte_len: int, stream: bytes) -> UnsignedByteField:
"""Generate an :py:class:`UnsignedByteField` from a raw bytestream and a length.
:raise ValueError: Byte length is not one of [1, 2, 4, 8]."""
if byte_len == 1:
return ByteFieldU8.from_bytes(stream)
return ByteFieldU8.from_u8_bytes(stream)
elif byte_len == 2:
return ByteFieldU16.from_bytes(stream)
return ByteFieldU16.from_u16_bytes(stream)
elif byte_len == 4:
return ByteFieldU32.from_bytes(stream)
return ByteFieldU32.from_u32_bytes(stream)
elif byte_len == 8:
return ByteFieldU64.from_u64_bytes(stream)
raise ValueError("invalid byte length")


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions tests/cfdp/pdus/test_ack_pdu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def test_ack_pdu(self):
self.check_fields_packet_0(ack_pdu=ack_pdu_unpacked)

pdu_conf = PduConfig(
transaction_seq_num=ByteFieldU32.from_bytes(
transaction_seq_num=ByteFieldU32.from_u32_bytes(
bytes([0x50, 0x00, 0x10, 0x01])
),
source_entity_id=ByteFieldU32.from_bytes(bytes([0x10, 0x00, 0x01, 0x02])),
dest_entity_id=ByteFieldU32.from_bytes(bytes([0x30, 0x00, 0x01, 0x03])),
source_entity_id=ByteFieldU32.from_u32_bytes(bytes([0x10, 0x00, 0x01, 0x02])),
dest_entity_id=ByteFieldU32.from_u32_bytes(bytes([0x30, 0x00, 0x01, 0x03])),
crc_flag=CrcFlag.WITH_CRC,
trans_mode=TransmissionMode.UNACKNOWLEDGED,
file_flag=LargeFileFlag.NORMAL,
Expand Down Expand Up @@ -152,15 +152,15 @@ def check_fields_packet_1(self, ack_pdu: AckPdu, with_crc: bool):
self.assertEqual(ack_pdu.transaction_status, TransactionStatus.ACTIVE)
self.assertEqual(
ack_pdu.pdu_file_directive.pdu_header.transaction_seq_num,
ByteFieldU32.from_bytes(bytes([0x50, 0x00, 0x10, 0x01])),
ByteFieldU32.from_u32_bytes(bytes([0x50, 0x00, 0x10, 0x01])),
)
self.assertEqual(
ack_pdu.pdu_file_directive.pdu_header.pdu_conf.source_entity_id,
ByteFieldU32.from_bytes(bytes([0x10, 0x00, 0x01, 0x02])),
ByteFieldU32.from_u32_bytes(bytes([0x10, 0x00, 0x01, 0x02])),
)
self.assertEqual(
ack_pdu.pdu_file_directive.pdu_header.pdu_conf.dest_entity_id,
ByteFieldU32.from_bytes(bytes([0x30, 0x00, 0x01, 0x03])),
ByteFieldU32.from_u32_bytes(bytes([0x30, 0x00, 0x01, 0x03])),
)
self.assertEqual(
ack_pdu.pdu_file_directive.pdu_header.pdu_conf.trans_mode,
Expand Down
2 changes: 1 addition & 1 deletion tests/cfdp/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_with_prompt_pdu(self):
self._switch_cfg()
self.pdu_conf.source_entity_id = ByteFieldU8(0)
self.pdu_conf.dest_entity_id = ByteFieldU8(0)
self.pdu_conf.transaction_seq_num = ByteFieldU16.from_bytes(bytes([0x00, 0x2C]))
self.pdu_conf.transaction_seq_num = ByteFieldU16.from_u16_bytes(bytes([0x00, 0x2C]))
prompt_pdu = PromptPdu(
response_required=ResponseRequired.KEEP_ALIVE, pdu_conf=self.pdu_conf
)
Expand Down
25 changes: 22 additions & 3 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ByteFieldU8,
ByteFieldU16,
ByteFieldU32,
ByteFieldU64,
UnsignedByteField,
IntByteConversion,
)
Expand Down Expand Up @@ -70,15 +71,19 @@ def test_one_byte_invalid_gen(self):

def test_byte_field_u8_invalid_unpack(self):
with self.assertRaises(ValueError):
ByteFieldU8.from_bytes(bytes())
ByteFieldU8.from_u8_bytes(bytes())

def test_byte_field_u16_invalid_unpack(self):
with self.assertRaises(ValueError):
ByteFieldU16.from_bytes(bytes([1]))
ByteFieldU16.from_u16_bytes(bytes([1]))

def test_byte_field_u32_invalid_unpack(self):
with self.assertRaises(ValueError):
ByteFieldU32.from_bytes(bytes([1, 2, 3]))
ByteFieldU32.from_u32_bytes(bytes([1, 2, 3]))

def test_byte_field_u64_invalid_unpack(self):
with self.assertRaises(ValueError):
ByteFieldU64.from_u64_bytes(bytes([1, 2, 3, 4, 5]))

def test_two_byte_field_gen(self):
two_byte_test = ByteFieldGenerator.from_int(byte_len=2, val=0x1842)
Expand All @@ -92,6 +97,12 @@ def test_four_byte_field_gen(self):
four_byte_test = ByteFieldGenerator.from_bytes(4, four_byte_test.as_bytes)
self.assertEqual(ByteFieldU32(0x10101010), four_byte_test)

def test_eight_byte_field_gen(self):
eight_byte_test = ByteFieldGenerator.from_int(byte_len=8, val=0x1010101010)
self.assertEqual(ByteFieldU64(0x1010101010), eight_byte_test)
eight_byte_test = ByteFieldGenerator.from_bytes(8, eight_byte_test.as_bytes)
self.assertEqual(ByteFieldU64(0x1010101010), eight_byte_test)

def test_setting_from_raw(self):
one_byte_test = ByteFieldGenerator.from_int(byte_len=1, val=0x42)
one_byte_test.value = bytes([0x22])
Expand All @@ -116,6 +127,10 @@ def test_byte_int_converter_signed_four_byte(self):
raw = IntByteConversion.to_signed(byte_num=4, val=-7329093)
self.assertEqual(struct.unpack("!i", raw)[0], -7329093)

def test_byte_int_converter_signed_eight_byte(self):
raw = IntByteConversion.to_signed(byte_num=8, val=-7329093032932932)
self.assertEqual(struct.unpack("!q", raw)[0], -7329093032932932)

def test_one_byte_str(self):
byte_field = ByteFieldU8(22)
self.assertEqual(
Expand Down Expand Up @@ -148,3 +163,7 @@ def test_two_byte_hex_str(self):
def test_four_byte_hex_str(self):
byte_field = ByteFieldU32(255532)
self.assertEqual(byte_field.hex_str, f"{byte_field.value:#010x}")

def test_eight_byte_hex_str(self):
byte_field = ByteFieldU64(0x1010101010)
self.assertEqual(byte_field.hex_str, f"{byte_field.value:#018x}")

0 comments on commit 88f9ed9

Please sign in to comment.