diff --git a/src/ws.rs b/src/ws.rs index 08c1bae..b0a54e8 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -3,7 +3,7 @@ use axum::extract::ws::{Message, WebSocket}; use std::{net::SocketAddr}; use std::collections::HashMap; use std::sync::Arc; -use std::sync::mpsc::{channel, Sender}; +use tokio::sync::mpsc::{channel, Sender}; use chrono::{DateTime, Utc}; use redis::aio::Connection; use redis::AsyncCommands; @@ -12,7 +12,7 @@ use crate::{AppState, TASK_DELAY, TASK_WORKING, WebTask}; use futures::{sink::SinkExt, stream::StreamExt}; #[derive(Clone)] -enum WsMsg { +pub enum WsMsg { STR(String), BYT(Vec) } @@ -63,7 +63,9 @@ impl MessageExecutor { if let Some(the_who) = who_is { let mut sender_map = self.sender_map.lock().await; if let Some(mut sender) = sender_map.get(&the_who){ - sender.send(WsMsg::STR(resp)); + if sender.send(WsMsg::STR(resp)).await.is_err(){ + println!("发送错误"); + } sender_map.remove(&the_who); } } @@ -76,12 +78,12 @@ impl MessageExecutor { pub async fn handle_socket(mut state: AppState, mut socket: WebSocket, who: SocketAddr) { let mut redis_conn = state.redis_client.get_async_connection().await.expect("Redis连接失败"); let (mut sender, mut receiver) = socket.split(); - let (mpsc_tx, mpsc_rx) = channel::(); + let (mpsc_tx, mut mpsc_rx) = channel::(1024); state.ws_executor.bind_sender(who, mpsc_tx).await; let mut send_task = tokio::spawn(async move { loop { - if let Ok(msg) = mpsc_rx.recv() { + if let Some(msg) = mpsc_rx.recv().await { if let WsMsg::STR(str_resp) = msg { if sender.send(Message::Text(str_resp)).await.is_err() { println!("向ws client写入消息错误");