diff --git a/src/completion.rs b/src/completion.rs index bdf60fb..2803c0c 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -255,14 +255,86 @@ pub struct InsertReplaceEdit { /// The range if the replace is requested. pub replace: Range, } - -#[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] +#[derive(Debug, Eq, PartialEq, Clone, Serialize)] #[serde(untagged)] pub enum CompletionTextEdit { Edit(TextEdit), InsertAndReplace(InsertReplaceEdit), } +impl<'de> Deserialize<'de> for CompletionTextEdit { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "camelCase")] + enum Field { + NewText, + Insert, + Replace, + Range, + } + + struct CompletionTextEditVisitor; + + impl<'de> serde::de::Visitor<'de> for CompletionTextEditVisitor { + type Value = CompletionTextEdit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct CompletionTextEdit") + } + + fn visit_map(self, mut map: V) -> Result + where + V: serde::de::MapAccess<'de>, + { + let mut new_text = None; + let mut insert = None; + let mut replace = None; + let mut range = None; + + while let Some(key) = map.next_key()? { + match key { + Field::NewText => { + new_text = Some(map.next_value()?); + } + Field::Insert => { + insert = Some(map.next_value()?); + } + Field::Replace => { + replace = Some(map.next_value()?); + } + Field::Range => { + range = Some(map.next_value()?); + } + } + } + if let Some(range) = range { + Ok(CompletionTextEdit::Edit(TextEdit { + new_text: new_text + .ok_or_else(|| serde::de::Error::missing_field("newText"))?, + range, + })) + } else if let (Some(new_text), Some(insert), Some(replace)) = + (new_text, insert, replace) + { + Ok(CompletionTextEdit::InsertAndReplace(InsertReplaceEdit { + new_text, + insert, + replace, + })) + } else { + Err(serde::de::Error::custom("missing required fields")) + } + } + } + + const FIELDS: &[&str] = &["newText", "insert", "replace", "range"]; + deserializer.deserialize_struct("CompletionTextEdit", FIELDS, CompletionTextEditVisitor) + } +} + impl From for CompletionTextEdit { fn from(edit: TextEdit) -> Self { CompletionTextEdit::Edit(edit) @@ -340,13 +412,50 @@ pub struct CompletionRegistrationOptions { pub completion_options: CompletionOptions, } -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Serialize)] #[serde(untagged)] pub enum CompletionResponse { Array(Vec), List(CompletionList), } +impl<'de> Deserialize<'de> for CompletionResponse { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct CompletionResponseVisitor; + + impl<'de> serde::de::Visitor<'de> for CompletionResponseVisitor { + type Value = CompletionResponse; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("an array or a struct for CompletionResponse") + } + + fn visit_seq(self, seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let items: Vec = + Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?; + Ok(CompletionResponse::Array(items)) + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let list: CompletionList = + Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?; + Ok(CompletionResponse::List(list)) + } + } + + deserializer.deserialize_any(CompletionResponseVisitor) + } +} + impl From> for CompletionResponse { fn from(items: Vec) -> Self { CompletionResponse::Array(items)