From cc7e084c31473f714d953969d17192a55bbb5761 Mon Sep 17 00:00:00 2001 From: Fanda Vacek Date: Tue, 3 Sep 2024 14:51:43 +0200 Subject: [PATCH] Implement tunneling support --- src/main.rs | 207 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 184 insertions(+), 23 deletions(-) diff --git a/src/main.rs b/src/main.rs index 736da92..27dce02 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,27 +1,32 @@ +use std::collections::BTreeMap; +use std::future::Future; use async_std::io::BufReader; -use async_std::net::TcpStream; +use async_std::net::{TcpListener, TcpStream}; use async_std::os::unix::net::UnixStream; -use async_std::{io, task}; +use async_std::{channel, io, task}; use clap::Parser; use futures::io::BufWriter; -use futures::AsyncReadExt; +use futures::{select, AsyncReadExt, FutureExt}; use futures::AsyncWriteExt; use log::*; use shvrpc::client::LoginParams; use shvrpc::framerw::{FrameReader, FrameWriter}; -use shvrpc::rpcmessage::RqId; +use shvrpc::rpcmessage::{RqId}; use shvrpc::serialrw::{SerialFrameReader, SerialFrameWriter}; use shvrpc::streamrw::{StreamFrameReader, StreamFrameWriter}; use shvrpc::util::{login_from_url, parse_log_verbosity}; use shvrpc::{client, RpcMessage, RpcMessageMetaTags}; use simple_logger::SimpleLogger; use url::Url; +use async_std::channel::{Sender}; +use async_std::stream::StreamExt; +use shvrpc::rpcframe::RpcFrame; #[cfg(feature = "readline")] use crossterm::tty::IsTty; #[cfg(feature = "readline")] use rustyline_async::ReadlineEvent; -use shvproto::RpcValue; +use shvproto::{Map, RpcValue}; #[cfg(feature = "readline")] use std::io::Write; @@ -43,6 +48,9 @@ struct Opts { /// Output format: [ cpon | chainpack | simple | value | "Placeholders {PATH} {METHOD} {VALUE} in any number and combination in custom string." ] #[arg(short = 'o', long = "output-format", default_value = "cpon")] output_format: String, + /// Create TCP tunnel, SSH like syntax, example: -L 2222:some.host.org:22 + #[arg(short = 'L', long)] + tunnel: Option, /// Verbose mode (module, .) #[arg(short, long)] verbose: Option, @@ -74,6 +82,16 @@ fn is_tty() -> bool { #[cfg(not(feature = "readline"))] return false; } +fn spawn_and_log_error(fut: F) -> task::JoinHandle<()> +where + F: Future> + Send + 'static, +{ + task::spawn(async move { + if let Err(e) = fut.await { + error!("{}", e) + } + }) +} pub(crate) fn main() -> Result { let opts = Opts::parse(); @@ -94,10 +112,9 @@ pub(crate) fn main() -> Result { // let rpc_timeout = Duration::from_millis(DEFAULT_RPC_TIMEOUT_MSEC); let url = Url::parse(&opts.url)?; - task::block_on(try_main(&url, &opts)) + task::block_on(try_main(&url, opts)) } - -async fn make_call(url: &Url, opts: &Opts) -> Result { +async fn login(url: &Url) -> shvrpc::Result<(BoxedFrameReader, BoxedFrameWriter)> { // Establish a connection let mut reset_session = false; let (mut frame_reader, mut frame_writer) = match url.scheme() { @@ -153,6 +170,22 @@ async fn make_call(url: &Url, opts: &Opts) -> Result { //frame_writer.send_frame(frame.expect("frame")).await?; client::login(&mut *frame_reader, &mut *frame_writer, &login_params).await?; info!("Connected to broker."); + Ok((frame_reader, frame_writer)) +} +async fn send_request( + frame_writer: &mut (dyn FrameWriter + Send), + path: &str, + method: &str, + param: &str, +) -> shvrpc::Result { + let param = if param.is_empty() { + None + } else { + Some(RpcValue::from_cpon(param)?) + }; + frame_writer.send_request(path, method, param).await +} +async fn make_call(mut frame_reader: BoxedFrameReader, mut frame_writer: BoxedFrameWriter, opts: &Opts) -> Result { async fn print_resp( stdout: &mut io::Stdout, resp: &RpcMessage, @@ -222,19 +255,6 @@ async fn make_call(url: &Url, opts: &Opts) -> Result { Ok(stdout.flush().await?) } - async fn send_request( - frame_writer: &mut (dyn FrameWriter + Send), - path: &str, - method: &str, - param: &str, - ) -> shvrpc::Result { - let param = if param.is_empty() { - None - } else { - Some(RpcValue::from_cpon(param)?) - }; - frame_writer.send_request(path, method, param).await - } fn parse_line(line: &str) -> std::result::Result<(&str, &str, &str), String> { let line = line.trim(); let method_ix = match line.find(':') { @@ -365,9 +385,150 @@ async fn make_call(url: &Url, opts: &Opts) -> Result { Ok(()) } +async fn make_tunnel(mut frame_reader: BoxedFrameReader, mut frame_writer: BoxedFrameWriter, opts: &Opts) -> Result { + let mut tunnel = opts.tunnel.as_ref().unwrap().split(':'); + let local_port = tunnel.next().ok_or("Local port must be specified")?; + let host = tunnel.next().ok_or("Host must be specified")?; + let remote_port = tunnel.next().ok_or("Remote port must be specified")?; + let host = format!("{host}:{remote_port}"); + let local_port = local_port.parse::()?; + enum RpcReaderCmd { + RegisterResponse(RqId, Sender, bool), + UnregisterResponse(RqId), + } + let (reader_cmd_sender, reader_cmd_receiver) = channel::unbounded::(); + spawn_and_log_error(async move { + struct PendingCall { + sender: Sender, + one_shot: bool, + } + let mut pending_calls = BTreeMap::::new(); + let mut get_frame_fut = frame_reader.receive_frame().fuse(); + loop { + select! { + frame = get_frame_fut => { + match frame { + Ok(frame) => { + let rqid = frame.request_id().unwrap_or_default(); + let drop_it = if let Some(pc) = pending_calls.get(&rqid) { + pc.sender.send(frame).await?; + pc.one_shot + } else { + false + }; + if drop_it { + pending_calls.remove(&rqid); + } + drop(get_frame_fut); + get_frame_fut = frame_reader.receive_frame().fuse(); + } + Err(e) => { + info!("RPC socket read error: {e}"); + break; + } + } + } + msg = reader_cmd_receiver.recv().fuse() => { + match msg { + Ok(msg) => { + match msg { + RpcReaderCmd::RegisterResponse(rqid, sender, one_shot) => { + pending_calls.insert(rqid, PendingCall {sender, one_shot}); + } + RpcReaderCmd::UnregisterResponse(rqid) => { + pending_calls.remove(&rqid); + } + } + } + Err(e) => { + error!("Read get frame message error: {e}"); + break; + } + } + } + } + } + shvrpc::Result::Ok(()) + }); + let (writer_sender, writer_receiver) = channel::unbounded::(); + spawn_and_log_error(async move { + loop { + let frame = writer_receiver.recv().await?; + frame_writer.send_frame(frame).await? + } + }); + info!("Starting TCP server"); + let listener = TcpListener::bind(format!("127.0.0.1:{local_port}")).await?; + let mut incoming = listener.incoming(); -async fn try_main(url: &Url, opts: &Opts) -> Result { - match make_call(url, opts).await { + while let Some(stream) = incoming.next().await { + let stream = stream?; + info!("New connection from {:?}", stream.local_addr()); + info!("Creating tunnel"); + //let tunid = call(&mut *frame_reader, &mut *frame_writer, ".app/tunnel", "create", Some(tun_opts.into())).await?.as_str().to_owned(); + let host = host.clone(); + let reader_cmd_sender = reader_cmd_sender.clone(); + let writer_sender = writer_sender.clone(); + spawn_and_log_error(async move { + let tunid = { + let tun_opts = Map::from([("host".into(), host.into())]); + let rq = RpcMessage::new_request(".app/tunnel", "create", Some(tun_opts.into())); + let rqid = rq.request_id().unwrap(); + let (sender, receiver) = channel::unbounded::(); + reader_cmd_sender.send(RpcReaderCmd::RegisterResponse(rqid, sender, true)).await?; + writer_sender.send(rq.to_frame()?).await?; + let resp = receiver.recv().await?; + resp.to_rpcmesage()?.result()?.as_str().to_owned() + }; + let rq = RpcMessage::new_request(&format!(".app/tunnel/{tunid}"), "write", None); + let rqid = rq.request_id().unwrap(); + let (sender, receiver) = channel::unbounded::(); + reader_cmd_sender.send(RpcReaderCmd::RegisterResponse(rqid, sender, false)).await?; + writer_sender.send(rq.to_frame()?).await?; + let (mut sock_reader, mut sock_writer) = stream.split(); + let mut sock_read_buff: [u8; 1024] = [0; 1024]; + loop { + select! { + n = sock_reader.read(&mut sock_read_buff).fuse() => { + let n = n?; + if n == 0 { + info!("Tunnel client socket closed"); + break; + } + let data = RpcValue::from(&sock_read_buff[0 .. n]); + let rq = RpcMessage::new_request(&format!(".app/tunnel/{tunid}"), "write", Some(data)); + writer_sender.send(rq.to_frame()?).await?; + } + frame = receiver.recv().fuse() => { + match frame { + Ok(frame) => { + let resp = frame.to_rpcmesage()?; + let data = resp.result()?.as_blob(); + sock_writer.write_all(data).await?; + sock_writer.flush().await?; + } + Err(e) => { + error!("Get response receiver error: {e}"); + break; + } + } + } + } + } + reader_cmd_sender.send(RpcReaderCmd::UnregisterResponse(rqid)).await?; + Ok(()) + }); + } + Ok(()) +} +async fn try_main(url: &Url, opts: Opts) -> Result { + let (frame_reader, frame_writer) = login(url).await?; + let res = if opts.tunnel.is_some() { + make_tunnel(frame_reader, frame_writer, &opts).await + } else { + make_call(frame_reader, frame_writer, &opts).await + }; + match res { Ok(_) => Ok(()), Err(err) => { eprintln!("{err}");