From ce48e85996f81a9a20846217bb5c150f7745b082 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Sun, 20 Oct 2019 12:01:24 -0400 Subject: configure oauth through the config file too --- src/cmd/server.rs | 22 ++++++++++++ src/config.rs | 82 +++++++++++++++++++++++++++++++++++++++++++++ src/error.rs | 19 +++++++++++ src/oauth.rs | 4 +++ src/oauth/recurse_center.rs | 23 ++++++------- src/server.rs | 48 ++++++++++++-------------- src/server/tls.rs | 5 +++ 7 files changed, 163 insertions(+), 40 deletions(-) diff --git a/src/cmd/server.rs b/src/cmd/server.rs index bc9b985..6e98fe5 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -5,6 +5,16 @@ use std::io::Read as _; pub struct Config { #[serde(default)] server: crate::config::Server, + + #[serde( + rename = "oauth", + deserialize_with = "crate::config::oauth_configs", + default + )] + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, } impl crate::config::Config for Config { @@ -24,6 +34,7 @@ impl crate::config::Config for Config { self.server.read_timeout, tls_identity_file, self.server.allowed_login_methods.clone(), + self.oauth_configs.clone(), )? } else { create_server( @@ -31,6 +42,7 @@ impl crate::config::Config for Config { self.server.buffer_size, self.server.read_timeout, self.server.allowed_login_methods.clone(), + self.oauth_configs.clone(), )? }; tokio::run(futures::future::lazy(move || { @@ -66,6 +78,10 @@ fn create_server( allowed_login_methods: std::collections::HashSet< crate::protocol::AuthType, >, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, ) -> Result<( Box + Send>, Box + Send>, @@ -86,6 +102,7 @@ fn create_server( read_timeout, sock_r, allowed_login_methods, + oauth_configs, ); Ok((Box::new(acceptor), Box::new(server))) } @@ -98,6 +115,10 @@ fn create_server_tls( allowed_login_methods: std::collections::HashSet< crate::protocol::AuthType, >, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, ) -> Result<( Box + Send>, Box + Send>, @@ -134,6 +155,7 @@ fn create_server_tls( read_timeout, sock_r, allowed_login_methods, + oauth_configs, ); Ok((Box::new(acceptor), Box::new(server))) } diff --git a/src/config.rs b/src/config.rs index f9b6588..3a22085 100644 --- a/src/config.rs +++ b/src/config.rs @@ -513,3 +513,85 @@ impl Default for Ttyrec { fn default_ttyrec_filename() -> String { DEFAULT_TTYREC_FILENAME.to_string() } + +pub fn oauth_configs<'a, D>( + deserializer: D, +) -> std::result::Result< + std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, + D::Error, +> +where + D: serde::de::Deserializer<'a>, +{ + let configs = + >::deserialize( + deserializer, + )?; + let mut ret = std::collections::HashMap::new(); + for (key, config) in configs { + let auth_type = crate::protocol::AuthType::try_from(key.as_str()) + .map_err(serde::de::Error::custom)?; + let real_config = match auth_type { + crate::protocol::AuthType::RecurseCenter => { + let client_id = config + .client_id + .context(crate::error::OauthMissingConfiguration { + field: "client_id", + auth_type, + }) + .map_err(serde::de::Error::custom)?; + let client_secret = config + .client_secret + .context(crate::error::OauthMissingConfiguration { + field: "client_secret", + auth_type, + }) + .map_err(serde::de::Error::custom)?; + crate::oauth::RecurseCenter::config( + &client_id, + &client_secret, + ) + } + ty if !ty.is_oauth() => { + return Err(Error::AuthTypeNotOauth { ty: auth_type }) + .map_err(serde::de::Error::custom); + } + _ => unreachable!(), + }; + ret.insert(auth_type, real_config); + } + Ok(ret) +} + +#[derive(serde::Deserialize, Debug)] +struct OauthConfig { + #[serde(default)] + client_id: Option, + + #[serde(default)] + client_secret: Option, + + #[serde(deserialize_with = "url", default)] + auth_url: Option, + + #[serde(deserialize_with = "url", default)] + token_url: Option, + + #[serde(deserialize_with = "url", default)] + redirect_url: Option, +} + +fn url<'a, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + Ok(>::deserialize(deserializer)? + .map(|s| url::Url::parse(&s)) + .transpose() + .map_err(serde::de::Error::custom)?) +} diff --git a/src/error.rs b/src/error.rs index c5506c2..eeca1e8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,9 +4,18 @@ pub enum Error { #[snafu(display("failed to accept: {}", source))] Acceptor { source: tokio::io::Error }, + #[snafu(display( + "oauth configuration for auth type {:?} not found", + ty + ))] + AuthTypeMissingOauthConfig { ty: crate::protocol::AuthType }, + #[snafu(display("auth type {:?} not allowed", ty))] AuthTypeNotAllowed { ty: crate::protocol::AuthType }, + #[snafu(display("auth type {:?} does not use oauth", ty))] + AuthTypeNotOauth { ty: crate::protocol::AuthType }, + #[snafu(display("failed to bind to {}: {}", address, source))] Bind { address: std::net::SocketAddr, @@ -127,6 +136,16 @@ pub enum Error { ))] NotAFileName { path: String }, + #[snafu(display( + "missing oauth configuration item {} for auth type {}", + field, + auth_type.name(), + ))] + OauthMissingConfiguration { + field: String, + auth_type: crate::protocol::AuthType, + }, + #[snafu(display("failed to open file {}: {}", filename, source))] OpenFile { filename: String, diff --git a/src/oauth.rs b/src/oauth.rs index 78ca392..0840585 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -5,6 +5,9 @@ use std::io::Read as _; mod recurse_center; pub use recurse_center::RecurseCenter; +// this needs to be fixed because we listen for it in a hardcoded place +pub const REDIRECT_URL: &str = "http://localhost:44141/oauth"; + pub trait Oauth { fn client(&self) -> &oauth2::basic::BasicClient; fn user_id(&self) -> &str; @@ -127,6 +130,7 @@ fn cache_refresh_token( Box::new(fut) } +#[derive(Debug, Clone)] pub struct Config { client_id: String, client_secret: String, diff --git a/src/oauth/recurse_center.rs b/src/oauth/recurse_center.rs index b12b968..2b9f7f7 100644 --- a/src/oauth/recurse_center.rs +++ b/src/oauth/recurse_center.rs @@ -6,13 +6,15 @@ pub struct RecurseCenter { } impl RecurseCenter { - pub fn new( - client_id: &str, - client_secret: &str, - redirect_url: url::Url, - user_id: &str, - ) -> Self { - let config = super::Config { + pub fn new(config: super::Config, user_id: &str) -> Self { + Self { + client: config.into_basic_client(), + user_id: user_id.to_string(), + } + } + + pub fn config(client_id: &str, client_secret: &str) -> super::Config { + super::Config { client_id: client_id.to_string(), client_secret: client_secret.to_string(), auth_url: url::Url::parse( @@ -21,12 +23,7 @@ impl RecurseCenter { .unwrap(), token_url: url::Url::parse("https://www.recurse.com/oauth/token") .unwrap(), - redirect_url, - }; - - Self { - client: config.into_basic_client(), - user_id: user_id.to_string(), + redirect_url: url::Url::parse(super::REDIRECT_URL).unwrap(), } } } diff --git a/src/server.rs b/src/server.rs index b49a564..8c965a9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -330,6 +330,10 @@ pub struct Server< connections: std::collections::HashMap>, rate_limiter: ratelimit_meter::KeyedRateLimiter>, allowed_auth_types: std::collections::HashSet, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, } impl @@ -342,6 +346,10 @@ impl allowed_auth_types: std::collections::HashSet< crate::protocol::AuthType, >, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, ) -> Self { let sock_stream = sock_r .map(move |s| Connection::new(s, buffer_size)) @@ -356,6 +364,7 @@ impl std::time::Duration::from_secs(60), ), allowed_auth_types, + oauth_configs, } } @@ -398,34 +407,19 @@ impl )); } oauth if oauth.is_oauth() => { + let config = self.oauth_configs.get(&ty).context( + crate::error::AuthTypeMissingOauthConfig { ty }, + )?; let (refresh, client) = match oauth { - crate::protocol::Auth::RecurseCenter { id } => { - // XXX this needs some kind of real configuration - // system - let client_id = - std::env::var("TT_RECURSE_CENTER_CLIENT_ID") - .unwrap(); - let client_secret = - std::env::var("TT_RECURSE_CENTER_CLIENT_SECRET") - .unwrap(); - let redirect_url = - std::env::var("TT_RECURSE_CENTER_REDIRECT_URL") - .unwrap(); - let redirect_url = - url::Url::parse(&redirect_url).unwrap(); - - ( - id.is_some(), - Box::new(crate::oauth::RecurseCenter::new( - &client_id, - &client_secret, - redirect_url, - &id.clone().unwrap_or_else(|| { - format!("{}", uuid::Uuid::new_v4()) - }), - )), - ) - } + crate::protocol::Auth::RecurseCenter { id } => ( + id.is_some(), + Box::new(crate::oauth::RecurseCenter::new( + config.clone(), + &id.clone().unwrap_or_else(|| { + format!("{}", uuid::Uuid::new_v4()) + }), + )), + ), _ => unreachable!(), }; diff --git a/src/server/tls.rs b/src/server/tls.rs index a05608a..d8eb088 100644 --- a/src/server/tls.rs +++ b/src/server/tls.rs @@ -20,6 +20,10 @@ impl Server { allowed_login_methods: std::collections::HashSet< crate::protocol::AuthType, >, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, ) -> Self { let (tls_sock_w, tls_sock_r) = tokio::sync::mpsc::channel(100); Self { @@ -28,6 +32,7 @@ impl Server { read_timeout, tls_sock_r, allowed_login_methods, + oauth_configs, ), sock_r, sock_w: tls_sock_w, -- cgit v1.2.3-54-g00ecf