Skip to content

Commit

Permalink
[rust]sync client settings periodically (apache#691)
Browse files Browse the repository at this point in the history
* [rust]sync client settings periodically

* fix ugly import

* reuse existing update_settings function

* simplify code

* fix: avoid unwrap in main code

Signed-off-by: Li Zhanhui <[email protected]>

* fix: fix tests

Signed-off-by: Li Zhanhui <[email protected]>

---------

Signed-off-by: Li Zhanhui <[email protected]>
Co-authored-by: Li Zhanhui <[email protected]>
  • Loading branch information
glcrazier and lizhanhui authored Mar 12, 2024
1 parent 95a8863 commit 430ed16
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 123 deletions.
164 changes: 104 additions & 60 deletions rust/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ use parking_lot::Mutex;
use prost_types::Duration;
use slog::{debug, error, info, o, warn, Logger};
use tokio::select;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::Instant;

use crate::conf::ClientOption;
use crate::conf::{ClientOption, SettingsAware};
use crate::error::{ClientError, ErrorKind};
use crate::model::common::{ClientType, Endpoints, Route, RouteStatus, SendReceipt};
use crate::model::message::AckMessageEntry;
Expand All @@ -44,14 +45,14 @@ use crate::pb::{
use crate::session::SessionManager;
use crate::session::{RPCClient, Session};

pub(crate) struct Client {
pub(crate) struct Client<S> {
logger: Logger,
option: ClientOption,
session_manager: Arc<SessionManager>,
route_table: Mutex<HashMap<String /* topic */, RouteStatus>>,
id: String,
access_endpoints: Endpoints,
settings: TelemetryCommand,
settings: Arc<RwLock<S>>,
telemetry_command_tx: Option<mpsc::Sender<pb::telemetry_command::Command>>,
shutdown_tx: Option<oneshot::Sender<()>>,
}
Expand All @@ -68,7 +69,10 @@ const OPERATION_SEND_MESSAGE: &str = "client.send_message";
const OPERATION_RECEIVE_MESSAGE: &str = "client.receive_message";
const OPERATION_ACK_MESSAGE: &str = "client.ack_message";

impl Debug for Client {
impl<S> Debug for Client<S>
where
S: SettingsAware + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("id", &self.id)
Expand All @@ -79,11 +83,14 @@ impl Debug for Client {
}

#[automock]
impl Client {
impl<S> Client<S>
where
S: SettingsAware + 'static + Send + Sync,
{
pub(crate) fn new(
logger: &Logger,
option: ClientOption,
settings: TelemetryCommand,
settings: Arc<RwLock<S>>,
) -> Result<Self, ClientError> {
let id = Self::generate_client_id();
let endpoints = Endpoints::from_url(option.access_url())
Expand Down Expand Up @@ -131,12 +138,16 @@ impl Client {
.await
.map_err(|error| error.with_operation(OPERATION_CLIENT_START))?;

let settings = Arc::clone(&self.settings);
tokio::spawn(async move {
rpc_client.is_started();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
let seconds_30 = std::time::Duration::from_secs(30);
let mut heartbeat_interval = tokio::time::interval(seconds_30);
let mut sync_settings_interval =
tokio::time::interval_at(Instant::now() + seconds_30, seconds_30);
loop {
select! {
_ = interval.tick() => {
_ = heartbeat_interval.tick() => {
let sessions = session_manager.get_all_sessions().await;
if sessions.is_err() {
error!(
Expand All @@ -159,7 +170,7 @@ impl Client {
continue;
}
let result =
Self::handle_response_status(response.unwrap().status, OPERATION_HEARTBEAT);
handle_response_status(response.unwrap().status, OPERATION_HEARTBEAT);
if result.is_err() {
error!(
logger,
Expand All @@ -171,13 +182,34 @@ impl Client {
debug!(logger,"send heartbeat to server success, peer={}",peer);
}
},
_ = sync_settings_interval.tick() => {
let sessions = session_manager.get_all_sessions().await;
if sessions.is_err() {
error!(logger, "sync settings failed: failed to get sessions: {}", sessions.unwrap_err());
continue;
}
for mut session in sessions.unwrap() {
let command;
{
command = settings.read().await.build_telemetry_command();
}
let peer = session.peer().to_string();
let result = session.update_settings(command).await;
if result.is_err() {
error!(logger, "sync settings failed: failed to call rpc: {}", result.unwrap_err());
continue;
}
debug!(logger, "sync settings success, peer = {}", peer);
}

},
_ = &mut shutdown_rx => {
info!(logger, "receive shutdown signal, stop heartbeat task.");
info!(logger, "receive shutdown signal, stop heartbeat and telemetry tasks.");
break;
}
}
}
info!(logger, "heartbeat task is stopped");
info!(logger, "heartbeat and telemetry task were stopped");
});
Ok(())
}
Expand Down Expand Up @@ -206,7 +238,7 @@ impl Client {
resource_namespace: self.option.namespace.to_string(),
});
let response = rpc_client.notify_shutdown(NotifyClientTerminationRequest { group });
Self::handle_response_status(response.await?.status, OPERATION_CLIENT_SHUTDOWN)?;
handle_response_status(response.await?.status, OPERATION_CLIENT_SHUTDOWN)?;
self.session_manager.shutdown().await;
Ok(())
}
Expand Down Expand Up @@ -234,13 +266,17 @@ impl Client {
)
}

async fn build_telemetry_command(&self) -> TelemetryCommand {
self.settings.read().await.build_telemetry_command()
}

pub(crate) async fn get_session(&self) -> Result<Session, ClientError> {
self.check_started(OPERATION_GET_SESSION)?;
let session = self
.session_manager
.get_or_create_session(
&self.access_endpoints,
self.settings.clone(),
self.build_telemetry_command().await,
self.telemetry_command_tx.clone().unwrap(),
)
.await?;
Expand All @@ -255,37 +291,13 @@ impl Client {
.session_manager
.get_or_create_session(
endpoints,
self.settings.clone(),
self.build_telemetry_command().await,
self.telemetry_command_tx.clone().unwrap(),
)
.await?;
Ok(session)
}

pub(crate) fn handle_response_status(
status: Option<Status>,
operation: &'static str,
) -> Result<(), ClientError> {
if status.is_none() {
return Err(ClientError::new(
ErrorKind::Server,
"server do not return status, this may be a bug",
operation,
));
}

let status = status.unwrap();
let status_code = Code::from_i32(status.code).unwrap();
if !status_code.eq(&Code::Ok) {
return Err(
ClientError::new(ErrorKind::Server, "server return an error", operation)
.with_context("code", status_code.as_str_name())
.with_context("message", status.message),
);
}
Ok(())
}

pub(crate) fn topic_route_from_cache(&self, topic: &str) -> Option<Arc<Route>> {
self.route_table.lock().get(topic).and_then(|route_status| {
if let RouteStatus::Found(route) = route_status {
Expand Down Expand Up @@ -325,7 +337,7 @@ impl Client {
};

let response = rpc_client.query_route(request).await?;
Self::handle_response_status(response.status, OPERATION_QUERY_ROUTE)?;
handle_response_status(response.status, OPERATION_QUERY_ROUTE)?;

let route = Route {
index: AtomicUsize::new(0),
Expand Down Expand Up @@ -454,7 +466,7 @@ impl Client {
) -> Result<Vec<SendReceipt>, ClientError> {
let request = SendMessageRequest { messages };
let response = rpc_client.send_message(request).await?;
Self::handle_response_status(response.status, OPERATION_SEND_MESSAGE)?;
handle_response_status(response.status, OPERATION_SEND_MESSAGE)?;

Ok(response
.entries
Expand Down Expand Up @@ -512,7 +524,7 @@ impl Client {
if status.code() == Code::MessageNotFound {
return Ok(vec![]);
}
Self::handle_response_status(Some(status), OPERATION_RECEIVE_MESSAGE)?;
handle_response_status(Some(status), OPERATION_RECEIVE_MESSAGE)?;
}
Content::Message(message) => {
messages.push(message);
Expand Down Expand Up @@ -560,7 +572,7 @@ impl Client {
entries,
};
let response = rpc_client.ack_message(request).await?;
Self::handle_response_status(response.status, OPERATION_ACK_MESSAGE)?;
handle_response_status(response.status, OPERATION_ACK_MESSAGE)?;
Ok(response.entries)
}

Expand Down Expand Up @@ -605,11 +617,31 @@ impl Client {
message_id,
};
let response = rpc_client.change_invisible_duration(request).await?;
Self::handle_response_status(response.status, OPERATION_ACK_MESSAGE)?;
handle_response_status(response.status, OPERATION_ACK_MESSAGE)?;
Ok(response.receipt_handle)
}
}

pub fn handle_response_status(
status: Option<Status>,
operation: &'static str,
) -> Result<(), ClientError> {
let status = status.ok_or(ClientError::new(
ErrorKind::Server,
"server do not return status, this may be a bug",
operation,
))?;

if status.code != Code::Ok as i32 {
return Err(
ClientError::new(ErrorKind::Server, "server return an error", operation)
.with_context("code", format!("{}", status.code))
.with_context("message", status.message),
);
}
Ok(())
}

#[cfg(test)]
pub(crate) mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
Expand All @@ -624,7 +656,6 @@ pub(crate) mod tests {
use crate::error::{ClientError, ErrorKind};
use crate::log::terminal_logger;
use crate::model::common::{ClientType, Route};
use crate::pb::receive_message_response::Content;
use crate::pb::{
AckMessageEntry, AckMessageResponse, ChangeInvisibleDurationResponse, Code,
FilterExpression, HeartbeatResponse, Message, MessageQueue, QueryRouteResponse,
Expand All @@ -637,7 +668,16 @@ pub(crate) mod tests {
// The lock is used to prevent the mocking static function at same time during parallel testing.
pub(crate) static MTX: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));

fn new_client_for_test() -> Client {
#[derive(Default)]
struct MockSettings {}

impl SettingsAware for MockSettings {
fn build_telemetry_command(&self) -> TelemetryCommand {
TelemetryCommand::default()
}
}

fn new_client_for_test() -> Client<MockSettings> {
Client {
logger: terminal_logger(),
option: ClientOption {
Expand All @@ -646,24 +686,24 @@ pub(crate) mod tests {
},
session_manager: Arc::new(SessionManager::default()),
route_table: Mutex::new(HashMap::new()),
id: Client::generate_client_id(),
id: Client::<MockSettings>::generate_client_id(),
access_endpoints: Endpoints::from_url("http://localhost:8081").unwrap(),
settings: TelemetryCommand::default(),
settings: Arc::new(RwLock::new(MockSettings::default())),
telemetry_command_tx: None,
shutdown_tx: None,
}
}

fn new_client_with_session_manager(session_manager: SessionManager) -> Client {
fn new_client_with_session_manager(session_manager: SessionManager) -> Client<MockSettings> {
let (tx, _) = mpsc::channel(16);
Client {
logger: terminal_logger(),
option: ClientOption::default(),
session_manager: Arc::new(session_manager),
route_table: Mutex::new(HashMap::new()),
id: Client::generate_client_id(),
id: Client::<MockSettings>::generate_client_id(),
access_endpoints: Endpoints::from_url("http://localhost:8081").unwrap(),
settings: TelemetryCommand::default(),
settings: Arc::new(RwLock::new(MockSettings::default())),
telemetry_command_tx: Some(tx),
shutdown_tx: None,
}
Expand All @@ -684,7 +724,7 @@ pub(crate) mod tests {
Client::new(
&terminal_logger(),
ClientOption::default(),
TelemetryCommand::default(),
Arc::new(RwLock::new(MockSettings::default())),
)?;
Ok(())
}
Expand Down Expand Up @@ -728,8 +768,8 @@ pub(crate) mod tests {
}

#[test]
fn handle_response_status() {
let result = Client::handle_response_status(None, "test");
fn test_handle_response_status() {
let result = handle_response_status(None, "test");
assert!(result.is_err(), "should return error when status is None");
let result = result.unwrap_err();
assert_eq!(result.kind, ErrorKind::Server);
Expand All @@ -739,7 +779,7 @@ pub(crate) mod tests {
);
assert_eq!(result.operation, "test");

let result = Client::handle_response_status(
let result = handle_response_status(
Some(Status {
code: Code::BadRequest as i32,
message: "test failed".to_string(),
Expand All @@ -757,12 +797,12 @@ pub(crate) mod tests {
assert_eq!(
result.context,
vec![
("code", "BAD_REQUEST".to_string()),
("code", format!("{}", Code::BadRequest as i32)),
("message", "test failed".to_string()),
]
);

let result = Client::handle_response_status(
let result = handle_response_status(
Some(Status {
code: Code::Ok as i32,
message: "test success".to_string(),
Expand Down Expand Up @@ -897,9 +937,13 @@ pub(crate) mod tests {
mock.expect_heartbeat()
.return_once(|_| Box::pin(futures::future::ready(response)));

let send_result =
Client::heart_beat_inner(mock, &Some("group".to_string()), "", &ClientType::Producer)
.await;
let send_result = Client::<MockSettings>::heart_beat_inner(
mock,
&Some("group".to_string()),
"",
&ClientType::Producer,
)
.await;
assert!(send_result.is_ok());
}

Expand Down
Loading

0 comments on commit 430ed16

Please sign in to comment.