diff options
Diffstat (limited to 'src/config.rs')
-rw-r--r-- | src/config.rs | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/src/config.rs b/src/config.rs index f9b6588..3a22085 100644 --- a/src/config.rs +++ b/src/config.rs @@ -513,3 +513,85 @@ impl Default for Ttyrec { fn default_ttyrec_filename() -> String { DEFAULT_TTYREC_FILENAME.to_string() } + +pub fn oauth_configs<'a, D>( + deserializer: D, +) -> std::result::Result< + std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, + D::Error, +> +where + D: serde::de::Deserializer<'a>, +{ + let configs = + <std::collections::HashMap<String, OauthConfig>>::deserialize( + deserializer, + )?; + let mut ret = std::collections::HashMap::new(); + for (key, config) in configs { + let auth_type = crate::protocol::AuthType::try_from(key.as_str()) + .map_err(serde::de::Error::custom)?; + let real_config = match auth_type { + crate::protocol::AuthType::RecurseCenter => { + let client_id = config + .client_id + .context(crate::error::OauthMissingConfiguration { + field: "client_id", + auth_type, + }) + .map_err(serde::de::Error::custom)?; + let client_secret = config + .client_secret + .context(crate::error::OauthMissingConfiguration { + field: "client_secret", + auth_type, + }) + .map_err(serde::de::Error::custom)?; + crate::oauth::RecurseCenter::config( + &client_id, + &client_secret, + ) + } + ty if !ty.is_oauth() => { + return Err(Error::AuthTypeNotOauth { ty: auth_type }) + .map_err(serde::de::Error::custom); + } + _ => unreachable!(), + }; + ret.insert(auth_type, real_config); + } + Ok(ret) +} + +#[derive(serde::Deserialize, Debug)] +struct OauthConfig { + #[serde(default)] + client_id: Option<String>, + + #[serde(default)] + client_secret: Option<String>, + + #[serde(deserialize_with = "url", default)] + auth_url: Option<url::Url>, + + #[serde(deserialize_with = "url", default)] + token_url: Option<url::Url>, + + #[serde(deserialize_with = "url", default)] + redirect_url: Option<url::Url>, +} + +fn url<'a, D>( + deserializer: D, +) -> std::result::Result<Option<url::Url>, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + Ok(<Option<String>>::deserialize(deserializer)? + .map(|s| url::Url::parse(&s)) + .transpose() + .map_err(serde::de::Error::custom)?) +} |