-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from tavallaie/websocket
ws init with tests
- Loading branch information
Showing
4 changed files
with
253 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
name: WebSocket Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- 'connectiva/protocols/websocket_protocol.py' | ||
- 'tests/test_websocket_protocol.py' | ||
- 'pyproject.toml' | ||
pull_request: | ||
branches: | ||
- main | ||
paths: | ||
- 'connectiva/protocols/websocket_protocol.py' | ||
- 'tests/test_websocket_protocol.py' | ||
- 'pyproject.toml' | ||
|
||
jobs: | ||
test-websocket: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.8", "3.9", "3.10", "3.11"] | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install Poetry | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install poetry | ||
- name: Add Poetry to Path | ||
run: echo "export PATH=\"$HOME/.local/bin:\$PATH\"" >> $GITHUB_ENV | ||
|
||
- name: Install dependencies | ||
run: poetry install | ||
|
||
- name: Run WebSocket Tests | ||
run: | | ||
echo "Running WebSocket tests on Python ${{ matrix.python-version }}..." | ||
poetry run python -m unittest discover -s tests -p 'test_websocket_protocol.py' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,155 @@ | ||
# connectiva/protocols/websocket_protocol.py | ||
|
||
import asyncio | ||
import websockets | ||
import json | ||
from typing import Dict, Any | ||
from ..interfaces import CommunicationMethod | ||
from ..message import Message | ||
import logging | ||
from typing import Dict, Any, Tuple | ||
from connectiva import CommunicationMethod, Message | ||
|
||
class WebSocketProtocol(CommunicationMethod): | ||
""" | ||
WebSocket communication class. | ||
WebSocket protocol that can operate as both a server and client. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
self.websocket_url = kwargs.get("websocket_url") | ||
self.ws = None | ||
self.mode = kwargs.get("mode", "client") # "client" or "server" | ||
self.endpoint = kwargs.get("endpoint", "ws://localhost:8765") | ||
self.logger = logging.getLogger(self.__class__.__name__) | ||
self.websocket = None | ||
self.server = None | ||
self.loop = asyncio.get_event_loop() | ||
|
||
def connect(self): | ||
print(f"Connecting to WebSocket at {self.websocket_url}...") | ||
def _parse_websocket_url(self) -> Tuple[str, int]: | ||
""" | ||
Parse the WebSocket URL to extract the host and port. | ||
""" | ||
try: | ||
self.ws = websockets.create_connection(self.websocket_url) | ||
print("Connected to WebSocket!") | ||
if self.endpoint.startswith("ws://") or self.endpoint.startswith("wss://"): | ||
address = self.endpoint.split("//")[1] | ||
host, port = address.split(":") | ||
return host, int(port) | ||
raise ValueError("Invalid WebSocket URL format") | ||
except Exception as e: | ||
print(f"Failed to connect to WebSocket: {e}") | ||
raise ValueError(f"Error parsing WebSocket URL: {e}") | ||
|
||
def send(self, message: Message) -> Dict[str, Any]: | ||
print("Sending message via WebSocket...") | ||
async def _start_server(self): | ||
""" | ||
Starts the WebSocket server. | ||
""" | ||
host, port = self._parse_websocket_url() | ||
self.logger.info(f"Starting WebSocket server on {self.endpoint}...") | ||
try: | ||
self.server = await websockets.serve(self._server_handler, host, port) | ||
self.logger.info("WebSocket server started.") | ||
await self.server.wait_closed() | ||
except Exception as e: | ||
self.logger.error(f"Failed to start WebSocket server: {e}") | ||
|
||
async def _connect_async(self): | ||
""" | ||
Connects to a WebSocket server. | ||
""" | ||
self.logger.info(f"Connecting to WebSocket at {self.endpoint}...") | ||
try: | ||
self.ws.send(json.dumps(message.__dict__)) | ||
print("Message sent successfully!") | ||
self.websocket = await websockets.connect(self.endpoint) | ||
self.logger.info("Connected to WebSocket!") | ||
except Exception as e: | ||
self.logger.error(f"Failed to connect to WebSocket: {e}") | ||
|
||
async def _server_handler(self, websocket, path): | ||
""" | ||
Handles incoming WebSocket connections. | ||
""" | ||
self.logger.info("Client connected.") | ||
try: | ||
async for message in websocket: | ||
self.logger.info(f"Received message: {message}") | ||
|
||
# Echo the message back with the expected structure | ||
received_data = json.loads(message) | ||
response = json.dumps({ | ||
"action": "response", | ||
"data": {"received": received_data["data"]["content"]} | ||
}) | ||
|
||
await websocket.send(response) | ||
except websockets.exceptions.ConnectionClosed as e: | ||
self.logger.info(f"Client disconnected: {e}") | ||
|
||
def connect(self): | ||
""" | ||
Starts the server or connects as a client based on the mode. | ||
""" | ||
if self.mode == "server": | ||
asyncio.run(self._start_server()) | ||
elif self.mode == "client": | ||
asyncio.run(self._connect_async()) | ||
else: | ||
self.logger.error("Invalid mode specified. Use 'client' or 'server'.") | ||
|
||
async def _send_async(self, message: Message) -> Dict[str, Any]: | ||
""" | ||
Sends a message via WebSocket. | ||
""" | ||
self.logger.info("Sending message via WebSocket...") | ||
try: | ||
await self.websocket.send(json.dumps(message.__dict__)) | ||
self.logger.info("Message sent successfully!") | ||
return {"status": "sent"} | ||
except Exception as e: | ||
print(f"Failed to send message: {e}") | ||
self.logger.error(f"Failed to send message: {e}") | ||
return {"error": str(e)} | ||
|
||
def receive(self) -> Message: | ||
print("Receiving message via WebSocket...") | ||
def send(self, message: Message) -> Dict[str, Any]: | ||
""" | ||
Unified method to send a message. | ||
""" | ||
if self.mode == "client": | ||
return asyncio.run(self._send_async(message)) | ||
else: | ||
self.logger.error("Sending directly from server mode is not supported.") | ||
return {"error": "Invalid operation in server mode"} | ||
|
||
async def _receive_async(self) -> Message: | ||
""" | ||
Receives a message via WebSocket. | ||
""" | ||
self.logger.info("Receiving message via WebSocket...") | ||
try: | ||
message = self.ws.recv() | ||
print("Message received successfully!") | ||
message = await self.websocket.recv() | ||
self.logger.info("Message received successfully!") | ||
return Message(action="receive", data=json.loads(message)) | ||
except Exception as e: | ||
print(f"Failed to receive message: {e}") | ||
self.logger.error(f"Failed to receive message: {e}") | ||
return Message(action="error", data={}, metadata={"error": str(e)}) | ||
|
||
def receive(self) -> Message: | ||
""" | ||
Unified method to receive a message. | ||
""" | ||
if self.mode == "client": | ||
return asyncio.run(self._receive_async()) | ||
else: | ||
self.logger.error("Receiving directly from server mode is not supported.") | ||
return Message(action="error", data={}, metadata={"error": "Invalid operation in server mode"}) | ||
|
||
async def _disconnect_async(self): | ||
""" | ||
Disconnects the WebSocket connection. | ||
""" | ||
self.logger.info("Disconnecting from WebSocket...") | ||
if self.websocket: | ||
await self.websocket.close() | ||
|
||
def disconnect(self): | ||
print("Disconnecting from WebSocket...") | ||
if self.ws: | ||
self.ws.close() | ||
""" | ||
Unified method to disconnect. | ||
""" | ||
if self.mode == "client": | ||
asyncio.run(self._disconnect_async()) | ||
elif self.mode == "server" and self.server: | ||
self.server.close() | ||
self.logger.info("WebSocket server stopped.") | ||
else: | ||
self.logger.error("No active connection to disconnect.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# tests/test_websocket_protocol.py | ||
|
||
import unittest | ||
import logging | ||
import threading | ||
import nest_asyncio | ||
import asyncio | ||
from connectiva import Connectiva, Message | ||
|
||
# Apply the nest_asyncio patch | ||
nest_asyncio.apply() | ||
|
||
class TestWebSocketProtocolWithConnectiva(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
# Set up logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
# Start the WebSocket server using Connectiva | ||
cls.server = Connectiva( | ||
endpoint="ws://localhost:8765", | ||
mode="server", | ||
log=True | ||
) | ||
cls.server_thread = threading.Thread(target=cls.server.connect) | ||
cls.server_thread.start() | ||
|
||
# Start the WebSocket client using Connectiva | ||
cls.client = Connectiva( | ||
endpoint="ws://localhost:8765", | ||
mode="client", | ||
log=True | ||
) | ||
cls.client.connect() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
# Disconnect both client and server | ||
cls.client.disconnect() | ||
cls.server.disconnect() | ||
cls.server_thread.join() | ||
|
||
def test_send_receive(self): | ||
# Test sending and receiving messages | ||
message = Message(action="send", data={"content": "Hello WebSocket!"}) | ||
response = self.client.send(message) | ||
self.assertEqual(response["status"], "sent", "Message should be sent successfully.") | ||
|
||
received_message = self.client.receive() | ||
# Check for the 'received' key in the echoed message | ||
self.assertEqual(received_message.data["data"]["received"], message.data["content"], "Received message should match sent message.") | ||
|
||
def test_server_client_communication(self): | ||
# Test communication between server and client | ||
message = Message(action="send", data={"content": "Server-Client Test"}) | ||
response = self.client.send(message) | ||
self.assertEqual(response["status"], "sent", "Message should be sent successfully.") | ||
|
||
received_message = self.client.receive() | ||
# Check for the 'received' key in the echoed message | ||
self.assertEqual(received_message.data["data"]["received"], message.data["content"], "Server should echo the message back to the client.") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |