From 8b4bfec6d55ef1a07903227ad0105f597f257d26 Mon Sep 17 00:00:00 2001 From: Andrew Kozin Date: Tue, 18 Jul 2023 15:06:17 +0100 Subject: [PATCH] Implement Deserialize trait for ServerAddr --- async-nats/src/lib.rs | 48 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 200f88f7e..2cce37c58 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -128,7 +128,7 @@ use tracing::{debug, error}; use core::fmt; use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{Display, Formatter}; use std::iter; use std::net::{SocketAddr, ToSocketAddrs}; use std::option; @@ -141,7 +141,8 @@ use tokio::time::{interval, Duration, Interval, MissedTickBehavior}; use url::{Host, Url}; use bytes::Bytes; -use serde::{Deserialize, Serialize}; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use tokio::io; use tokio::sync::{mpsc, oneshot}; @@ -1252,6 +1253,29 @@ pub trait ToServerAddrs { fn to_server_addrs(&self) -> io::Result; } +struct ServerAddrVisitor; + +impl<'de> Visitor<'de> for ServerAddrVisitor { + type Value = ServerAddr; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("a valid NATS server address") + } + + fn visit_str(self, v: &str) -> Result where E: de::Error { + ServerAddr::from_str(v).map_err(|e| de::Error::custom(e.to_string())) + } +} + +impl<'de> Deserialize<'de> for ServerAddr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(ServerAddrVisitor) + } +} + impl ToServerAddrs for ServerAddr { type Iter = option::IntoIter; fn to_server_addrs(&self) -> io::Result { @@ -1336,4 +1360,24 @@ mod tests { let address = ServerAddr::from_str("nats://example.com").unwrap(); assert_eq!(address.host(), "example.com") } + + #[test] + fn deserialize_valid_server_address() { + let serialized = "\"nats://example.com\""; + let address = ServerAddr::from_str("nats://example.com").unwrap(); + + assert_eq!( + serde_json::from_str::(serialized).unwrap(), + address + ); + } + + #[test] + fn deserialize_invalid_server_address() { + let serialized = "\"this is not the address\""; + let result = serde_json::from_str::(serialized); + + assert!(result.is_err()); + assert!(format!("{}", result.unwrap_err()).contains("NATS server URL is invalid")); + } }