-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
226 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
name: WebSocket Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- 'connectiva/protocols/websocket_protocol.py' | ||
- 'tests/test_websocket_protocol.py' | ||
- 'pyproject.toml' | ||
pull_request: | ||
branches: | ||
- main | ||
paths: | ||
- 'connectiva/protocols/websocket_protocol.py' | ||
- 'tests/test_websocket_protocol.py' | ||
- 'pyproject.toml' | ||
|
||
jobs: | ||
test-websocket: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.8", "3.9", "3.10", "3.11"] | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install Poetry | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install poetry | ||
- name: Add Poetry to Path | ||
run: echo "export PATH=\"$HOME/.local/bin:\$PATH\"" >> $GITHUB_ENV | ||
|
||
- name: Install dependencies | ||
run: poetry install | ||
|
||
- name: Run WebSocket Tests | ||
run: | | ||
echo "Running WebSocket tests on Python ${{ matrix.python-version }}..." | ||
poetry run python -m unittest discover -s tests -p 'test_websocket_protocol.py' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,146 @@ | ||
import asyncio | ||
import websockets | ||
import json | ||
import logging | ||
from typing import Dict, Any | ||
from ..interfaces import CommunicationMethod | ||
from ..message import Message | ||
from connectiva import CommunicationMethod, Message | ||
|
||
class WebSocketProtocol(CommunicationMethod): | ||
""" | ||
WebSocket communication class. | ||
WebSocket protocol that can operate as both a server and client. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
self.websocket_url = kwargs.get("websocket_url") | ||
self.ws = None | ||
self.mode = kwargs.get("mode", "client") # "client" or "server" | ||
self.endpoint = kwargs.get("endpoint", "ws://localhost:8765") | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.websocket = None | ||
self.server = None | ||
|
||
def connect(self): | ||
print(f"Connecting to WebSocket at {self.websocket_url}...") | ||
async def start_server(self): | ||
""" | ||
Starts the WebSocket server. | ||
""" | ||
host, port = self._parse_websocket_url(self.endpoint) | ||
self.logger.info(f"Starting WebSocket server on {self.endpoint}...") | ||
try: | ||
self.ws = websockets.create_connection(self.websocket_url) | ||
print("Connected to WebSocket!") | ||
self.server = await websockets.serve(self.handler, host, port) | ||
self.logger.info("WebSocket server started.") | ||
await self.server.wait_closed() | ||
except Exception as e: | ||
print(f"Failed to connect to WebSocket: {e}") | ||
self.logger.error(f"Failed to start WebSocket server: {e}") | ||
|
||
def send(self, message: Message) -> Dict[str, Any]: | ||
print("Sending message via WebSocket...") | ||
async def connect_async(self): | ||
""" | ||
Connects to a WebSocket server. | ||
""" | ||
self.logger.info(f"Connecting to WebSocket at {self.endpoint}...") | ||
try: | ||
self.websocket = await websockets.connect(self.endpoint) | ||
self.logger.info("Connected to WebSocket!") | ||
except Exception as e: | ||
self.logger.error(f"Failed to connect to WebSocket: {e}") | ||
|
||
async def handler(self, websocket, path): | ||
""" | ||
Handles incoming WebSocket connections. | ||
""" | ||
self.logger.info("Client connected.") | ||
try: | ||
async for message in websocket: | ||
self.logger.info(f"Received message: {message}") | ||
response = json.dumps({"action": "response", "data": {"received": message}}) | ||
await websocket.send(response) | ||
except websockets.exceptions.ConnectionClosed as e: | ||
self.logger.info(f"Client disconnected: {e}") | ||
|
||
def connect(self): | ||
""" | ||
Starts the server or connects as a client based on the mode. | ||
""" | ||
if self.mode == "server": | ||
asyncio.get_event_loop().run_until_complete(self.start_server()) | ||
elif self.mode == "client": | ||
asyncio.get_event_loop().run_until_complete(self.connect_async()) | ||
else: | ||
self.logger.error("Invalid mode specified. Use 'client' or 'server'.") | ||
|
||
async def send_async(self, message: Message) -> Dict[str, Any]: | ||
""" | ||
Sends a message via WebSocket. | ||
""" | ||
self.logger.info("Sending message via WebSocket...") | ||
try: | ||
self.ws.send(json.dumps(message.__dict__)) | ||
print("Message sent successfully!") | ||
await self.websocket.send(json.dumps(message.__dict__)) | ||
self.logger.info("Message sent successfully!") | ||
return {"status": "sent"} | ||
except Exception as e: | ||
print(f"Failed to send message: {e}") | ||
self.logger.error(f"Failed to send message: {e}") | ||
return {"error": str(e)} | ||
|
||
def receive(self) -> Message: | ||
print("Receiving message via WebSocket...") | ||
def send(self, message: Message) -> Dict[str, Any]: | ||
""" | ||
Unified method to send a message. | ||
""" | ||
if self.mode == "client": | ||
return asyncio.get_event_loop().run_until_complete(self.send_async(message)) | ||
else: | ||
self.logger.error("Sending directly from server mode is not supported.") | ||
return {"error": "Invalid operation in server mode"} | ||
|
||
async def receive_async(self) -> Message: | ||
""" | ||
Receives a message via WebSocket. | ||
""" | ||
self.logger.info("Receiving message via WebSocket...") | ||
try: | ||
message = self.ws.recv() | ||
print("Message received successfully!") | ||
message = await self.websocket.recv() | ||
self.logger.info("Message received successfully!") | ||
return Message(action="receive", data=json.loads(message)) | ||
except Exception as e: | ||
print(f"Failed to receive message: {e}") | ||
self.logger.error(f"Failed to receive message: {e}") | ||
return Message(action="error", data={}, metadata={"error": str(e)}) | ||
|
||
def receive(self) -> Message: | ||
""" | ||
Unified method to receive a message. | ||
""" | ||
if self.mode == "client": | ||
return asyncio.get_event_loop().run_until_complete(self.receive_async()) | ||
else: | ||
self.logger.error("Receiving directly from server mode is not supported.") | ||
return Message(action="error", data={}, metadata={"error": "Invalid operation in server mode"}) | ||
|
||
async def disconnect_async(self): | ||
""" | ||
Disconnects the WebSocket connection. | ||
""" | ||
self.logger.info("Disconnecting from WebSocket...") | ||
if self.websocket: | ||
await self.websocket.close() | ||
|
||
def disconnect(self): | ||
print("Disconnecting from WebSocket...") | ||
if self.ws: | ||
self.ws.close() | ||
""" | ||
Unified method to disconnect. | ||
""" | ||
if self.mode == "client": | ||
asyncio.get_event_loop().run_until_complete(self.disconnect_async()) | ||
elif self.mode == "server" and self.server: | ||
self.server.close() | ||
self.logger.info("WebSocket server stopped.") | ||
else: | ||
self.logger.error("No active connection to disconnect.") | ||
|
||
@staticmethod | ||
def _parse_websocket_url(endpoint: str): | ||
""" | ||
Parse the WebSocket URL to extract the host and port. | ||
""" | ||
try: | ||
if endpoint.startswith("ws://") or endpoint.startswith("wss://"): | ||
address = endpoint.split("//")[1] | ||
host, port = address.split(":") | ||
return host, int(port) | ||
raise ValueError("Invalid WebSocket URL format") | ||
except Exception as e: | ||
raise ValueError(f"Error parsing WebSocket URL: {e}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import unittest | ||
import asyncio | ||
import logging | ||
from connectiva import Connectiva, Message | ||
|
||
class TestWebSocketProtocolWithConnectiva(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Set up logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
# Start the WebSocket server using Connectiva | ||
cls.server = Connectiva( | ||
endpoint="ws://localhost:8765", | ||
mode="server", | ||
log=True | ||
) | ||
cls.loop = asyncio.get_event_loop() | ||
cls.loop.run_in_executor(None, cls.server.connect) | ||
|
||
# Start the WebSocket client using Connectiva | ||
cls.client = Connectiva( | ||
endpoint="ws://localhost:8765", | ||
mode="client", | ||
log=True | ||
) | ||
cls.loop.run_until_complete(cls.client.connect_async()) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
# Disconnect both client and server | ||
cls.client.disconnect() | ||
cls.server.disconnect() | ||
|
||
def test_send_receive(self): | ||
# Test sending and receiving messages | ||
message = Message(action="send", data={"content": "Hello WebSocket!"}) | ||
response = self.client.send(message) | ||
self.assertEqual(response["status"], "sent", "Message should be sent successfully.") | ||
|
||
received_message = self.client.receive() | ||
self.assertEqual(received_message.data["received"], message.data["content"], "Received message should match sent message.") | ||
|
||
def test_server_client_communication(self): | ||
# Test communication between server and client | ||
message = Message(action="send", data={"content": "Server-Client Test"}) | ||
response = self.client.send(message) | ||
self.assertEqual(response["status"], "sent", "Message should be sent successfully.") | ||
|
||
received_message = self.client.receive() | ||
self.assertEqual(received_message.data["received"], message.data["content"], "Server should echo the message back to the client.") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |