Skip to content

Commit

Permalink
feat: implicit async ws
Browse files Browse the repository at this point in the history
  • Loading branch information
tavallaie committed Aug 4, 2024
1 parent 8c9ffe6 commit 5d2ef01
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 33 deletions.
59 changes: 31 additions & 28 deletions connectiva/protocols/websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# connectiva/protocols/websocket_protocol.py

import asyncio
import websockets
import json
import logging
from typing import Dict, Any
from typing import Dict, Any, Tuple
from connectiva import CommunicationMethod, Message


class WebSocketProtocol(CommunicationMethod):
"""
WebSocket protocol that can operate as both a server and client.
Expand All @@ -16,21 +19,35 @@ def __init__(self, **kwargs):
self.logger = logging.getLogger(self.__class__.__name__)
self.websocket = None
self.server = None
self.loop = asyncio.get_event_loop()

def _parse_websocket_url(self) -> Tuple[str, int]:
"""
Parse the WebSocket URL to extract the host and port.
"""
try:
if self.endpoint.startswith("ws://") or self.endpoint.startswith("wss://"):
address = self.endpoint.split("//")[1]
host, port = address.split(":")
return host, int(port)
raise ValueError("Invalid WebSocket URL format")
except Exception as e:
raise ValueError(f"Error parsing WebSocket URL: {e}")

async def start_server(self):
async def _start_server(self):
"""
Starts the WebSocket server.
"""
host, port = self._parse_websocket_url(self.endpoint)
host, port = self._parse_websocket_url()
self.logger.info(f"Starting WebSocket server on {self.endpoint}...")
try:
self.server = await websockets.serve(self.handler, host, port)
self.server = await websockets.serve(self._server_handler, host, port)
self.logger.info("WebSocket server started.")
await self.server.wait_closed()
except Exception as e:
self.logger.error(f"Failed to start WebSocket server: {e}")

async def connect_async(self):
async def _connect_async(self):
"""
Connects to a WebSocket server.
"""
Expand All @@ -41,7 +58,7 @@ async def connect_async(self):
except Exception as e:
self.logger.error(f"Failed to connect to WebSocket: {e}")

async def handler(self, websocket, path):
async def _server_handler(self, websocket, path):
"""
Handles incoming WebSocket connections.
"""
Expand All @@ -59,13 +76,13 @@ def connect(self):
Starts the server or connects as a client based on the mode.
"""
if self.mode == "server":
asyncio.get_event_loop().run_until_complete(self.start_server())
self.loop.run_until_complete(self._start_server())
elif self.mode == "client":
asyncio.get_event_loop().run_until_complete(self.connect_async())
self.loop.run_until_complete(self._connect_async())
else:
self.logger.error("Invalid mode specified. Use 'client' or 'server'.")

async def send_async(self, message: Message) -> Dict[str, Any]:
async def _send_async(self, message: Message) -> Dict[str, Any]:
"""
Sends a message via WebSocket.
"""
Expand All @@ -83,12 +100,12 @@ def send(self, message: Message) -> Dict[str, Any]:
Unified method to send a message.
"""
if self.mode == "client":
return asyncio.get_event_loop().run_until_complete(self.send_async(message))
return self.loop.run_until_complete(self._send_async(message))
else:
self.logger.error("Sending directly from server mode is not supported.")
return {"error": "Invalid operation in server mode"}

async def receive_async(self) -> Message:
async def _receive_async(self) -> Message:
"""
Receives a message via WebSocket.
"""
Expand All @@ -106,12 +123,12 @@ def receive(self) -> Message:
Unified method to receive a message.
"""
if self.mode == "client":
return asyncio.get_event_loop().run_until_complete(self.receive_async())
return self.loop.run_until_complete(self._receive_async())
else:
self.logger.error("Receiving directly from server mode is not supported.")
return Message(action="error", data={}, metadata={"error": "Invalid operation in server mode"})

async def disconnect_async(self):
async def _disconnect_async(self):
"""
Disconnects the WebSocket connection.
"""
Expand All @@ -124,23 +141,9 @@ def disconnect(self):
Unified method to disconnect.
"""
if self.mode == "client":
asyncio.get_event_loop().run_until_complete(self.disconnect_async())
self.loop.run_until_complete(self._disconnect_async())
elif self.mode == "server" and self.server:
self.server.close()
self.logger.info("WebSocket server stopped.")
else:
self.logger.error("No active connection to disconnect.")

@staticmethod
def _parse_websocket_url(endpoint: str):
"""
Parse the WebSocket URL to extract the host and port.
"""
try:
if endpoint.startswith("ws://") or endpoint.startswith("wss://"):
address = endpoint.split("//")[1]
host, port = address.split(":")
return host, int(port)
raise ValueError("Invalid WebSocket URL format")
except Exception as e:
raise ValueError(f"Error parsing WebSocket URL: {e}")
11 changes: 6 additions & 5 deletions tests/test_websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import asyncio
import logging
import threading
from connectiva import Connectiva, Message

class TestWebSocketProtocolWithConnectiva(unittest.TestCase):
Expand All @@ -15,29 +15,30 @@ def setUpClass(cls):
mode="server",
log=True
)
cls.loop = asyncio.get_event_loop()
cls.loop.run_in_executor(None, cls.server.connect)
cls.server_thread = threading.Thread(target=cls.server.connect)
cls.server_thread.start()

# Start the WebSocket client using Connectiva
cls.client = Connectiva(
endpoint="ws://localhost:8765",
mode="client",
log=True
)
cls.loop.run_until_complete(cls.client.connect_async())
cls.client.connect()

@classmethod
def tearDownClass(cls):
# Disconnect both client and server
cls.client.disconnect()
cls.server.disconnect()
cls.server_thread.join()

def test_send_receive(self):
# Test sending and receiving messages
message = Message(action="send", data={"content": "Hello WebSocket!"})
response = self.client.send(message)
self.assertEqual(response["status"], "sent", "Message should be sent successfully.")

received_message = self.client.receive()
self.assertEqual(received_message.data["received"], message.data["content"], "Received message should match sent message.")

Expand Down

0 comments on commit 5d2ef01

Please sign in to comment.