From e3783dbf847aad0c98ae124a51dd9eff98f9500b Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Sun, 28 Aug 2022 16:27:14 +0300 Subject: [PATCH] Add username to on_init --- mysql/src/lib.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mysql/src/lib.rs b/mysql/src/lib.rs index 265ae13..340c325 100644 --- a/mysql/src/lib.rs +++ b/mysql/src/lib.rs @@ -193,6 +193,7 @@ pub trait AsyncMysqlShim { async fn on_init<'a>( &'a mut self, _: &'a str, + _: &'a str, _: InitWriter<'a, W>, ) -> Result<(), Self::Error> { Ok(()) @@ -220,6 +221,7 @@ const AUTH_PLUGIN_DATA_PART_1_LENGTH: usize = 8; pub struct AsyncMysqlIntermediary { pub(crate) client_capabilities: CapabilityFlags, process_use_statement_on_query: bool, + username: Option, shim: B, reader: packet_reader::PacketReader, writer: packet_writer::PacketWriter, @@ -248,6 +250,7 @@ where let r = packet_reader::PacketReader::new(input_stream); let w = packet_writer::PacketWriter::new(output_stream); let mut mi = AsyncMysqlIntermediary { + username: None, client_capabilities: CapabilityFlags::from_bits_truncate(0), process_use_statement_on_query: opts.process_use_statement_on_query, shim, @@ -423,13 +426,19 @@ where self.writer.flush_all().await?; return Err(io::Error::new(io::ErrorKind::PermissionDenied, err_msg).into()); } + self.username = Some( + String::from_utf8(handshake.username) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, + ); if let Some(Ok(db)) = handshake.db.as_ref().map(|x| std::str::from_utf8(x)) { let w = InitWriter { client_capabilities: self.client_capabilities, writer: &mut self.writer, }; - self.shim.on_init(db, w).await?; + self.shim + .on_init(db, self.username.as_ref().unwrap().as_ref(), w) + .await?; } else { writers::write_ok_packet( &mut self.writer, @@ -499,7 +508,9 @@ where let schema = ::std::str::from_utf8(&q[b"USE ".len()..]) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let schema = schema.trim().trim_end_matches(';').trim_matches('`'); - self.shim.on_init(schema, w).await?; + self.shim + .on_init(schema, self.username.as_ref().unwrap().as_ref(), w) + .await?; } else { let w = QueryResultWriter::new( &mut self.writer, @@ -599,6 +610,7 @@ where ::std::str::from_utf8(schema).map_err(|e| { io::Error::new(io::ErrorKind::InvalidData, e) })?, + self.username.as_ref().unwrap().as_ref(), w, ) .await?;