Skip to content

Commit

Permalink
ver-0.2.6 add websockets support
Browse files Browse the repository at this point in the history
add websockets support
  • Loading branch information
ipconfiger committed Jan 8, 2024
1 parent 136db5d commit 88161b7
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 9 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = "0.6.19"
axum = { version = "0.6.19", features = ["ws"] }
axum-extra = { version = "0.9.1", features = ["typed-header"] }
clap = "3.2.6"
tokio = { version = "1.35.1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
Expand All @@ -19,7 +20,7 @@ timer = "0.2.0"
lettre = { version = "0.10.4", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
dirs = "3.0.2"
reqwest = { version = "0.11.23", default-features = false, features = ["blocking", "json", "rustls-tls"] }

futures = "0.3"
rdkafka = { version = "0.33.2", default-features = false, features = ["cmake-build"] }

[dev-dependencies]
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ Supported types of job execution include:
| kafka_resp_topic | 返回任务执行结果的Topic | Topic for response execute result |


Both real-time triggering and delayed triggering support HTTP and Kafka multi-channel access. By default, only HTTP is enabled. After setting the correct Kafka prefix parameters, tasks can be received from Kafka.
Both real-time triggering and delayed triggering support HTTP、WebSockets and Kafka multi-channel access. By default, only HTTP is enabled. After setting the correct Kafka prefix parameters, tasks can be received from Kafka.

实时触发和延迟触发均支持 HTTPKafka 多通道接入,默认只开启HTTP,在设置好正确的kafka前缀的参数后,即可从Kafka接收任务
实时触发和延迟触发均支持 HTTPKafka、WebSockets等 多通道接入,默认只开启HTTP和WebSockets,在设置好正确的kafka前缀的参数后,即可从Kafka接收任务

http 接收任务的例子(Sample for http):
curl -X POST http://127.0.0.1:8000/task_in_queue
-H 'Content-Type: application/json'
-d '{...}'

WebSockets连接的地址为:(Address for WebSockets incomming)
ws://ip:port/ws_task

The format of the JSON body for HTTP requests and the message topic for Kafka is consistent, both in JSON format, as defined below:

Expand Down
79 changes: 74 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod runner;
mod ws;

extern crate redis;
use std::sync::{Arc, Mutex};
Expand All @@ -8,11 +9,12 @@ use redis::{AsyncCommands, Commands};
use std::net::SocketAddr;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use axum::{extract::{Json, path::Path as PathParam, State}, Router, routing::{post, get}, response::IntoResponse};
use axum::{extract::{Json, path::Path as PathParam, State, ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}}, Router, routing::{post, get}, response::IntoResponse, ServiceExt};
use axum_extra::{headers, TypedHeader};
use std::{env, thread};

use std::path::{Path, PathBuf};
use std::str::FromStr;
use axum::extract::ConnectInfo;
use tokio::fs::File;
use tokio::io::{self, BufReader, AsyncBufReadExt};
use chrono::{DateTime, Utc, Local};
Expand All @@ -25,14 +27,16 @@ use rdkafka::config::ClientConfig;
use rdkafka::consumer::{Consumer, BaseConsumer};
use rdkafka::producer::{BaseProducer, BaseRecord, Producer};
use rdkafka::Message;
use crate::ws::{handle_socket, MessageExecutor};


#[derive(Clone)]
pub struct AppState {
pub config_path: String,
pub redis_client: redis::Client,
queue: QueueGroup,
pub config: AppConfig
pub config: AppConfig,
pub ws_executor: MessageExecutor
}

#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -206,6 +210,34 @@ impl KafkaProducer {
}
}

#[derive(Clone)]
struct ResponseQueue {
queue: Arc<Mutex<VecDeque<(String, String)>>>
}

impl ResponseQueue {
fn new() -> ResponseQueue {
ResponseQueue{ queue: Arc::new(Mutex::new(VecDeque::new())) }
}

fn queue_resp(&mut self, task_id: String, resp: String) {
while let Ok(mut queue) = self.queue.lock() {
queue.push_back((task_id.clone(), resp.clone()));
}
}

fn wait_for(&mut self) -> Option<(String, String)> {
while let Ok(mut queue) = self.queue.lock() {
if let Some(resp) = queue.pop_front(){
return Some(resp);
}else{
return None;
}
}
None
}
}


const TASK_WRONG: &'static str = "task||wrong";

Expand Down Expand Up @@ -347,11 +379,15 @@ async fn main() {
let mut queue_group = QueueGroup::init_by_number(workers);

let client = redis::Client::open(redis).unwrap();
let mut ws_exec = MessageExecutor::new();

let mut resp_queue = ResponseQueue::new();

for thread_id in 0..workers {
let mut group = queue_group.clone();
let mut redis_connection = client.get_connection().unwrap();
let appconfig = appconfig.clone();
let mut resp_queue1 = resp_queue.clone();
thread::spawn(move || {
let mut pd = KafkaProducer::from_bootstrap(appconfig.kafka_servers.as_str(), appconfig.kafka_resp_topic.clone());
loop {
Expand Down Expand Up @@ -388,6 +424,10 @@ async fn main() {
if task.src_chn == "kafka" {
pd.sent(serde_json::json!({"result": "OK", "request_id": task.id.clone()}));
}
if task.src_chn == "ws" {
//ws_exec.response_for(task.id.clone(), serde_json::json!({"result": "OK", "request_id": task.id.clone()}));
resp_queue1.queue_resp(task_id.clone(), serde_json::to_string(&serde_json::json!({"result": "OK", "request_id": task.id.clone()})).unwrap());
}
}
}
redis_connection.srem::<&str, String, ()>(TASK_WORKING, task.id.clone()).expect("redis error");
Expand Down Expand Up @@ -528,6 +568,18 @@ async fn main() {
});
}

let mut resp_queue2 = resp_queue.clone();
let mut ws_exec1 = ws_exec.clone();
tokio::spawn(async move {
loop {
if let Some((tid, resp_str)) = resp_queue2.wait_for() {
ws_exec1.response_for(tid, resp_str).await;
}else{
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
}
});

if let Ok(mut main_conn) = client.get_async_connection().await {
if let Ok(working_ids) = main_conn.smembers::<&str, Vec<String>>(TASK_WORKING).await {
for tk_id in working_ids {
Expand All @@ -546,17 +598,19 @@ async fn main() {
config_path: cron_path.to_string(),
redis_client: client.clone(),
queue: queue_group.clone(),
config: appconfig.clone()
config: appconfig.clone(),
ws_executor: ws_exec.clone()
};
let app = Router::new()
.route("/task_in_queue", post(handler))
.route("/task_resp/:key", get(waiting))
.route("/ws_task", get(ws_handler))
.route("/sys_info", get(system_info_handler))
.with_state(app_state);

println!("server will start at 0.0.0.0:{}", port);
let serv = axum::Server::bind(& SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), int_port))
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await;
match serv {
Ok(_)=>{
Expand Down Expand Up @@ -625,6 +679,21 @@ async fn waiting_for_result(conn: &mut Connection, flag: String, waiting: usize)
}
}

async fn ws_handler(
State(state): State<AppState>,
ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
let addr = addr.clone();
println!("need to upgrade:{:?}", &addr);
ws.on_upgrade( move |socket| async move {
let socket = socket;
println!("socks {:?} connected, {}", &addr, state.queue.size);
//handle_socket(state.clone(), socket, addr)
handle_socket(state.clone(), socket, addr.clone()).await
})
}

async fn system_info_handler(State(mut state): State<AppState>) -> impl IntoResponse {
let mut conn = state.redis_client.get_async_connection().await.unwrap();
let working_keys = if let Ok(_working_keys) = conn.smembers::<String, Vec<String>>(TASK_WORKING.to_string()).await {
Expand Down
161 changes: 161 additions & 0 deletions src/ws.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::ops::ControlFlow;
use axum::extract::ws::{Message, WebSocket};
use std::{net::SocketAddr};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::mpsc::{channel, Sender};
use chrono::{DateTime, Utc};
use redis::aio::Connection;
use redis::AsyncCommands;
use tokio::sync::Mutex;
use crate::{AppState, TASK_DELAY, TASK_WORKING, WebTask};
use futures::{sink::SinkExt, stream::StreamExt};

#[derive(Clone)]
enum WsMsg {
STR(String),
BYT(Vec<u8>)
}

#[derive(Clone)]
pub struct MessageExecutor {
sender_map: Arc<Mutex<HashMap<SocketAddr, Sender<WsMsg>>>>,
req_map: Arc<Mutex<HashMap<String, SocketAddr>>>,
}

impl MessageExecutor {
pub fn new()->MessageExecutor {
MessageExecutor {
sender_map: Arc::new(Mutex::new(HashMap::new())),
req_map: Arc::new(Mutex::new(HashMap::new())),
}
}

pub async fn bind_sender(&mut self, who: SocketAddr, sender: Sender<WsMsg>) {
let mut sm = self.sender_map.lock().await;
if let Some(rs) = sm.insert(who, sender){
}
}

pub async fn clear_client(&mut self, who: SocketAddr) {
let mut sm = self.sender_map.lock().await;
if let Some(v) = sm.remove(&who) {
}
}

pub async fn bind_request_id(&mut self, task_id: String, who: SocketAddr) {
let mut req_map = self.req_map.lock().await;
if req_map.contains_key(&task_id){
req_map.remove(&task_id);
}
req_map.insert(task_id.clone(), who);
}

pub async fn response_for(&mut self, task_id: String, resp: String) {
let who_is = {
let req_map = self.req_map.lock().await;
if let Some(who) = req_map.get(&task_id) {
Some(who.clone())
} else {
None
}
};
if let Some(the_who) = who_is {
let mut sender_map = self.sender_map.lock().await;
if let Some(mut sender) = sender_map.get(&the_who){
sender.send(WsMsg::STR(resp));
sender_map.remove(&the_who);
}
}
}


}


pub async fn handle_socket(mut state: AppState, mut socket: WebSocket, who: SocketAddr) {
let mut redis_conn = state.redis_client.get_async_connection().await.expect("Redis连接失败");
let (mut sender, mut receiver) = socket.split();
let (mpsc_tx, mpsc_rx) = channel::<WsMsg>();
state.ws_executor.bind_sender(who, mpsc_tx).await;

let mut send_task = tokio::spawn(async move {
loop {
if let Ok(msg) = mpsc_rx.recv() {
if let WsMsg::STR(str_resp) = msg {
if sender.send(Message::Text(str_resp)).await.is_err() {
println!("向ws client写入消息错误");
return 1;
}
}
}else{
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
});

let mut recv_task = tokio::spawn(async move {
let mut cnt = 0;
while let Some(Ok(msg)) = receiver.next().await {
cnt += 1;
// print message and break if instructed to do so
if process_message(msg, who, &mut state, &mut redis_conn).await.is_break() {
break;
}
}
cnt
});
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(a) => println!("{a} messages sent to {who}"),
Err(a) => println!("Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(b) => println!("Received {b} messages"),
Err(b) => println!("Error receiving messages {b:?}")
}
send_task.abort();
}
}
}

async fn process_message(msg: Message, who: SocketAddr, state: &mut AppState, redis_conn: &mut Connection) -> ControlFlow<(), ()> {
let msg_need_to_process = match msg {
Message::Text(t) => Some(t),
Message::Close(c) => {
if let Some(cf) = c {
println!(">>> {} sent close with code {} and reason `{}`", who, cf.code, cf.reason);
} else {
println!(">>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
Message::Pong(v) => None,
Message::Ping(v) => None,
_=>None
};
if let Some(msg_str) = msg_need_to_process {
if let Ok(mut web_task) = serde_json::from_str::<WebTask>(msg_str.as_str()) {
let mut task = web_task.gen_task(state.queue.size);
let now: DateTime<Utc> = Utc::now();
let now_ts = now.timestamp();
task.src_chn = "ws".to_string();
let redis_payload = serde_json::to_string(&task).unwrap();
redis_conn.set::<String, String, ()>(task.id.clone(), redis_payload).await.expect("set error");
state.ws_executor.bind_request_id(task.id.clone(), who.clone()).await;
if web_task.delay == 0 {
redis_conn.sadd::<String, String, ()>(TASK_WORKING.to_string(), task.id.clone()).await.expect("set list error");
println!("开始分配Worker线程");
state.queue.dispatch_task(&task);
} else {
let delay_key = format!("{} {}", task.id, now_ts + web_task.delay as i64);
redis_conn.sadd::<String, String, ()>(TASK_DELAY.to_string(), delay_key.clone()).await.expect("set list error");
}
}
}
ControlFlow::Continue(())
}

0 comments on commit 88161b7

Please sign in to comment.