Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

connections.ev3: Improve firmware update. #98

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions pybricksdev/cli/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,24 +382,28 @@ async def flash_ev3(firmware: bytes) -> None:
fw, hw = await bootloader.get_version()
print(f"hwid: {hw}")

ERASE_TICKS = 60

# Erasing doesn't have any feedback so we just use time for the progress
# bar. The operation runs on the EV3, so the time is the same for everyone.
async def tick(callback):
for _ in range(ERASE_TICKS):
await asyncio.sleep(1)
callback(1)
CHUNK = 8000
SPEED = 256000
for _ in range(len(firmware) // CHUNK):
await asyncio.sleep(CHUNK / SPEED)
callback(CHUNK)

print("Erasing memory...")
with logging_redirect_tqdm(), tqdm(total=ERASE_TICKS) as pbar:
await asyncio.gather(bootloader.erase_chip(), tick(pbar.update))
print("Erasing memory and preparing firmware download...")
with logging_redirect_tqdm(), tqdm(
total=len(firmware), unit="B", unit_scale=True
) as pbar:
await asyncio.gather(
bootloader.erase_and_begin_download(len(firmware)), tick(pbar.update)
)

print("Downloading firmware...")
with logging_redirect_tqdm(), tqdm(
total=len(firmware), unit="B", unit_scale=True
) as pbar:
await bootloader.download(0, firmware, pbar.update)
await bootloader.download(firmware, pbar.update)

print("Verifying...", end="", flush=True)
checksum = await bootloader.get_checksum(0, len(firmware))
Expand Down
60 changes: 39 additions & 21 deletions pybricksdev/connections/ev3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,20 @@ def _send_command(self, command: Command, payload: Optional[bytes] = None) -> in

return message_number

def _receive_reply(self, command: Command, message_number: int) -> bytes:
def _receive_reply(
self, command: Command, message_number: int, force_length: int = 0
) -> bytes:
"""
Receive a reply from the EV3 bootloader.

Args:
command: The command that was sent.
message_number: The return value of :meth:`_send_command`.
force_length: Expected length, used only when it fails to unpack
normally. Some replies on USB 3.0 hosts contain
the original command written over the reply. This
means the header is bad, but the payload may be in
tact if you know what data to expect.

Returns:
The payload of the reply.
Expand All @@ -131,36 +138,41 @@ def _receive_reply(self, command: Command, message_number: int) -> bytes:
raise ReplyError(status)

if message_type != MessageType.SYSTEM_REPLY:
raise RuntimeError("unexpected message type: {message_type}")
if force_length:
return reply[7 : force_length + 2]
raise RuntimeError(f"unexpected message type: {message_type}")

if reply_command != command:
raise RuntimeError("command mismatch: {reply_command} != {command}")
raise RuntimeError(f"command mismatch: {reply_command} != {command}")

return reply[7 : length + 2]

def download_sync(
self,
address: int,
data: bytes,
progress: Optional[Callable[[int], None]] = None,
) -> None:
"""
Blocking version of :meth:`download`.
"""
param_data = struct.pack("<II", address, len(data))
num = self._send_command(Command.BEGIN_DOWNLOAD, param_data)
self._receive_reply(Command.BEGIN_DOWNLOAD, num)

completed = 0
for c in chunk(data, self._MAX_DATA_SIZE):
num = self._send_command(Command.DOWNLOAD_DATA, c)
self._receive_reply(Command.DOWNLOAD_DATA, num)
try:
completed += len(c)
self._receive_reply(Command.DOWNLOAD_DATA, num)
except RuntimeError as e:
# Allow exception only on the final chunk.
if completed != len(data):
raise e
print(e, ". Proceeding anyway.")

if progress:
progress(len(c))

async def download(
self,
address: int,
data: bytes,
progress: Optional[Callable[[int], None]] = None,
) -> None:
Expand All @@ -170,30 +182,31 @@ async def download(
This operation takes about 60 seconds for a full 16MB firmware file.

Args:
address: The starting address of where to write the data.
data: The data to write.
progress: Optional callback for indicating progress.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.download_sync, address, data, progress
None, self.download_sync, data, progress
)

def erase_chip_sync(self) -> None:
def erase_and_begin_download_sync(self, size) -> None:
"""
Blocking version of :meth:`erase_chip`.
Blocking version of :meth:`erase_and_begin_download`.
"""
num = self._send_command(Command.CHIP_ERASE)
self._receive_reply(Command.CHIP_ERASE, num)
param_data = struct.pack("<II", 0, size)
num = self._send_command(Command.BEGIN_DOWNLOAD_WITH_ERASE, param_data)
self._receive_reply(Command.BEGIN_DOWNLOAD_WITH_ERASE, num)

async def erase_chip(self) -> None:
async def erase_and_begin_download(self, size) -> None:
"""
Erases the external flash memory chip.
Erases the external flash memory chip by the amount required to
flash the new firmware. Also prepares firmware download.

This operation takes about 60 seconds.
Args:
size: How much to erase.
"""
return await asyncio.get_running_loop().run_in_executor(
None,
self.erase_chip_sync,
None, self.erase_and_begin_download_sync, size
)

def start_app_sync(self) -> None:
Expand Down Expand Up @@ -241,7 +254,12 @@ def get_version_sync(self) -> Tuple[int, int]:
Blocking version of :meth:`get_version`.
"""
num = self._send_command(Command.GET_VERSION)
payload = self._receive_reply(Command.GET_VERSION, num)
# On certain USB 3.0 systems, the brick reply contains the command
# we just sent written over it. This means we don't get the correct
# header and length info. Since the command here is smaller than the
# reply, the paypload does not get overwritten, so we can still get
# the version info since we know the expected reply size.
payload = self._receive_reply(Command.GET_VERSION, num, force_length=13)
return struct.unpack("<II", payload)

async def get_version(self) -> Tuple[int, int]:
Expand Down
Loading