Skip to content

Commit

Permalink
ws init with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tavallaie committed Aug 4, 2024
1 parent 2c0e638 commit 8c9ffe6
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 23 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/test-websocket.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: WebSocket Tests

on:
push:
branches:
- main
paths:
- 'connectiva/protocols/websocket_protocol.py'
- 'tests/test_websocket_protocol.py'
- 'pyproject.toml'
pull_request:
branches:
- main
paths:
- 'connectiva/protocols/websocket_protocol.py'
- 'tests/test_websocket_protocol.py'
- 'pyproject.toml'

jobs:
test-websocket:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
python -m pip install --upgrade pip
python -m pip install poetry
- name: Add Poetry to Path
run: echo "export PATH=\"$HOME/.local/bin:\$PATH\"" >> $GITHUB_ENV

- name: Install dependencies
run: poetry install

- name: Run WebSocket Tests
run: |
echo "Running WebSocket tests on Python ${{ matrix.python-version }}..."
poetry run python -m unittest discover -s tests -p 'test_websocket_protocol.py'
145 changes: 122 additions & 23 deletions connectiva/protocols/websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,146 @@
import asyncio
import websockets
import json
import logging
from typing import Dict, Any
from ..interfaces import CommunicationMethod
from ..message import Message
from connectiva import CommunicationMethod, Message

class WebSocketProtocol(CommunicationMethod):
"""
WebSocket communication class.
WebSocket protocol that can operate as both a server and client.
"""

def __init__(self, **kwargs):
self.websocket_url = kwargs.get("websocket_url")
self.ws = None
self.mode = kwargs.get("mode", "client") # "client" or "server"
self.endpoint = kwargs.get("endpoint", "ws://localhost:8765")
self.logger = logging.getLogger(self.__class__.__name__)
self.websocket = None
self.server = None

def connect(self):
print(f"Connecting to WebSocket at {self.websocket_url}...")
async def start_server(self):
"""
Starts the WebSocket server.
"""
host, port = self._parse_websocket_url(self.endpoint)
self.logger.info(f"Starting WebSocket server on {self.endpoint}...")
try:
self.ws = websockets.create_connection(self.websocket_url)
print("Connected to WebSocket!")
self.server = await websockets.serve(self.handler, host, port)
self.logger.info("WebSocket server started.")
await self.server.wait_closed()
except Exception as e:
print(f"Failed to connect to WebSocket: {e}")
self.logger.error(f"Failed to start WebSocket server: {e}")

def send(self, message: Message) -> Dict[str, Any]:
print("Sending message via WebSocket...")
async def connect_async(self):
"""
Connects to a WebSocket server.
"""
self.logger.info(f"Connecting to WebSocket at {self.endpoint}...")
try:
self.websocket = await websockets.connect(self.endpoint)
self.logger.info("Connected to WebSocket!")
except Exception as e:
self.logger.error(f"Failed to connect to WebSocket: {e}")

async def handler(self, websocket, path):
"""
Handles incoming WebSocket connections.
"""
self.logger.info("Client connected.")
try:
async for message in websocket:
self.logger.info(f"Received message: {message}")
response = json.dumps({"action": "response", "data": {"received": message}})
await websocket.send(response)
except websockets.exceptions.ConnectionClosed as e:
self.logger.info(f"Client disconnected: {e}")

def connect(self):
"""
Starts the server or connects as a client based on the mode.
"""
if self.mode == "server":
asyncio.get_event_loop().run_until_complete(self.start_server())
elif self.mode == "client":
asyncio.get_event_loop().run_until_complete(self.connect_async())
else:
self.logger.error("Invalid mode specified. Use 'client' or 'server'.")

async def send_async(self, message: Message) -> Dict[str, Any]:
"""
Sends a message via WebSocket.
"""
self.logger.info("Sending message via WebSocket...")
try:
self.ws.send(json.dumps(message.__dict__))
print("Message sent successfully!")
await self.websocket.send(json.dumps(message.__dict__))
self.logger.info("Message sent successfully!")
return {"status": "sent"}
except Exception as e:
print(f"Failed to send message: {e}")
self.logger.error(f"Failed to send message: {e}")
return {"error": str(e)}

def receive(self) -> Message:
print("Receiving message via WebSocket...")
def send(self, message: Message) -> Dict[str, Any]:
"""
Unified method to send a message.
"""
if self.mode == "client":
return asyncio.get_event_loop().run_until_complete(self.send_async(message))
else:
self.logger.error("Sending directly from server mode is not supported.")
return {"error": "Invalid operation in server mode"}

async def receive_async(self) -> Message:
"""
Receives a message via WebSocket.
"""
self.logger.info("Receiving message via WebSocket...")
try:
message = self.ws.recv()
print("Message received successfully!")
message = await self.websocket.recv()
self.logger.info("Message received successfully!")
return Message(action="receive", data=json.loads(message))
except Exception as e:
print(f"Failed to receive message: {e}")
self.logger.error(f"Failed to receive message: {e}")
return Message(action="error", data={}, metadata={"error": str(e)})

def receive(self) -> Message:
"""
Unified method to receive a message.
"""
if self.mode == "client":
return asyncio.get_event_loop().run_until_complete(self.receive_async())
else:
self.logger.error("Receiving directly from server mode is not supported.")
return Message(action="error", data={}, metadata={"error": "Invalid operation in server mode"})

async def disconnect_async(self):
"""
Disconnects the WebSocket connection.
"""
self.logger.info("Disconnecting from WebSocket...")
if self.websocket:
await self.websocket.close()

def disconnect(self):
print("Disconnecting from WebSocket...")
if self.ws:
self.ws.close()
"""
Unified method to disconnect.
"""
if self.mode == "client":
asyncio.get_event_loop().run_until_complete(self.disconnect_async())
elif self.mode == "server" and self.server:
self.server.close()
self.logger.info("WebSocket server stopped.")
else:
self.logger.error("No active connection to disconnect.")

@staticmethod
def _parse_websocket_url(endpoint: str):
"""
Parse the WebSocket URL to extract the host and port.
"""
try:
if endpoint.startswith("ws://") or endpoint.startswith("wss://"):
address = endpoint.split("//")[1]
host, port = address.split(":")
return host, int(port)
raise ValueError("Invalid WebSocket URL format")
except Exception as e:
raise ValueError(f"Error parsing WebSocket URL: {e}")
55 changes: 55 additions & 0 deletions tests/test_websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
import asyncio
import logging
from connectiva import Connectiva, Message

class TestWebSocketProtocolWithConnectiva(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set up logging
logging.basicConfig(level=logging.DEBUG)

# Start the WebSocket server using Connectiva
cls.server = Connectiva(
endpoint="ws://localhost:8765",
mode="server",
log=True
)
cls.loop = asyncio.get_event_loop()
cls.loop.run_in_executor(None, cls.server.connect)

# Start the WebSocket client using Connectiva
cls.client = Connectiva(
endpoint="ws://localhost:8765",
mode="client",
log=True
)
cls.loop.run_until_complete(cls.client.connect_async())

@classmethod
def tearDownClass(cls):
# Disconnect both client and server
cls.client.disconnect()
cls.server.disconnect()

def test_send_receive(self):
# Test sending and receiving messages
message = Message(action="send", data={"content": "Hello WebSocket!"})
response = self.client.send(message)
self.assertEqual(response["status"], "sent", "Message should be sent successfully.")

received_message = self.client.receive()
self.assertEqual(received_message.data["received"], message.data["content"], "Received message should match sent message.")

def test_server_client_communication(self):
# Test communication between server and client
message = Message(action="send", data={"content": "Server-Client Test"})
response = self.client.send(message)
self.assertEqual(response["status"], "sent", "Message should be sent successfully.")

received_message = self.client.receive()
self.assertEqual(received_message.data["received"], message.data["content"], "Server should echo the message back to the client.")


if __name__ == "__main__":
unittest.main()

0 comments on commit 8c9ffe6

Please sign in to comment.