Skip to content

Commit

Permalink
Flatten + untagged enums don't really work
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake-Shadle committed Sep 12, 2024
1 parent 5a9304e commit 97c641d
Showing 1 changed file with 77 additions and 6 deletions.
83 changes: 77 additions & 6 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,85 @@ pub(crate) const BACKOFF_INITIAL_DELAY: Duration = Duration::from_millis(500);

base64_serde_type!(pub Base64Standard, base64::engine::general_purpose::STANDARD);

#[derive(Clone, Debug, Deserialize, Serialize, JsonSchema)]
#[derive(Clone, Debug, Serialize, JsonSchema)]
#[serde(untagged)]
#[cfg_attr(test, derive(PartialEq))]
pub enum DatacenterConfig {
NonAgent {
#[serde(default)]
datacenters: Watch<DatacenterMap>,
},
Agent {
#[serde(default)]
icao_code: Slot<IcaoCode>,
#[serde(default)]
qcmp_port: Slot<u16>,
},
NonAgent {
#[serde(default)]
datacenters: Watch<DatacenterMap>,
},
}

impl<'de> Deserialize<'de> for DatacenterConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Visitor;

impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = DatacenterConfig;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("DatacenterConfig")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut icao_code = Option::<IcaoCode>::None;
let mut qcmp_port = Option::<u16>::None;
let mut datacenters = None;

while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"icao_code" => {
icao_code = Some(map.next_value()?);
}
"qcmp_port" => {
qcmp_port = Some(map.next_value()?);
}
"datacenters" => {
datacenters = Some(map.next_value()?);
}
unknown => {
return Err(serde::de::Error::unknown_field(
unknown,
&["icao_code", "qcmp_port", "datacenters"],
));
}
}
}

if icao_code.is_some() || qcmp_port.is_some() {
if datacenters.is_some() {
return Err(serde::de::Error::custom(
"icao_code or qcmp_port are not compatible with datacenters",
));
}

Ok(DatacenterConfig::Agent {
icao_code: Slot::new(icao_code.unwrap_or_default()),
qcmp_port: Slot::new(qcmp_port.unwrap_or_default()),
})
} else {
Ok(DatacenterConfig::NonAgent {
datacenters: Watch::new(datacenters.unwrap_or_default()),
})
}
}
}

deserializer.deserialize_map(Visitor)
}
}

/// Configuration for a component
Expand All @@ -80,7 +145,6 @@ pub struct Config {
pub id: Slot<String>,
#[serde(default)]
pub version: Slot<Version>,
#[serde(flatten)]
pub datacenter: DatacenterConfig,
}

Expand Down Expand Up @@ -768,6 +832,13 @@ impl<'de> Deserialize<'de> for IcaoCode {
f.write_str("a 4-character, uppercase, alphabetical ASCII ICAO code")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(serde::de::Error::custom)
}

fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
where
E: serde::de::Error,
Expand Down

0 comments on commit 97c641d

Please sign in to comment.