diff --git a/pybricksdev/ble/pybricks.py b/pybricksdev/ble/pybricks.py index cb058b9..44d35df 100644 --- a/pybricksdev/ble/pybricks.py +++ b/pybricksdev/ble/pybricks.py @@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str: .. availability:: Since Pybricks protocol v1.0.0. """ +DVC_NAME_UUID = _standard_uuid(0x2A00) +"""Standard Device Name UUID + +.. availability:: Since Pybricks protocol v1.0.0. +""" + FW_REV_UUID = _standard_uuid(0x2A26) """Standard Firmware Revision String characteristic UUID diff --git a/pybricksdev/cli/__init__.py b/pybricksdev/cli/__init__.py index 198f52b..f0f341e 100644 --- a/pybricksdev/cli/__init__.py +++ b/pybricksdev/cli/__init__.py @@ -171,10 +171,12 @@ def add_parser(self, subparsers: argparse._SubParsersAction): ) async def run(self, args: argparse.Namespace): - from ..ble import find_device + from usb.core import find as find_usb + + from ..ble import find_device as find_ble from ..connections.ev3dev import EV3Connection from ..connections.lego import REPLHub - from ..connections.pybricks import PybricksHubBLE + from ..connections.pybricks import PybricksHubBLE, PybricksHubUSB # Pick the right connection if args.conntype == "ssh": @@ -185,14 +187,28 @@ async def run(self, args: argparse.Namespace): device_or_address = socket.gethostbyname(args.name) hub = EV3Connection(device_or_address) + elif args.conntype == "ble": # It is a Pybricks Hub with BLE. Device name or address is given. print(f"Searching for {args.name or 'any hub with Pybricks service'}...") - device_or_address = await find_device(args.name) + device_or_address = await find_ble(args.name) hub = PybricksHubBLE(device_or_address) elif args.conntype == "usb": - hub = REPLHub() + + def is_pybricks_usb(dev): + return ( + (dev.idVendor == 0x0694) + and ((dev.idProduct == 0x0009) or (dev.idProduct == 0x0011)) + and dev.product.endswith("Pybricks") + ) + + device_or_address = find_usb(custom_match=is_pybricks_usb) + + if device_or_address is not None: + hub = PybricksHubUSB(device_or_address) + else: + hub = REPLHub() else: raise ValueError(f"Unknown connection type: {args.conntype}") diff --git a/pybricksdev/connections/pybricks.py b/pybricksdev/connections/pybricks.py index 3b69b60..c5dda82 100644 --- a/pybricksdev/connections/pybricks.py +++ b/pybricksdev/connections/pybricks.py @@ -7,6 +7,7 @@ import os import struct from typing import Awaitable, Callable, List, Optional, TypeVar +from uuid import UUID import reactivex.operators as op import semver @@ -17,10 +18,15 @@ from reactivex.subject import BehaviorSubject, Subject from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm +from usb.control import get_descriptor +from usb.core import Device as USBDevice +from usb.core import Endpoint, USBTimeoutError +from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor from ..ble.lwp3.bytecodes import HubKind from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID from ..ble.pybricks import ( + DVC_NAME_UUID, FW_REV_UUID, PNP_ID_UUID, PYBRICKS_COMMAND_EVENT_UUID, @@ -705,3 +711,120 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None: async def start_notify(self, uuid: str, callback: Callable) -> None: return await self._client.start_notify(uuid, callback) + + +class PybricksHubUSB(PybricksHub): + _device: USBDevice + _ep_in: Endpoint + _ep_out: Endpoint + _notify_callbacks = {} + _monitor_task: asyncio.Task + + def __init__(self, device: USBDevice): + super().__init__() + self._device = device + + async def _client_connect(self) -> bool: + self._device.set_configuration() + + # Save input and output endpoints + cfg = self._device.get_active_configuration() + intf = cfg[(0, 0)] + self._ep_in = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_IN, + ) + self._ep_out = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_OUT, + ) + + # Set write size to endpoint packet size minus length of UUID + self._max_write_size = self._ep_out.wMaxPacketSize - 16 + + # Get length of BOS descriptor + bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0) + (ofst, _, bos_len, _) = struct.unpack(" bool: + self._monitor_task.cancel() + self._handle_disconnect() + + async def read_gatt_char(self, uuid: str) -> bytearray: + return None + + async def write_gatt_char(self, uuid: str, data, response: bool) -> None: + self._ep_out.write(UUID(uuid).bytes_le + data) + # TODO: Handle response + + async def start_notify(self, uuid: str, callback: Callable) -> None: + self._notify_callbacks[uuid] = callback + + async def _monitor_usb(self): + loop = asyncio.get_running_loop() + + while True: + msg = await loop.run_in_executor(None, self._read_usb) + + if msg is None: + continue + + if len(msg) > 16: + uuid = str(UUID(bytes_le=bytes(msg[0:16]))) + if uuid in self._notify_callbacks: + callback = self._notify_callbacks[uuid] + if callback: + callback(None, bytes(msg[16:])) + + def _read_usb(self): + try: + msg = self._ep_in.read(self._ep_in.wMaxPacketSize) + return msg + except USBTimeoutError: + return None