diff --git a/rsp1570serial/connection.py b/rsp1570serial/connection.py index 20d4e34..226279a 100644 --- a/rsp1570serial/connection.py +++ b/rsp1570serial/connection.py @@ -3,7 +3,7 @@ import logging from rsp1570serial.commands import encode_command, encode_volume_direct_command from rsp1570serial.messages import decode_message_stream -from serial import PARITY_NONE, STOPBITS_ONE +from serial import PARITY_NONE, STOPBITS_ONE, SerialException from serial_asyncio import open_serial_connection import uuid import weakref @@ -11,6 +11,10 @@ _LOGGER = logging.getLogger(__name__) +class RotelAmpConnConnectionError(Exception): + pass + + class RotelAmpConn: """ Basic connection to a Rotel Amp @@ -26,13 +30,16 @@ def __init__(self, serial_port): async def open(self): if not self.is_open: - self.reader, self.writer = await open_serial_connection( - url=self.serial_port, - baudrate=115200, - timeout=None, - parity=PARITY_NONE, - stopbits=STOPBITS_ONE, - ) + try: + self.reader, self.writer = await open_serial_connection( + url=self.serial_port, + baudrate=115200, + timeout=None, + parity=PARITY_NONE, + stopbits=STOPBITS_ONE, + ) + except SerialException as exc: + raise RotelAmpConnConnectionError(str(exc)) from exc self.is_open = True def close(self): diff --git a/rsp1570serial/tests/test_connection.py b/rsp1570serial/tests/test_connection.py index 7f7bcc3..3a0e1ab 100644 --- a/rsp1570serial/tests/test_connection.py +++ b/rsp1570serial/tests/test_connection.py @@ -1,7 +1,7 @@ import asyncio import aiounittest from contextlib import asynccontextmanager -from rsp1570serial.connection import SharedRotelAmpConn +from rsp1570serial.connection import SharedRotelAmpConn, RotelAmpConnConnectionError from rsp1570serial.emulator import ( make_message_handler, create_device, @@ -125,18 +125,21 @@ async def test_multi_clients1(self): async def test_connection_failure1(self): shared_conn = SharedRotelAmpConn(f"socket://:{BAD_TEST_PORT}") - with self.assertRaises(SerialException): - # Connection refused by host + # Connection refused by host + with self.assertRaises(RotelAmpConnConnectionError) as cm: await shared_conn.open() + self.assertIsInstance(cm.exception.__cause__, SerialException) async def test_connection_failure2(self): shared_conn = SharedRotelAmpConn(f"socket://192.168.51.1:{TEST_PORT}") - with self.assertRaises(SerialException): - # Timed out due to made up IP address + # Timed out due to made up IP address + with self.assertRaises(RotelAmpConnConnectionError) as cm: await shared_conn.open() + self.assertIsInstance(cm.exception.__cause__, SerialException) async def test_connection_failure3(self): shared_conn = SharedRotelAmpConn(f"socket://made_up_hostname:{TEST_PORT}") - with self.assertRaises(SerialException): - # getaddrinfo failed + # getaddrinfo failed + with self.assertRaises(RotelAmpConnConnectionError) as cm: await shared_conn.open() + self.assertIsInstance(cm.exception.__cause__, SerialException)