diff --git a/nats/src/connector.rs b/nats/src/connector.rs index 02948d8f6..92ae1a5ab 100644 --- a/nats/src/connector.rs +++ b/nats/src/connector.rs @@ -13,8 +13,11 @@ use lazy_static::__Deref; use parking_lot::{Mutex, MutexGuard}; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer}; use std::collections::HashMap; use std::convert::TryFrom; +use std::fmt::{self, Formatter}; use std::io::prelude::*; use std::io::{self, BufReader, Error, ErrorKind}; use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs}; @@ -748,6 +751,26 @@ impl IntoServerList for io::Result> { } } +struct ServerAddressVisitor; + +impl<'de> Visitor<'de> for ServerAddressVisitor { + type Value = ServerAddress; + + fn expecting<'a>(&self, formatter: &mut Formatter<'a>) -> fmt::Result { + formatter.write_str("a valid URL") + } + + fn visit_str(self, value: &str) -> Result { + ServerAddress::from_str(value).map_err(|err| E::custom(format!("{}", err))) + } +} + +impl<'de> Deserialize<'de> for ServerAddress { + fn deserialize>(deserializer: D) -> Result { + deserializer.deserialize_str(ServerAddressVisitor) + } +} + #[cfg(test)] mod tests { use super::*; @@ -789,4 +812,24 @@ mod tests { matches!(address.auth(), AuthStyle::UserPass(username, password) if &username == "myuser" && &password == "mypass") ); } + + #[test] + fn deserialize_valid_server_address() { + let serialized = "\"nats://myuser:mypass@localhost:4222\""; + let address = ServerAddress::from_str("nats://myuser:mypass@localhost:4222").unwrap(); + + assert_eq!( + serde_json::from_str::(serialized).unwrap(), + address + ); + } + + #[test] + fn deserialize_invalid_server_address() { + let serialized = "\"this is not the address\""; + let result = serde_json::from_str::(serialized); + + assert!(result.is_err()); + assert!(format!("{}", result.unwrap_err()).contains("NATS server URL is invalid")); + } }