summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKesavan Yogeswaran <hikes@google.com>2022-06-28 00:18:48 -0400
committerKesavan Yogeswaran <hikes@google.com>2022-06-28 23:34:17 -0400
commit7db2e8bfb46d9364ddc3419d3186b150141cc890 (patch)
tree3f9b94b879bbe6a73814d90402b77fed586038a0
parent8b41015dbb231a2e5e0be37698d49924a10b290f (diff)
Use TryInto for more permissive deserialization for integers
* Attempt to convert between integer types using `TryInto`-based conversions rather than blanket failing for some source and destination types. * Use `into_uint` instead of `into_int` in `Value` Deserialize implementations for unsigned integer types. Previously, we were converting from signed types to unsigned types using `as`, which can lead to surprise integer values conversions (#93). Fixes #352 and #93
-rw-r--r--src/de.rs8
-rw-r--r--src/value.rs109
-rw-r--r--tests/env.rs73
-rw-r--r--tests/integer_range.rs17
4 files changed, 158 insertions, 49 deletions
diff --git a/src/de.rs b/src/de.rs
index 9df347f..d1271b2 100644
--- a/src/de.rs
+++ b/src/de.rs
@@ -62,25 +62,25 @@ impl<'de> de::Deserializer<'de> for Value {
#[inline]
fn deserialize_u8<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
- visitor.visit_u8(self.into_int()? as u8)
+ visitor.visit_u8(self.into_uint()? as u8)
}
#[inline]
fn deserialize_u16<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
- visitor.visit_u16(self.into_int()? as u16)
+ visitor.visit_u16(self.into_uint()? as u16)
}
#[inline]
fn deserialize_u32<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
- visitor.visit_u32(self.into_int()? as u32)
+ visitor.visit_u32(self.into_uint()? as u32)
}
#[inline]
fn deserialize_u64<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
// FIXME: This should *fail* if the value does not fit in the requets integer type
- visitor.visit_u64(self.into_int()? as u64)
+ visitor.visit_u64(self.into_uint()? as u64)
}
#[inline]
diff --git a/src/value.rs b/src/value.rs
index 1727e4c..c9536f7 100644
--- a/src/value.rs
+++ b/src/value.rs
@@ -1,3 +1,4 @@
+use std::convert::TryInto;
use std::fmt;
use std::fmt::Display;
@@ -269,21 +270,27 @@ impl Value {
pub fn into_int(self) -> Result<i64> {
match self.kind {
ValueKind::I64(value) => Ok(value),
- ValueKind::I128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::I128(value),
- "an signed 64 bit or less integer",
- )),
- ValueKind::U64(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::U64(value),
- "an signed 64 bit or less integer",
- )),
- ValueKind::U128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::U128(value),
- "an signed 64 bit or less integer",
- )),
+ ValueKind::I128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::I128(value),
+ "an signed 64 bit or less integer",
+ )
+ }),
+ ValueKind::U64(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::U64(value),
+ "an signed 64 bit or less integer",
+ )
+ }),
+ ValueKind::U128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::U128(value),
+ "an signed 64 bit or less integer",
+ )
+ }),
ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
@@ -330,11 +337,13 @@ impl Value {
ValueKind::I64(value) => Ok(value.into()),
ValueKind::I128(value) => Ok(value),
ValueKind::U64(value) => Ok(value.into()),
- ValueKind::U128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::U128(value),
- "an signed 128 bit integer",
- )),
+ ValueKind::U128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::U128(value),
+ "an signed 128 bit integer",
+ )
+ }),
ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
@@ -380,21 +389,27 @@ impl Value {
pub fn into_uint(self) -> Result<u64> {
match self.kind {
ValueKind::U64(value) => Ok(value),
- ValueKind::U128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::U128(value),
- "an unsigned 64 bit or less integer",
- )),
- ValueKind::I64(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::I64(value),
- "an unsigned 64 bit or less integer",
- )),
- ValueKind::I128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::I128(value),
- "an unsigned 64 bit or less integer",
- )),
+ ValueKind::U128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::U128(value),
+ "an unsigned 64 bit or less integer",
+ )
+ }),
+ ValueKind::I64(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::I64(value),
+ "an unsigned 64 bit or less integer",
+ )
+ }),
+ ValueKind::I128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::I128(value),
+ "an unsigned 64 bit or less integer",
+ )
+ }),
ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
@@ -440,16 +455,20 @@ impl Value {
match self.kind {
ValueKind::U64(value) => Ok(value.into()),
ValueKind::U128(value) => Ok(value),
- ValueKind::I64(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::I64(value),
- "an unsigned 128 bit or less integer",
- )),
- ValueKind::I128(value) => Err(ConfigError::invalid_type(
- self.origin,
- Unexpected::I128(value),
- "an unsigned 128 bit or less integer",
- )),
+ ValueKind::I64(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::I64(value),
+ "an unsigned 128 bit or less integer",
+ )
+ }),
+ ValueKind::I128(value) => value.try_into().map_err(|_| {
+ ConfigError::invalid_type(
+ self.origin,
+ Unexpected::I128(value),
+ "an unsigned 128 bit or less integer",
+ )
+ }),
ValueKind::String(ref s) => {
match s.to_lowercase().as_ref() {
diff --git a/tests/env.rs b/tests/env.rs
index ad252e4..2ee67de 100644
--- a/tests/env.rs
+++ b/tests/env.rs
@@ -131,6 +131,39 @@ fn test_parse_int() {
}
#[test]
+fn test_parse_uint() {
+ // using a struct in an enum here to make serde use `deserialize_any`
+ #[derive(Deserialize, Debug)]
+ #[serde(tag = "tag")]
+ enum TestUintEnum {
+ Uint(TestUint),
+ }
+
+ #[derive(Deserialize, Debug)]
+ struct TestUint {
+ int_val: u32,
+ }
+
+ temp_env::with_var("INT_VAL", Some("42"), || {
+ let environment = Environment::default().try_parsing(true);
+
+ let config = Config::builder()
+ .set_default("tag", "Uint")
+ .unwrap()
+ .add_source(environment)
+ .build()
+ .unwrap();
+
+ let config: TestUintEnum = config.try_deserialize().unwrap();
+
+ assert!(matches!(
+ config,
+ TestUintEnum::Uint(TestUint { int_val: 42 })
+ ));
+ })
+}
+
+#[test]
fn test_parse_float() {
// using a struct in an enum here to make serde use `deserialize_any`
#[derive(Deserialize, Debug)]
@@ -535,3 +568,43 @@ fn test_parse_off_string() {
}
})
}
+
+#[test]
+fn test_parse_int_default() {
+ #[derive(Deserialize, Debug)]
+ struct TestInt {
+ int_val: i32,
+ }
+
+ let environment = Environment::default().try_parsing(true);
+
+ let config = Config::builder()
+ .set_default("int_val", 42_i32)
+ .unwrap()
+ .add_source(environment)
+ .build()
+ .unwrap();
+
+ let config: TestInt = config.try_deserialize().unwrap();
+ assert_eq!(config.int_val, 42);
+}
+
+#[test]
+fn test_parse_uint_default() {
+ #[derive(Deserialize, Debug)]
+ struct TestUint {
+ int_val: u32,
+ }
+
+ let environment = Environment::default().try_parsing(true);
+
+ let config = Config::builder()
+ .set_default("int_val", 42_u32)
+ .unwrap()
+ .add_source(environment)
+ .build()
+ .unwrap();
+
+ let config: TestUint = config.try_deserialize().unwrap();
+ assert_eq!(config.int_val, 42);
+}
diff --git a/tests/integer_range.rs b/tests/integer_range.rs
index c3e8839..7777ef2 100644
--- a/tests/integer_range.rs
+++ b/tests/integer_range.rs
@@ -33,3 +33,20 @@ fn nonwrapping_u32() {
let port: u32 = c.get("settings.port").unwrap();
assert_eq!(port, 66000);
}
+
+#[test]
+#[should_panic]
+fn invalid_signedness() {
+ let c = Config::builder()
+ .add_source(config::File::from_str(
+ r#"
+ [settings]
+ port = -1
+ "#,
+ config::FileFormat::Toml,
+ ))
+ .build()
+ .unwrap();
+
+ let _: u32 = c.get("settings.port").unwrap();
+}