Skip to content

Commit

Permalink
Add username to on_init
Browse files Browse the repository at this point in the history
  • Loading branch information
burgalon committed Aug 28, 2022
1 parent a3df8cd commit 14310a7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions mysql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ pub trait AsyncMysqlShim<W: Send> {
async fn on_init<'a>(
&'a mut self,
_: &'a str,
_: &'a str,
_: InitWriter<'a, W>,
) -> Result<(), Self::Error> {
Ok(())
Expand Down Expand Up @@ -220,6 +221,7 @@ const AUTH_PLUGIN_DATA_PART_1_LENGTH: usize = 8;
pub struct AsyncMysqlIntermediary<B, S: AsyncRead + Unpin, W> {
pub(crate) client_capabilities: CapabilityFlags,
process_use_statement_on_query: bool,
username: Option<String>,
shim: B,
reader: packet_reader::PacketReader<S>,
writer: packet_writer::PacketWriter<W>,
Expand Down Expand Up @@ -248,6 +250,7 @@ where
let r = packet_reader::PacketReader::new(input_stream);
let w = packet_writer::PacketWriter::new(output_stream);
let mut mi = AsyncMysqlIntermediary {
username: None,
client_capabilities: CapabilityFlags::from_bits_truncate(0),
process_use_statement_on_query: opts.process_use_statement_on_query,
shim,
Expand Down Expand Up @@ -423,13 +426,16 @@ where
self.writer.flush_all().await?;
return Err(io::Error::new(io::ErrorKind::PermissionDenied, err_msg).into());
}
self.username = Some(String::from_utf8(handshake.username).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, e)
})?);

if let Some(Ok(db)) = handshake.db.as_ref().map(|x| std::str::from_utf8(x)) {
let w = InitWriter {
client_capabilities: self.client_capabilities,
writer: &mut self.writer,
};
self.shim.on_init(db, w).await?;
self.shim.on_init(db, self.username.as_ref().unwrap().as_ref(), w).await?;
} else {
writers::write_ok_packet(
&mut self.writer,
Expand Down Expand Up @@ -499,7 +505,7 @@ where
let schema = ::std::str::from_utf8(&q[b"USE ".len()..])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let schema = schema.trim().trim_end_matches(';').trim_matches('`');
self.shim.on_init(schema, w).await?;
self.shim.on_init(schema, self.username.as_ref().unwrap().as_ref(), w).await?;
} else {
let w = QueryResultWriter::new(
&mut self.writer,
Expand Down Expand Up @@ -599,6 +605,7 @@ where
::std::str::from_utf8(schema).map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, e)
})?,
self.username.as_ref().unwrap().as_ref(),
w,
)
.await?;
Expand Down

0 comments on commit 14310a7

Please sign in to comment.