Skip to content

Commit

Permalink
Merge pull request #6 from tavallaie/websocket
Browse files Browse the repository at this point in the history
ws init with tests
  • Loading branch information
tavallaie authored Aug 4, 2024
2 parents 2c0e638 + bf02156 commit 0fb71c6
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 25 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/test-websocket.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: WebSocket Tests

on:
push:
branches:
- main
paths:
- 'connectiva/protocols/websocket_protocol.py'
- 'tests/test_websocket_protocol.py'
- 'pyproject.toml'
pull_request:
branches:
- main
paths:
- 'connectiva/protocols/websocket_protocol.py'
- 'tests/test_websocket_protocol.py'
- 'pyproject.toml'

jobs:
test-websocket:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
python -m pip install --upgrade pip
python -m pip install poetry
- name: Add Poetry to Path
run: echo "export PATH=\"$HOME/.local/bin:\$PATH\"" >> $GITHUB_ENV

- name: Install dependencies
run: poetry install

- name: Run WebSocket Tests
run: |
echo "Running WebSocket tests on Python ${{ matrix.python-version }}..."
poetry run python -m unittest discover -s tests -p 'test_websocket_protocol.py'
156 changes: 132 additions & 24 deletions connectiva/protocols/websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,155 @@
# connectiva/protocols/websocket_protocol.py

import asyncio
import websockets
import json
from typing import Dict, Any
from ..interfaces import CommunicationMethod
from ..message import Message
import logging
from typing import Dict, Any, Tuple
from connectiva import CommunicationMethod, Message

class WebSocketProtocol(CommunicationMethod):
"""
WebSocket communication class.
WebSocket protocol that can operate as both a server and client.
"""

def __init__(self, **kwargs):
self.websocket_url = kwargs.get("websocket_url")
self.ws = None
self.mode = kwargs.get("mode", "client") # "client" or "server"
self.endpoint = kwargs.get("endpoint", "ws://localhost:8765")
self.logger = logging.getLogger(self.__class__.__name__)
self.websocket = None
self.server = None
self.loop = asyncio.get_event_loop()

def connect(self):
print(f"Connecting to WebSocket at {self.websocket_url}...")
def _parse_websocket_url(self) -> Tuple[str, int]:
"""
Parse the WebSocket URL to extract the host and port.
"""
try:
self.ws = websockets.create_connection(self.websocket_url)
print("Connected to WebSocket!")
if self.endpoint.startswith("ws://") or self.endpoint.startswith("wss://"):
address = self.endpoint.split("//")[1]
host, port = address.split(":")
return host, int(port)
raise ValueError("Invalid WebSocket URL format")
except Exception as e:
print(f"Failed to connect to WebSocket: {e}")
raise ValueError(f"Error parsing WebSocket URL: {e}")

def send(self, message: Message) -> Dict[str, Any]:
print("Sending message via WebSocket...")
async def _start_server(self):
"""
Starts the WebSocket server.
"""
host, port = self._parse_websocket_url()
self.logger.info(f"Starting WebSocket server on {self.endpoint}...")
try:
self.server = await websockets.serve(self._server_handler, host, port)
self.logger.info("WebSocket server started.")
await self.server.wait_closed()
except Exception as e:
self.logger.error(f"Failed to start WebSocket server: {e}")

async def _connect_async(self):
"""
Connects to a WebSocket server.
"""
self.logger.info(f"Connecting to WebSocket at {self.endpoint}...")
try:
self.ws.send(json.dumps(message.__dict__))
print("Message sent successfully!")
self.websocket = await websockets.connect(self.endpoint)
self.logger.info("Connected to WebSocket!")
except Exception as e:
self.logger.error(f"Failed to connect to WebSocket: {e}")

async def _server_handler(self, websocket, path):
"""
Handles incoming WebSocket connections.
"""
self.logger.info("Client connected.")
try:
async for message in websocket:
self.logger.info(f"Received message: {message}")

# Echo the message back with the expected structure
received_data = json.loads(message)
response = json.dumps({
"action": "response",
"data": {"received": received_data["data"]["content"]}
})

await websocket.send(response)
except websockets.exceptions.ConnectionClosed as e:
self.logger.info(f"Client disconnected: {e}")

def connect(self):
"""
Starts the server or connects as a client based on the mode.
"""
if self.mode == "server":
asyncio.run(self._start_server())
elif self.mode == "client":
asyncio.run(self._connect_async())
else:
self.logger.error("Invalid mode specified. Use 'client' or 'server'.")

async def _send_async(self, message: Message) -> Dict[str, Any]:
"""
Sends a message via WebSocket.
"""
self.logger.info("Sending message via WebSocket...")
try:
await self.websocket.send(json.dumps(message.__dict__))
self.logger.info("Message sent successfully!")
return {"status": "sent"}
except Exception as e:
print(f"Failed to send message: {e}")
self.logger.error(f"Failed to send message: {e}")
return {"error": str(e)}

def receive(self) -> Message:
print("Receiving message via WebSocket...")
def send(self, message: Message) -> Dict[str, Any]:
"""
Unified method to send a message.
"""
if self.mode == "client":
return asyncio.run(self._send_async(message))
else:
self.logger.error("Sending directly from server mode is not supported.")
return {"error": "Invalid operation in server mode"}

async def _receive_async(self) -> Message:
"""
Receives a message via WebSocket.
"""
self.logger.info("Receiving message via WebSocket...")
try:
message = self.ws.recv()
print("Message received successfully!")
message = await self.websocket.recv()
self.logger.info("Message received successfully!")
return Message(action="receive", data=json.loads(message))
except Exception as e:
print(f"Failed to receive message: {e}")
self.logger.error(f"Failed to receive message: {e}")
return Message(action="error", data={}, metadata={"error": str(e)})

def receive(self) -> Message:
"""
Unified method to receive a message.
"""
if self.mode == "client":
return asyncio.run(self._receive_async())
else:
self.logger.error("Receiving directly from server mode is not supported.")
return Message(action="error", data={}, metadata={"error": "Invalid operation in server mode"})

async def _disconnect_async(self):
"""
Disconnects the WebSocket connection.
"""
self.logger.info("Disconnecting from WebSocket...")
if self.websocket:
await self.websocket.close()

def disconnect(self):
print("Disconnecting from WebSocket...")
if self.ws:
self.ws.close()
"""
Unified method to disconnect.
"""
if self.mode == "client":
asyncio.run(self._disconnect_async())
elif self.mode == "server" and self.server:
self.server.close()
self.logger.info("WebSocket server stopped.")
else:
self.logger.error("No active connection to disconnect.")
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@ readme = "README.md"
keywords = ["microservices", "communication", "REST", "gRPC", "Kafka", "WebSockets", "GraphQL"]

[tool.poetry.dependencies]
python = "^3.8"
python = ">=3.8.1"
requests = "^2.32.3"
grpcio = "^1.65.4"
pika = "^1.3.2"
websockets = "^12.0"
kafka-python-ng = "^2.2.2"


[tool.poetry.group.dev.dependencies]
ruff = "^0.5.6"
black = "^24.8.0"
flake8 = "^7.1.0"
nest-asyncio = "^1.6.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
65 changes: 65 additions & 0 deletions tests/test_websocket_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# tests/test_websocket_protocol.py

import unittest
import logging
import threading
import nest_asyncio
import asyncio
from connectiva import Connectiva, Message

# Apply the nest_asyncio patch
nest_asyncio.apply()

class TestWebSocketProtocolWithConnectiva(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set up logging
logging.basicConfig(level=logging.DEBUG)

# Start the WebSocket server using Connectiva
cls.server = Connectiva(
endpoint="ws://localhost:8765",
mode="server",
log=True
)
cls.server_thread = threading.Thread(target=cls.server.connect)
cls.server_thread.start()

# Start the WebSocket client using Connectiva
cls.client = Connectiva(
endpoint="ws://localhost:8765",
mode="client",
log=True
)
cls.client.connect()

@classmethod
def tearDownClass(cls):
# Disconnect both client and server
cls.client.disconnect()
cls.server.disconnect()
cls.server_thread.join()

def test_send_receive(self):
# Test sending and receiving messages
message = Message(action="send", data={"content": "Hello WebSocket!"})
response = self.client.send(message)
self.assertEqual(response["status"], "sent", "Message should be sent successfully.")

received_message = self.client.receive()
# Check for the 'received' key in the echoed message
self.assertEqual(received_message.data["data"]["received"], message.data["content"], "Received message should match sent message.")

def test_server_client_communication(self):
# Test communication between server and client
message = Message(action="send", data={"content": "Server-Client Test"})
response = self.client.send(message)
self.assertEqual(response["status"], "sent", "Message should be sent successfully.")

received_message = self.client.receive()
# Check for the 'received' key in the echoed message
self.assertEqual(received_message.data["data"]["received"], message.data["content"], "Server should echo the message back to the client.")


if __name__ == "__main__":
unittest.main()

0 comments on commit 0fb71c6

Please sign in to comment.