diff --git a/madvr/madvr.py b/madvr/madvr.py index 4529ea3..9ab094c 100644 --- a/madvr/madvr.py +++ b/madvr/madvr.py @@ -6,6 +6,7 @@ from typing import Final, Union import re import time +import threading import socket from madvr.commands import ACKs, Footer, Commands, Enum, Connections from madvr.errors import AckError, RetryExceededError, HeartBeatError @@ -26,6 +27,7 @@ def __init__( self.port = port self.connect_timeout: int = connect_timeout self.logger = logger + self._lock = threading.Lock() # Const values self.MADVR_OK: Final = Connections.welcome.value @@ -46,7 +48,7 @@ def __init__( self.temp_gpu: int = 0 self.temp_hdmi: int = 0 self.temp_cpu: int = 0 - self.temp_mainboard: int = 0 + self.temp_mainboard: int = 0 # Outgoing signal self.outgoing_res = "" @@ -288,8 +290,8 @@ def send_command(self, command: str) -> str: self.client.send(cmd) try: - # Read the Ok\r\n - ack_reply = self.client.recv(4) + # Read everything because it can randomly include notifications ugh + ack_reply = self.client.recv(self.read_limit) self.logger.debug("Got ack from cmd: %s", ack_reply) # Don't read more if its informational if not is_info: @@ -377,17 +379,23 @@ def poll_status(self) -> None: """ Poll the status for attributes and write them to state """ - # send heartbeat so it doesnt close our connection - try: - self._send_heartbeat() - except (socket.timeout, socket.error, HeartBeatError) as err: - self.logger.error("Error getting update: %s", err) - - # Get incoming signal info - for cmd in ["GetIncomingSignalInfo", "GetAspectRatio", "GetTemperatures", "GetOutgoingSignalInfo"]: - res = self.send_command(cmd) - self.logger.debug("poll_status resp: %s", res) - self._process_notifications(res) + # lock so HA doesn't trip over itself + with self._lock: + try: + # send heartbeat so it doesnt close our connection + self._send_heartbeat() + # Get incoming signal info + for cmd in [ + "GetIncomingSignalInfo", + "GetAspectRatio", + "GetTemperatures", + "GetOutgoingSignalInfo", + ]: + res = self.send_command(cmd) + self.logger.debug("poll_status resp: %s", res) + self._process_notifications(res) + except (socket.timeout, socket.error, HeartBeatError) as err: + self.logger.error("Error getting update: %s", err) def _process_notifications(self, input_data: Union[bytes, str]) -> None: """ diff --git a/tests/testMadVR.py b/tests/testMadVR.py index 5380b8e..174a704 100644 --- a/tests/testMadVR.py +++ b/tests/testMadVR.py @@ -54,20 +54,23 @@ class TestFunctional(unittest.TestCase): def test_a_poll(self): # self.skipTest("") - print("running test_a_poll") self.assertEqual(madvr.incoming_res, "") madvr.poll_status() self.assertNotEqual(madvr.incoming_res, "") + self.assertNotEqual(madvr.outgoing_color_space, "") + print( + madvr.hdr_flag, + madvr.incoming_aspect_ratio, + madvr.incoming_frame_rate + ) - print("running test_b_aspect") signal = madvr.send_command("GetAspectRatio") self.assertNotEqual(signal, "Command not found") signal = madvr.send_command("KeyPress, UP") self.assertNotEqual(signal, "Command not found") - print("running test_c_command_notfound") fake_cmd = madvr.send_command("FakeCommand") self.assertEqual(fake_cmd, "Command not found") @@ -77,7 +80,6 @@ def test_a_poll(self): # print("running test_d_notifications") # madvr.read_notifications(True) - print("running test_e_menuopen") cmd = madvr.send_command("KeyPress, MENU") self.assertNotEqual(cmd, "Command not found") @@ -85,7 +87,6 @@ def test_a_poll(self): cmd = madvr.send_command("KeyPress, MENU") self.assertNotEqual(cmd, "Command not found") - print("running test_z_ConnClose") madvr.close_connection() self.assertEqual(True, madvr.is_closed)