Skip to content

Commit

Permalink
Abstract out the network transport format again
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed Feb 20, 2025
1 parent 8529d37 commit f751949
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
9 changes: 8 additions & 1 deletion cylc/flow/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,14 @@ class ResponseDict(TypedDict, total=False):
back-compat issues."""


def load_server_response(message: str) -> 'ResponseDict':
def stringify(data: object) -> str:
"""Convert the structure holding a message to a JSON message string."""
# Abstract out the transport format in order to allow it to be changed
# in future.
return json.dumps(data)


def parse(message: str) -> 'ResponseDict':
"""Convert a JSON message string to dict with an added 'user' field."""
msg = json.loads(message)
if 'user' not in msg:
Expand Down
10 changes: 5 additions & 5 deletions cylc/flow/network/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
abstractmethod,
)
import asyncio
import json
import os
from shutil import which
import socket
Expand Down Expand Up @@ -51,7 +50,8 @@
from cylc.flow.network import (
ZMQSocketBase,
get_location,
load_server_response,
parse,
stringify,
)
from cylc.flow.network.client_factory import CommsMeth
from cylc.flow.network.server import PB_METHOD_MAP
Expand Down Expand Up @@ -306,7 +306,7 @@ async def async_request(
if req_meta:
msg['meta'].update(req_meta)
LOG.debug('zmq:send %s', msg)
message = json.dumps(msg)
message = stringify(msg)
self.socket.send_string(message)

# receive response
Expand All @@ -326,7 +326,7 @@ async def async_request(
if command in PB_METHOD_MAP:
return res

response: ResponseDict = load_server_response(res.decode())
response: ResponseDict = parse(res.decode())

try:
return response['data']
Expand All @@ -338,7 +338,7 @@ async def async_request(
f"{response}"
)
wflow_cylc_ver = response.get('cylc_version')
if wflow_cylc_ver:
if wflow_cylc_ver and wflow_cylc_ver != CYLC_VERSION:
error += (
f"\n(Workflow is running in Cylc {wflow_cylc_ver})"
)
Expand Down
10 changes: 5 additions & 5 deletions cylc/flow/network/replier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Server for workflow runtime API."""

import json
from queue import Queue
from typing import (
TYPE_CHECKING,
Expand All @@ -30,7 +29,8 @@
)
from cylc.flow.network import (
ZMQSocketBase,
load_server_response,
parse,
stringify,
)


Expand Down Expand Up @@ -116,7 +116,7 @@ def listener(self) -> None:
res: ResponseDict
response: bytes
try:
message = load_server_response(msg)
message = parse(msg)
except Exception as exc: # purposefully catch generic exception
# failed to decode message, possibly resulting from failed
# authentication
Expand All @@ -126,7 +126,7 @@ def listener(self) -> None:
'error': {'message': str(exc)},
'cylc_version': CYLC_VERSION,
}
response = json.dumps(res).encode()
response = stringify(res).encode()
else:
# success case - serve the request
res = self.server.receiver(message)
Expand All @@ -135,5 +135,5 @@ def listener(self) -> None:
if isinstance(data, bytes):
response = data
else:
response = json.dumps(res).encode()
response = stringify(res).encode()
self.socket.send(response) # type: ignore[union-attr]
4 changes: 2 additions & 2 deletions tests/integration/network/test_replier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest

from cylc.flow import __version__ as CYLC_VERSION
from cylc.flow.network import load_server_response
from cylc.flow.network import parse
from cylc.flow.network.client import WorkflowRuntimeClient
from cylc.flow.scheduler import Scheduler

Expand All @@ -33,7 +33,7 @@ async def test_listener(one: Scheduler, start):
# (without directly calling listener):
client = WorkflowRuntimeClient(one.workflow)
client.socket.send_string(r'Not JSON')
res = load_server_response(
res = parse(
(await client.socket.recv()).decode()
)
assert res['error']
Expand Down

0 comments on commit f751949

Please sign in to comment.