Skip to content

Commit

Permalink
Fix the surrogate model update (was not triggered before) and fix issue
Browse files Browse the repository at this point in the history
#83 (empty RabbitMQ exchange and/or routing key fields led to AMSlib crashing)

Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Feb 28, 2025
1 parent 1347ce6 commit 27db0eb
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 75 deletions.
11 changes: 11 additions & 0 deletions src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,17 @@ class AMSWrap
if (rmq_entry.contains("rabbitmq-cert"))
rmq_cert = getEntry<std::string>(rmq_entry, "rabbitmq-cert");

CFATAL(AMS,
(exchange == "" || routing_key == "") && update_surrogate,
"Found empty RMQ exchange / routing-key, model update is not possible. "
"Please provide a RMQ exchange or deactivate surrogate model "
"update.")

if(exchange == "" || routing_key == "") {
WARNING(AMS, "Found empty RMQ exchange or routing-key, deactivating model update")
update_surrogate = false;
}

auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_rmq_db(port,
host,
Expand Down
50 changes: 29 additions & 21 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1593,11 +1593,13 @@ class RMQInterface
std::shared_ptr<RMQConsumer> _consumer;
/** @brief Thread in charge of the consumer */
std::thread _consumer_thread;
/** @brief True if connected to RabbitMQ */
bool connected;
/** @brief True if publisher is connected to RabbitMQ */
bool _publisher_connected;
/** @brief True if consumer is connected to RabbitMQ */
bool _consumer_connected;

public:
RMQInterface() : connected(false), _rId(0) {}
RMQInterface() : _publisher_connected(false), _consumer_connected(false), _rId(0) {}

/**
* @brief Connect to a RabbitMQ server
Expand All @@ -1612,9 +1614,9 @@ class RMQInterface
* @param[in] outbound_queue Name of the queue on which AMSlib publishes (send) messages
* @param[in] exchange Exchange for incoming messages
* @param[in] routing_key Routing key for incoming messages (must match what the AMS Python side is using)
* @return True if connection succeeded
* @return True, True if connection succeeded for both publisher/consumer
*/
bool connect(std::string rmq_name,
std::pair<bool, bool> connect(std::string rmq_name,
std::string rmq_password,
std::string rmq_user,
std::string rmq_vhost,
Expand All @@ -1623,13 +1625,29 @@ class RMQInterface
std::string rmq_cert,
std::string outbound_queue,
std::string exchange,
std::string routing_key);
std::string routing_key,
bool update_surrogate);

/**
* @brief Check if the RabbitMQ connection is connected.
* @return True if connected
*/
bool isConnected() const { return connected; }
bool isPublisherConnected() const { return _publisher_connected; }

/**
* @brief Check if the RabbitMQ connection is connected.
* @return True if connected
*/
bool isConsumerConnected() const { return _consumer_connected; }

/**
* @brief Check if at least one RabbitMQ connection is connected.
* @return True if connected
*/
bool isConnected() const
{
return isPublisherConnected() || isConsumerConnected();
}

/**
* @brief Set the internal ID of the interface (usually MPI rank).
Expand Down Expand Up @@ -1666,18 +1684,7 @@ class RMQInterface
CALIPER(CALI_MARK_BEGIN("STORE_RMQ");)
AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs);

if (!_publisher->connectionValid()) {
connected = false;
restartPublisher();
bool status = _publisher->waitToEstablish(100, 10);
if (!status) {
_publisher->stop();
_publisher_thread.join();
FATAL(RMQInterface,
"Could not establish publisher RabbitMQ connection");
}
connected = true;
}
if (!_publisher->connectionValid()) restartPublisher();
_publisher->publish(std::move(msg));
_msg_tag++;
CALIPER(CALI_MARK_END("STORE_RMQ");)
Expand Down Expand Up @@ -1719,7 +1726,7 @@ class RMQInterface

~RMQInterface()
{
if (connected) close();
if (isConnected()) close();
}
};

Expand Down Expand Up @@ -2098,7 +2105,8 @@ class DBManager
rmq_cert,
outbound_queue,
exchange,
routing_key);
routing_key,
update_surrogate);
#else
FATAL(DBManager,
"Requsted RMQ database but AMS is not built with such support "
Expand Down
123 changes: 69 additions & 54 deletions src/AMSlib/wf/rmqdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -917,16 +917,17 @@ int RMQPublisher::msgAcknowledged() const

bool RMQPublisher::close(unsigned ms, int repeat)
{
_handler->flush();
_connection->close(false);
return _handler->waitToClose(ms, repeat);
if (_handler) _handler->flush();
if (_connection) _connection->close(false);
if (_handler) return _handler->waitToClose(ms, repeat);
return false;
}

/**
* RMQInterface
*/

bool RMQInterface::connect(std::string rmq_name,
std::pair<bool, bool> RMQInterface::connect(std::string rmq_name,
std::string rmq_password,
std::string rmq_user,
std::string rmq_vhost,
Expand All @@ -935,7 +936,8 @@ bool RMQInterface::connect(std::string rmq_name,
std::string rmq_cert,
std::string outbound_queue,
std::string exchange,
std::string routing_key)
std::string routing_key,
bool update_surrogate)
{
_queue_sender = outbound_queue;
_exchange = exchange;
Expand Down Expand Up @@ -967,77 +969,90 @@ bool RMQInterface::connect(std::string rmq_name,
_publisher_thread.join();
FATAL(RabbitMQInterface, "Could not establish connection");
}
_publisher_connected = true;

_consumer = std::make_shared<RMQConsumer>(
_rId, *_address, _cacert, _exchange, _routing_key);
_consumer_thread = std::thread([&]() { _consumer->start(); });
if (update_surrogate) {
_consumer = std::make_shared<RMQConsumer>(
_rId, *_address, _cacert, _exchange, _routing_key);
_consumer_thread = std::thread([&]() { _consumer->start(); });

if (!_consumer->waitToEstablish(100, 10)) {
_consumer->stop();
_consumer_thread.join();
FATAL(RabbitMQDB, "Could not establish consumer connection");
if (!_consumer->waitToEstablish(100, 10)) {
_consumer->stop();
_consumer_thread.join();
FATAL(RabbitMQDB, "Could not establish consumer connection");
}
_consumer_connected = true;
}

connected = true;
return connected;
return std::make_pair(_publisher_connected, _consumer_connected);
}

void RMQInterface::restartPublisher()
{
CALIPER(CALI_MARK_BEGIN("RMQ_RESTART_PUBLISHER");)
std::vector<AMSMessage> messages = _publisher->getMsgBuffer();

AMSMessage& msg_min =
*(std::min_element(messages.begin(),
messages.end(),
[](const AMSMessage& a, const AMSMessage& b) {
return a.id() < b.id();
}));
if (_publisher->connectionValid()) return;

DBG(RMQPublisher,
"[r%d] we have %lu buffered messages that will get re-send "
"(starting from msg #%d).",
_rId,
messages.size(),
msg_min.id())
CALIPER(CALI_MARK_BEGIN("RMQ_RESTART_PUBLISHER");)
std::vector<AMSMessage> messages = _publisher->getMsgBuffer();
_publisher_connected = false;
if (messages.size() > 0) {
AMSMessage& msg_min =
*(std::min_element(messages.begin(),
messages.end(),
[](const AMSMessage& a, const AMSMessage& b) {
return a.id() < b.id();
}));

DBG(RMQInterface,
"[r%d] we have %lu buffered messages that will get re-send "
"(starting from msg #%d).",
_rId,
messages.size(),
msg_min.id())
}

// Stop the faulty publisher
_publisher->close(100, 10);
_publisher->stop();
_publisher_thread.join();
if (_publisher_thread.joinable()) _publisher_thread.join();
_publisher.reset();
connected = false;

_publisher = std::make_shared<RMQPublisher>(
_rId, *_address, _cacert, _queue_sender, std::move(messages));
_publisher_thread = std::thread([&]() { _publisher->start(); });
connected = true;

if (!_publisher->waitToEstablish(100, 10)) {
_publisher->stop();
if (_publisher_thread.joinable()) _publisher_thread.join();
FATAL(RMQInterface, "Could not re-establish publisher connection (timeout)");
}
_publisher_connected = true;
CALIPER(CALI_MARK_END("RMQ_RESTART_PUBLISHER");)
}

void RMQInterface::close()
{
if (!_publisher_thread.joinable() || !_consumer_thread.joinable()) {
DBG(RMQInterface, "Threads are not joinable")
return;
if (isPublisherConnected()) {
bool status = _publisher->close(100, 10);
CWARNING(RMQInterface,
!status,
"Could not gracefully close publisher TCP connection")

DBG(RMQInterface, "Number of messages sent: %d", _msg_tag)
DBG(RMQInterface,
"Number of unacknowledged messages are %d",
_publisher->unacknowledged())
_publisher->stop();
if (_publisher_thread.joinable()) _publisher_thread.join();
_publisher_connected = false;
}
bool status = _publisher->close(100, 10);
CWARNING(RabbitMQDB,
!status,
"Could not gracefully close publisher TCP connection")

DBG(RabbitMQInterface, "Number of messages sent: %d", _msg_tag)
DBG(RabbitMQInterface,
"Number of unacknowledged messages are %d",
_publisher->unacknowledged())
_publisher->stop();
_publisher_thread.join();

status = _consumer->close(100, 10);
CWARNING(RabbitMQDB,
!status,
"Could not gracefully close consumer TCP connection")
_consumer->stop();
_consumer_thread.join();

connected = false;
if (isConsumerConnected()) {
bool status = _consumer->close(100, 10);
CWARNING(RabbitMQDB,
!status,
"Could not gracefully close consumer TCP connection")
_consumer->stop();
if (_consumer_thread.joinable()) _consumer_thread.join();
_consumer_connected = false;
}
}

0 comments on commit 27db0eb

Please sign in to comment.