Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RabbitMQ tests to GitHub CI #85

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
895 changes: 554 additions & 341 deletions .github/workflows/ci.yml

Large diffs are not rendered by default.

97 changes: 76 additions & 21 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import json
import pika


class AMSMessage(object):
"""
Represents a RabbitMQ incoming message from AMSLib.
Expand All @@ -28,6 +27,24 @@ class AMSMessage(object):
def __init__(self, body: str):
self.body = body

self.num_elements = None
self.hsize = None
self.dtype_byte = None
self.mpi_rank = None
self.domain_name_size = None
self.domain_names = []
self.input_dim = None
self.output_dim = None

def __str__(self):
dt = "float" if self.dtype_byte == 4 else 8
if not self.dtype_byte:
dt = None
return f"AMSMessage(domain={self.domain_names}, #mpi={self.mpi_rank}, num_elements={self.num_elements}, datatype={dt}, input_dim={self.input_dim}, output_dim={self.output_dim})"

def __repr__(self):
return self.__str__()

def header_format(self) -> str:
"""
This string represents the AMS format in Python pack format:
Expand Down Expand Up @@ -110,6 +127,15 @@ def _parse_header(self, body: str) -> dict:
res["dsize"] = int(res["datatype"]) * int(res["num_element"]) * (int(res["input_dim"]) + int(res["output_dim"]))
res["msg_size"] = hsize + res["dsize"]
res["multiple_msg"] = len(body) != res["msg_size"]

self.num_elements = int(res["num_element"])
self.hsize = int(res["hsize"])
self.dtype_byte = int(res["datatype"])
self.mpi_rank = int(res["mpirank"])
self.domain_name_size = int(res["domain_size"])
self.input_dim = int(res["input_dim"])
self.output_dim = int(res["output_dim"])

return res

def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.array]:
Expand Down Expand Up @@ -144,30 +170,37 @@ def _decode(self, body: str) -> Tuple[np.array]:
input = []
output = []
# Multiple AMS messages could be packed in one RMQ message
# TODO: we should manage potential mutliple messages per AMSMessage better
while body:
header_info = self._parse_header(body)
domain_name, temp_input, temp_output = self._parse_data(body, header_info)
# print(f"MSG: {domain_name} input shape {temp_input.shape} outpute shape {temp_output.shape}")
# total size of byte we read for that message
chunk_size = header_info["hsize"] + header_info["dsize"] + header_info["domain_size"]
input.append(temp_input)
output.append(temp_output)
# We remove the current message and keep going
body = body[chunk_size:]
self.domain_names.append(domain_name)
return domain_name, np.concatenate(input), np.concatenate(output)

def decode(self) -> Tuple[str, np.array, np.array]:
return self._decode(self.body)

def default_ams_callback(method, properties, body):
"""Simple callback that decode incoming message assuming they are AMS binary messages"""
return AMSMessage(body)

class AMSChannel:
"""
A wrapper around Pika RabbitMQ channel
"""

def __init__(self, connection, q_name, logger: logging.Logger = None):
def __init__(self, connection, q_name, callback: Optional[Callable] = None, logger: Optional[logging.Logger] = None):
self.connection = connection
self.q_name = q_name
self.logger = logger if logger else logging.getLogger(__name__)
self.callback = callback if callback else self.default_callback

def __enter__(self):
self.open()
Expand All @@ -176,9 +209,9 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

@staticmethod
def callback(method, properties, body):
return body.decode("utf-8")
def default_callback(self, method, properties, body):
""" Simple callback that return the message received"""
return body

def open(self):
self.channel = self.connection.channel()
Expand All @@ -187,18 +220,19 @@ def open(self):
def close(self):
self.channel.close()

def receive(self, n_msg: int = None, accum_msg=list()):
def receive(self, n_msg: int = None, timeout: int = None, accum_msg = list()):
"""
Consume a message on the queue and post processing by calling the callback.
@param n_msg The number of messages to receive.
- if n_msg is None, this call will block for ever and will process all messages that arrives
- if n_msg = 1 for example, this function will block until one message has been processed.
@param timeout If None, timout infinite, otherwise timeout in seconds
@return a list containing all received messages
"""

if self.channel and self.channel.is_open:
self.logger.info(
f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..."
f"Starting to consume messages from queue={self.q_name} ..."
)
# we will consume only n_msg and requeue all other messages
# if there are more messages in the queue.
Expand All @@ -207,11 +241,15 @@ def receive(self, n_msg: int = None, accum_msg=list()):
n_msg = max(n_msg, 0)
message_consumed = 0
# Comsume n_msg messages and break out
for method_frame, properties, body in self.channel.consume(self.q_name):
for method_frame, properties, body in self.channel.consume(self.q_name, inactivity_timeout=timeout):
if (method_frame, properties, body) == (None, None, None):
self.logger.info(f"Timeout after {timeout} seconds")
self.channel.cancel()
break
# Call the call on the message parts
try:
accum_msg.append(
BlockingClient.callback(
self.callback(
method_frame,
properties,
body,
Expand All @@ -223,23 +261,24 @@ def receive(self, n_msg: int = None, accum_msg=list()):
finally:
# Acknowledge the message even on failure
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
message_consumed += 1
self.logger.warning(
f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})"
f"Consumed message {message_consumed}/{method_frame.delivery_tag} (exchange=\'{method_frame.exchange}\', routing_key={method_frame.routing_key})"
)
message_consumed += 1
# Escape out of the loop after nb_msg messages
if message_consumed == n_msg:
# Cancel the consumer and return any pending messages
self.channel.cancel()
break
return accum_msg

def send(self, text: str):
def send(self, text: str, exchange : str = ""):
"""
Send a message
@param text The text to send
@param exchange Exchange to use
"""
self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text)
self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text)
return

def get_messages(self):
Expand All @@ -250,26 +289,42 @@ def purge(self):
if self.channel and self.channel.is_open:
self.channel.queue_purge(self.q_name)


class BlockingClient:
"""
BlockingClient is a class that manages a simple blocking RMQ client lifecycle.
"""

def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None):
def __init__(
self,
host: str,
port: int,
vhost: str,
user: str,
password: str,
cert: Optional[str] = None,
callback: Optional[Callable] = None,
logger: Optional[logging.Logger] = None
):
# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file):
# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt
self.logger = logger if logger else logging.getLogger(__name__)
self.cert = cert
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)

if self.cert is None or self.cert == "":
ssl_options = None
else:
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.check_hostname = False
self.context.load_verify_locations(self.cert)
ssl_options = pika.SSLOptions(self.context)

self.host = host
self.vhost = vhost
self.port = port
self.user = user
self.password = password
self.callback = callback

self.credentials = pika.PlainCredentials(self.user, self.password)

Expand All @@ -278,7 +333,7 @@ def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logg
port=self.port,
virtual_host=self.vhost,
credentials=self.credentials,
ssl_options=pika.SSLOptions(self.context),
ssl_options=ssl_options,
)

def __enter__(self):
Expand All @@ -290,7 +345,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def connect(self, queue):
"""Connect to the queue"""
return AMSChannel(self.connection, queue)
return AMSChannel(self.connection, queue, self.callback)


class AsyncConsumer(object):
Expand Down
18 changes: 18 additions & 0 deletions src/AMSWorkflow/ams/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import datetime
import socket
import subprocess
import uuid
from pathlib import Path

from typing import Tuple

def get_unique_fn():
# Randomly generate the output file name. We use the uuid4 function with the socket name and the current
Expand All @@ -20,6 +22,22 @@ def get_unique_fn():
]
return "_".join(fn)

def generate_tls_certificate(host: str, port: int) -> Tuple[bool,str]:
"""Generate TLS certificate for RabbitMQ

:param str host: The RabbitMQ hostname
:param int port: The RabbitMQ port

:rtype: Tuple[bool,str]
:return: return a tuple with a boolean set to True if certificate got generated and the TLS certificate (other contains stderr)
"""
openssl = subprocess.run(["openssl", "s_client", "-connect", f"{host}:{port}", "-showcerts"], check=True, capture_output=True)
if openssl.returncode != 0:
return False, openssl.stderr.decode().strip()
sed = subprocess.run(["sed", "-ne", r"/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p"], input=openssl.stdout, capture_output=True)
if sed.returncode != 0:
return False, sed.stderr.decode().strip()
return True, sed.stdout.decode().strip()

def mkdir(root_path, fn):
_tmp = root_path / Path(fn)
Expand Down
33 changes: 31 additions & 2 deletions tests/AMSlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ function(JSON_TESTS db_type)
unset(JSON_FP)
endfunction()

function(CHECK_RMQ_CONFIG file)
# Read the JSON file.
file(READ ${file} MY_JSON_STRING)
message(STATUS "RabbitMQ config ${file}")

string(JSON DB_CONF GET ${MY_JSON_STRING} db)
string(JSON DB_CONF GET ${DB_CONF} rmq_config)
string(JSON RMQ_HOST GET ${DB_CONF} "service-host")
string(JSON RMQ_PORT GET ${DB_CONF} "service-port")

if(NOT "${RMQ_HOST}" STREQUAL "" AND NOT "${RMQ_PORT}" STREQUAL "0")
message(STATUS "RabbitMQ config ${file}: ${RMQ_HOST}:${RMQ_PORT}")
else()
message(WARNING "RabbitMQ config file ${file} looks empty! Make sure to fill these fields before running the tests")
endif()
endfunction()

function(INTEGRATION_TEST_ENV)
JSON_TESTS("csv")
Expand All @@ -43,12 +59,23 @@ function(INTEGRATION_TEST_ENV)
add_test(NAME AMSEndToEndFromJSON::DuqMean::DuqMax::Double::DB::hdf5-debug::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_end_to_end_env 0 8 9 \"double\" 1 1024 app_uq_mean_debug app_uq_max_debug;AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_ete.py 0 8 9 \"double\" 1024 app_uq_mean_debug app_uq_max_debug")
unset(JSON_FP)
endif()
endfunction()

function(INTEGRATION_TEST_RMQ)
if (WITH_RMQ)
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/json_configs/rmq.json.in" "rmq.json" @ONLY)
if(EXISTS "${CMAKE_CURRENT_BINARY_DIR}/rmq.json")
# If file exists we do not overwrite it
message(STATUS "Ctest will use ${CMAKE_CURRENT_BINARY_DIR}/rmq.json as RabbitMQ configuration for testing. Make sure RabbitMQ parameters are valid.")
else()
message(STATUS "Copying empty configuration to ${CMAKE_CURRENT_BINARY_DIR}/rmq.json")
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/json_configs/rmq.json.in" "rmq.json" @ONLY)
endif()
set(JSON_FP "${CMAKE_CURRENT_BINARY_DIR}/rmq.json")
CHECK_RMQ_CONFIG(${JSON_FP})
add_test(NAME AMSEndToEndFromJSON::NoModel::Double::DB::rmq::HOST COMMAND bash -c "AMS_OBJECTS=${JSON_FP} ${CMAKE_CURRENT_BINARY_DIR}/ams_rmq 0 8 9 \"double\" 2 1024; AMS_OBJECTS=${JSON_FP} python3 ${CMAKE_CURRENT_SOURCE_DIR}/verify_rmq.py 0 8 9 \"double\" 2 1024")
endif()
endfunction()


function (INTEGRATION_TEST)
#######################################################
# TEST: output format
Expand Down Expand Up @@ -186,6 +213,8 @@ endif()
INTEGRATION_TEST()
BUILD_TEST(ams_end_to_end_env ams_ete_env.cpp)
INTEGRATION_TEST_ENV()
BUILD_TEST(ams_rmq ams_rmq_env.cpp)
INTEGRATION_TEST_RMQ()


# UQ Tests
Expand Down
Loading
Loading