Skip to content

Commit

Permalink
Merge pull request #28 from JPHutchins/fix/message-deserialization
Browse files Browse the repository at this point in the history
fix: add smp_data field used to set the bytes if deserializing
  • Loading branch information
JPHutchins authored Jul 24, 2024
2 parents 4796bcf + ff47002 commit d1d3b9b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions smp/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,22 @@ class _MessageBase(ABC, BaseModel):
header: smpheader.Header = None # type: ignore
version: smpheader.Version = smpheader.Version.V2
sequence: int = None # type: ignore
smp_data: bytes = None # type: ignore

def __bytes__(self) -> bytes:
return self._bytes
return self.smp_data

@property
def BYTES(self) -> bytes:
return self._bytes
return self.smp_data

@classmethod
def loads(cls: Type[T], data: bytes) -> T:
"""Deserialize the SMP message."""
message = cls(
header=smpheader.Header.loads(data[: smpheader.Header.SIZE]),
**cast(dict, cbor2.loads(data[smpheader.Header.SIZE :])),
smp_data=data,
)
if message.header is None: # pragma: no cover
raise ValueError
Expand All @@ -75,10 +77,11 @@ def load(cls: Type[T], header: smpheader.Header, data: dict) -> T:
def model_post_init(self, _: None) -> None:
data_bytes = cbor2.dumps(
self.model_dump(
exclude_unset=True, exclude={'header', 'version', 'sequence'}, exclude_none=True
exclude_unset=True,
exclude={'header', 'version', 'sequence', 'smp_data'},
exclude_none=True,
)
)
self._bytes: bytes
if self.header is None: # create the header
object.__setattr__(
self,
Expand All @@ -95,7 +98,7 @@ def model_post_init(self, _: None) -> None:
)
object.__setattr__(self, 'sequence', self.header.sequence)
else: # validate the header and update version & sequence
if self.header.length != len(data_bytes):
if self.smp_data is None and self.header.length != len(data_bytes):
raise SMPMalformed(
f"header.length {self.header.length} != len(data_bytes) {len(data_bytes)}"
)
Expand All @@ -111,7 +114,8 @@ def model_post_init(self, _: None) -> None:
"from the provided header."
)
object.__setattr__(self, 'version', self.header.version)
self._bytes = self.header.BYTES + data_bytes
if self.smp_data is None:
object.__setattr__(self, 'smp_data', bytes(self.header) + data_bytes)


class Request(_MessageBase, ABC):
Expand Down

0 comments on commit d1d3b9b

Please sign in to comment.