From 2cb768bddc715c32e740a067652e7200c8d344f3 Mon Sep 17 00:00:00 2001 From: Eugeen Sablin Date: Sat, 10 Nov 2018 19:23:16 +0300 Subject: support reading enums from config --- src/de.rs | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 4 deletions(-) (limited to 'src/de.rs') diff --git a/src/de.rs b/src/de.rs index 89d0c0c..e465d60 100644 --- a/src/de.rs +++ b/src/de.rs @@ -5,7 +5,7 @@ use std::borrow::Cow; use std::collections::hash_map::Drain; use std::collections::HashMap; use std::iter::Peekable; -use value::{Value, ValueKind, ValueWithKey}; +use value::{Value, ValueKind, ValueWithKey, Table}; // TODO: Use a macro or some other magic to reduce the code duplication here @@ -113,9 +113,22 @@ impl<'de> de::Deserializer<'de> for ValueWithKey<'de> { } } + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + // FIXME: find a way to extend_with_key + visitor.visit_enum(EnumAccess{ value: self.0, name: name, variants: variants }) + } + forward_to_deserialize_any! { char seq - bytes byte_buf map struct unit enum newtype_struct + bytes byte_buf map struct unit newtype_struct identifier ignored_any unit_struct tuple_struct tuple } } @@ -231,9 +244,21 @@ impl<'de> de::Deserializer<'de> for Value { visitor.visit_newtype_struct(self) } + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_enum(EnumAccess{ value: self, name: name, variants: variants }) + } + forward_to_deserialize_any! { char seq - bytes byte_buf map struct unit enum + bytes byte_buf map struct unit identifier ignored_any unit_struct tuple_struct tuple } } @@ -334,6 +359,107 @@ impl<'de> de::MapAccess<'de> for MapAccess { } } +struct EnumAccess { + value: Value, + name: &'static str, + variants: &'static [&'static str], +} + +impl EnumAccess { + fn variant_deserializer(&self, name: &String) -> Result { + self.variants + .iter() + .find(|&s| s.to_lowercase() == name.to_lowercase()) + .map(|&s| StrDeserializer(s)) + .ok_or(self.no_constructor_error(name)) + } + + fn table_deserializer(&self, table: &Table) -> Result { + if table.len() == 1 { + self.variant_deserializer(table.iter().next().unwrap().0) + } else { + Err(self.structural_error()) + } + } + + fn no_constructor_error(&self, supposed_variant: &str) -> ConfigError { + ConfigError::Message(format!( + "enum {} does not have variant constructor {}", + self.name, supposed_variant + )) + } + + fn structural_error(&self) -> ConfigError { + ConfigError::Message(format!( + "value of enum {} should be represented by either string or table with exactly one key", + self.name + )) + } +} + +impl<'de> de::EnumAccess<'de> for EnumAccess { + type Error = ConfigError; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: de::DeserializeSeed<'de>, + { + let value = { + let deserializer = match self.value.kind { + ValueKind::String(ref s) => self.variant_deserializer(s), + ValueKind::Table(ref t) => self.table_deserializer(&t), + _ => Err(self.structural_error()), + }?; + seed.deserialize(deserializer)? + }; + + Ok((value, self)) + } +} + +impl<'de> de::VariantAccess<'de> for EnumAccess { + type Error = ConfigError; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + match self.value.kind { + ValueKind::Table(t) => seed.deserialize(t.into_iter().next().unwrap().1), + _ => unreachable!(), + } + } + + fn tuple_variant(self, _len: usize, visitor: V) -> Result + where + V: de::Visitor<'de>, + { + match self.value.kind { + ValueKind::Table(t) => de::Deserializer::deserialize_seq(t.into_iter().next().unwrap().1, visitor), + _ => unreachable!(), + } + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + match self.value.kind { + ValueKind::Table(t) => de::Deserializer::deserialize_map(t.into_iter().next().unwrap().1, visitor), + _ => unreachable!(), + } + } +} + impl<'de> de::Deserializer<'de> for Config { type Error = ConfigError; @@ -438,9 +564,21 @@ impl<'de> de::Deserializer<'de> for Config { } } + fn deserialize_enum( + self, + name: &'static str, + variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: de::Visitor<'de>, + { + visitor.visit_enum(EnumAccess{ value: self.cache, name: name, variants: variants }) + } + forward_to_deserialize_any! { char seq - bytes byte_buf map struct unit enum newtype_struct + bytes byte_buf map struct unit newtype_struct identifier ignored_any unit_struct tuple_struct tuple } } -- cgit v1.2.3