Skip to content

Commit

Permalink
Implement Deserialize trait for ServerAddr
Browse files Browse the repository at this point in the history
  • Loading branch information
nepalez committed Jul 31, 2023
1 parent 78c0bd2 commit 8b4bfec
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ use tracing::{debug, error};

use core::fmt;
use std::collections::HashMap;
use std::fmt::Display;
use std::fmt::{Display, Formatter};
use std::iter;
use std::net::{SocketAddr, ToSocketAddrs};
use std::option;
Expand All @@ -141,7 +141,8 @@ use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
use url::{Host, Url};

use bytes::Bytes;
use serde::{Deserialize, Serialize};
use serde::de::{self, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use tokio::io;
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -1252,6 +1253,29 @@ pub trait ToServerAddrs {
fn to_server_addrs(&self) -> io::Result<Self::Iter>;
}

struct ServerAddrVisitor;

impl<'de> Visitor<'de> for ServerAddrVisitor {
type Value = ServerAddr;

fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
formatter.write_str("a valid NATS server address")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> where E: de::Error {
ServerAddr::from_str(v).map_err(|e| de::Error::custom(e.to_string()))
}
}

impl<'de> Deserialize<'de> for ServerAddr {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ServerAddrVisitor)
}
}

impl ToServerAddrs for ServerAddr {
type Iter = option::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
Expand Down Expand Up @@ -1336,4 +1360,24 @@ mod tests {
let address = ServerAddr::from_str("nats://example.com").unwrap();
assert_eq!(address.host(), "example.com")
}

#[test]
fn deserialize_valid_server_address() {
let serialized = "\"nats://example.com\"";
let address = ServerAddr::from_str("nats://example.com").unwrap();

assert_eq!(
serde_json::from_str::<ServerAddr>(serialized).unwrap(),
address
);
}

#[test]
fn deserialize_invalid_server_address() {
let serialized = "\"this is not the address\"";
let result = serde_json::from_str::<ServerAddr>(serialized);

assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("NATS server URL is invalid"));
}
}

0 comments on commit 8b4bfec

Please sign in to comment.