diff options
Diffstat (limited to 'src/config.rs')
-rw-r--r-- | src/config.rs | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs index 3a22085..dcc44d7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -222,6 +222,12 @@ pub struct Server { )] pub allowed_login_methods: std::collections::HashSet<crate::protocol::AuthType>, + + #[serde(deserialize_with = "uid", default)] + pub uid: Option<users::uid_t>, + + #[serde(deserialize_with = "gid", default)] + pub gid: Option<users::gid_t>, } impl Server { @@ -307,6 +313,8 @@ impl Default for Server { read_timeout: default_read_timeout(), tls_identity_file: None, allowed_login_methods: default_allowed_login_methods(), + uid: None, + gid: None, } } } @@ -410,6 +418,117 @@ fn default_allowed_login_methods( crate::protocol::AuthType::iter().collect() } +fn uid<'a, D>( + deserializer: D, +) -> std::result::Result<Option<users::uid_t>, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + struct StringOrInt; + + impl<'a> serde::de::Visitor<'a> for StringOrInt { + type Value = Option<u32>; + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + formatter.write_str("string or int") + } + + fn visit_str<E>( + self, + value: &str, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(Some( + users::get_user_by_name(value) + .context(crate::error::UnknownUser { name: value }) + .map_err(serde::de::Error::custom)? + .uid(), + )) + } + + fn visit_u32<E>( + self, + value: u32, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + if users::get_user_by_uid(value).is_none() { + return Err(serde::de::Error::custom(Error::UnknownUid { + uid: value, + })); + } + Ok(Some(value)) + } + } + + deserializer.deserialize_any(StringOrInt) +} + +fn gid<'a, D>( + deserializer: D, +) -> std::result::Result<Option<users::gid_t>, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + struct StringOrInt; + + impl<'a> serde::de::Visitor<'a> for StringOrInt { + type Value = Option<u32>; + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + formatter.write_str("string or int") + } + + fn visit_none<E>(self) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_str<E>( + self, + value: &str, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + Ok(Some( + users::get_group_by_name(value) + .context(crate::error::UnknownGroup { name: value }) + .map_err(serde::de::Error::custom)? + .gid(), + )) + } + + fn visit_u32<E>( + self, + value: u32, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + if users::get_group_by_gid(value).is_none() { + return Err(serde::de::Error::custom(Error::UnknownGid { + gid: value, + })); + } + Ok(Some(value)) + } + } + + deserializer.deserialize_any(StringOrInt) +} + #[derive(serde::Deserialize, Debug)] pub struct Command { #[serde(default = "default_buffer_size")] |