From c7312e513e8d9696fea66b198e3526bc0d9f6325 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Wed, 18 Dec 2019 02:01:56 -0500 Subject: make the oauth client a struct instead of a trait --- teleterm/src/auth.rs | 1 + teleterm/src/auth/recurse_center.rs | 57 +++++++++++++++++ teleterm/src/config.rs | 2 +- teleterm/src/main.rs | 1 + teleterm/src/oauth.rs | 118 ++++++++++++++++++++++++----------- teleterm/src/oauth/recurse_center.rs | 90 -------------------------- teleterm/src/protocol.rs | 24 +++---- teleterm/src/server.rs | 60 +++++++++--------- 8 files changed, 182 insertions(+), 171 deletions(-) create mode 100644 teleterm/src/auth.rs create mode 100644 teleterm/src/auth/recurse_center.rs delete mode 100644 teleterm/src/oauth/recurse_center.rs diff --git a/teleterm/src/auth.rs b/teleterm/src/auth.rs new file mode 100644 index 0000000..3820f57 --- /dev/null +++ b/teleterm/src/auth.rs @@ -0,0 +1 @@ +pub mod recurse_center; diff --git a/teleterm/src/auth/recurse_center.rs b/teleterm/src/auth/recurse_center.rs new file mode 100644 index 0000000..203a0df --- /dev/null +++ b/teleterm/src/auth/recurse_center.rs @@ -0,0 +1,57 @@ +use crate::prelude::*; + +pub fn oauth_config( + client_id: &str, + client_secret: &str, + redirect_url: &url::Url, +) -> crate::oauth::Config { + crate::oauth::Config::new( + client_id.to_string(), + client_secret.to_string(), + url::Url::parse("https://www.recurse.com/oauth/authorize").unwrap(), + url::Url::parse("https://www.recurse.com/oauth/token").unwrap(), + redirect_url.clone(), + ) +} + +pub fn get_username( + access_token: &str, +) -> Box + Send> { + let fut = reqwest::r#async::Client::new() + .get("https://www.recurse.com/api/v1/profiles/me") + .bearer_auth(access_token) + .send() + .context(crate::error::GetRecurseCenterProfile) + .and_then(|mut res| res.json().context(crate::error::ParseJson)) + .map(|user: User| user.name()); + Box::new(fut) +} + +#[derive(serde::Deserialize)] +struct User { + name: String, + stints: Vec, +} + +#[derive(serde::Deserialize)] +struct Stint { + batch: Option, + start_date: String, +} + +#[derive(serde::Deserialize)] +struct Batch { + short_name: String, +} + +impl User { + fn name(&self) -> String { + let latest_stint = + self.stints.iter().max_by_key(|s| &s.start_date).unwrap(); + if let Some(batch) = &latest_stint.batch { + format!("{} ({})", self.name, batch.short_name) + } else { + self.name.to_string() + } + } +} diff --git a/teleterm/src/config.rs b/teleterm/src/config.rs index 95d5366..6590ee8 100644 --- a/teleterm/src/config.rs +++ b/teleterm/src/config.rs @@ -919,7 +919,7 @@ where let redirect_url = url::Url::parse(crate::oauth::CLI_REDIRECT_URL) .unwrap(); - crate::oauth::RecurseCenter::config( + crate::auth::recurse_center::oauth_config( &client_id, &client_secret, &redirect_url, diff --git a/teleterm/src/main.rs b/teleterm/src/main.rs index 2475981..6a88960 100644 --- a/teleterm/src/main.rs +++ b/teleterm/src/main.rs @@ -18,6 +18,7 @@ const _DUMMY_DEPENDENCY: &str = include_str!("../Cargo.toml"); mod prelude; mod async_stdin; +mod auth; mod client; mod cmd; mod config; diff --git a/teleterm/src/oauth.rs b/teleterm/src/oauth.rs index d76a1b7..c4ec4bb 100644 --- a/teleterm/src/oauth.rs +++ b/teleterm/src/oauth.rs @@ -1,40 +1,39 @@ use crate::prelude::*; use oauth2::TokenResponse as _; -mod recurse_center; -pub use recurse_center::RecurseCenter; - // this needs to be fixed because we listen for it in a hardcoded place pub const CLI_REDIRECT_URL: &str = "http://localhost:44141/oauth"; -pub trait Oauth { - fn client(&self) -> &oauth2::basic::BasicClient; - fn user_id(&self) -> &str; - fn name(&self) -> &str; +pub struct Oauth { + client: oauth2::basic::BasicClient, + user_id: String, +} - fn server_token_file( - &self, - must_exist: bool, - ) -> Option { - let name = format!("server-oauth-{}-{}", self.name(), self.user_id()); - crate::dirs::Dirs::new().data_file(&name, must_exist) +impl Oauth { + pub fn new(config: Config, user_id: String) -> Self { + let client = config.into_basic_client(); + Self { client, user_id } } - fn generate_authorize_url(&self) -> String { + pub fn generate_authorize_url(&self) -> String { let (auth_url, _) = self - .client() + .client .authorize_url(oauth2::CsrfToken::new_random) .url(); auth_url.to_string() } - fn get_access_token_from_auth_code( + pub fn user_id(&self) -> &str { + &self.user_id + } + + pub fn get_access_token_from_auth_code( &self, code: &str, ) -> Box + Send> { let token_cache_file = self.server_token_file(false).unwrap(); let fut = self - .client() + .client .exchange_code(oauth2::AuthorizationCode::new(code.to_string())) .request_future(oauth2::reqwest::future_http_client) .map_err(|e| { @@ -48,32 +47,61 @@ pub trait Oauth { Box::new(fut) } - fn get_access_token_from_refresh_token( - &self, - token: &str, + pub fn get_access_token_from_refresh_token( + self, ) -> Box + Send> { let token_cache_file = self.server_token_file(false).unwrap(); - let fut = self - .client() - .exchange_refresh_token(&oauth2::RefreshToken::new( - token.to_string(), - )) - .request_future(oauth2::reqwest::future_http_client) - .map_err(|e| { - let msg = stringify_oauth2_http_error(&e); - Error::ExchangeRefreshToken { msg } - }) - .and_then(|token| { - cache_refresh_token(token_cache_file, &token) - .map(move |_| token.access_token().secret().to_string()) - }); + let fut = load_refresh_token(&token_cache_file).and_then( + move |refresh_token| { + // XXX + let refresh_token = refresh_token.unwrap(); + self.client + .exchange_refresh_token(&oauth2::RefreshToken::new( + refresh_token, + )) + .request_future(oauth2::reqwest::future_http_client) + .map_err(|e| { + let msg = stringify_oauth2_http_error(&e); + Error::ExchangeRefreshToken { msg } + }) + .and_then(move |token| { + cache_refresh_token(token_cache_file, &token).map( + move |_| { + token.access_token().secret().to_string() + }, + ) + }) + }, + ); Box::new(fut) } - fn get_username_from_access_token( + pub fn server_token_file( &self, - token: &str, - ) -> Box + Send>; + must_exist: bool, + ) -> Option { + let name = format!("server-oauth-{}", self.user_id); + crate::dirs::Dirs::new().data_file(&name, must_exist) + } +} + +fn load_refresh_token( + token_cache_file: &std::path::Path, +) -> Box, Error = Error> + Send> { + let token_cache_file = token_cache_file.to_path_buf(); + Box::new( + tokio::fs::File::open(token_cache_file.clone()) + .with_context(move || crate::error::OpenFile { + filename: token_cache_file.to_string_lossy().to_string(), + }) + .and_then(|file| { + tokio::io::lines(std::io::BufReader::new(file)) + .into_future() + .map_err(|(e, _)| e) + .context(crate::error::ReadFile) + }) + .map(|(refresh_token, _)| refresh_token), + ) } fn cache_refresh_token( @@ -107,6 +135,22 @@ pub struct Config { } impl Config { + pub fn new( + client_id: String, + client_secret: String, + auth_url: url::Url, + token_url: url::Url, + redirect_url: url::Url, + ) -> Self { + Self { + client_id, + client_secret, + auth_url, + token_url, + redirect_url, + } + } + pub fn set_redirect_url(&mut self, url: url::Url) { self.redirect_url = url; } diff --git a/teleterm/src/oauth/recurse_center.rs b/teleterm/src/oauth/recurse_center.rs deleted file mode 100644 index 6ab41ba..0000000 --- a/teleterm/src/oauth/recurse_center.rs +++ /dev/null @@ -1,90 +0,0 @@ -use crate::prelude::*; - -pub struct RecurseCenter { - client: oauth2::basic::BasicClient, - user_id: String, -} - -impl RecurseCenter { - pub fn new(config: super::Config, user_id: &str) -> Self { - Self { - client: config.into_basic_client(), - user_id: user_id.to_string(), - } - } - - pub fn config( - client_id: &str, - client_secret: &str, - redirect_url: &url::Url, - ) -> super::Config { - super::Config { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - auth_url: url::Url::parse( - "https://www.recurse.com/oauth/authorize", - ) - .unwrap(), - token_url: url::Url::parse("https://www.recurse.com/oauth/token") - .unwrap(), - redirect_url: redirect_url.clone(), - } - } -} - -impl super::Oauth for RecurseCenter { - fn client(&self) -> &oauth2::basic::BasicClient { - &self.client - } - - fn user_id(&self) -> &str { - &self.user_id - } - - fn name(&self) -> &str { - crate::protocol::AuthType::RecurseCenter.name() - } - - fn get_username_from_access_token( - &self, - token: &str, - ) -> Box + Send> { - let fut = reqwest::r#async::Client::new() - .get("https://www.recurse.com/api/v1/profiles/me") - .bearer_auth(token) - .send() - .context(crate::error::GetRecurseCenterProfile) - .and_then(|mut res| res.json().context(crate::error::ParseJson)) - .map(|user: User| user.name()); - Box::new(fut) - } -} - -#[derive(serde::Deserialize)] -struct User { - name: String, - stints: Vec, -} - -#[derive(serde::Deserialize)] -struct Stint { - batch: Option, - start_date: String, -} - -#[derive(serde::Deserialize)] -struct Batch { - short_name: String, -} - -impl User { - fn name(&self) -> String { - let latest_stint = - self.stints.iter().max_by_key(|s| &s.start_date).unwrap(); - if let Some(batch) = &latest_stint.batch { - format!("{} ({})", self.name, batch.short_name) - } else { - self.name.to_string() - } - } -} diff --git a/teleterm/src/protocol.rs b/teleterm/src/protocol.rs index dbb69a0..b30b1f2 100644 --- a/teleterm/src/protocol.rs +++ b/teleterm/src/protocol.rs @@ -144,20 +144,14 @@ impl AuthType { self, config: &crate::oauth::Config, id: Option<&str>, - ) -> Option> { - match self { - Self::RecurseCenter => { - Some(Box::new(crate::oauth::RecurseCenter::new( - config.clone(), - &id.map_or_else( - || format!("{}", uuid::Uuid::new_v4()), - std::string::ToString::to_string, - ), - ))) - } - ty if !ty.is_oauth() => None, - _ => unreachable!(), - } + ) -> Option { + Some(crate::oauth::Oauth::new( + config.clone(), + id.map_or_else( + || format!("{}", uuid::Uuid::new_v4()), + std::string::ToString::to_string, + ), + )) } } @@ -213,7 +207,7 @@ impl Auth { pub fn oauth_client( &self, config: &crate::oauth::Config, - ) -> Option> { + ) -> Option { self.auth_type().oauth_client(config, self.oauth_id()) } diff --git a/teleterm/src/server.rs b/teleterm/src/server.rs index 36186d1..6437c40 100644 --- a/teleterm/src/server.rs +++ b/teleterm/src/server.rs @@ -55,6 +55,7 @@ struct TerminalInfo { enum ConnectionState { Accepted, LoggingIn { + auth_type: crate::protocol::AuthType, term_info: TerminalInfo, }, LoggedIn { @@ -88,7 +89,17 @@ impl ConnectionState { } } - fn term_info(&mut self) -> Option<&TerminalInfo> { + fn auth_type(&self) -> Option { + match self { + Self::Accepted => None, + Self::LoggingIn { auth_type, .. } => Some(*auth_type), + Self::LoggedIn { .. } => None, + Self::Streaming { .. } => None, + Self::Watching { .. } => None, + } + } + + fn term_info(&self) -> Option<&TerminalInfo> { match self { Self::Accepted => None, Self::LoggingIn { term_info, .. } => Some(term_info), @@ -159,11 +170,13 @@ impl ConnectionState { fn login_oauth_start( &mut self, + auth_type: crate::protocol::AuthType, term_type: &str, size: crate::term::Size, ) { if let Self::Accepted = self { *self = Self::LoggingIn { + auth_type, term_info: TerminalInfo { term: term_type.to_string(), size, @@ -218,7 +231,7 @@ struct Connection< closed: bool, state: ConnectionState, last_activity: std::time::Instant, - oauth_client: Option>, + oauth_client: Option, } impl @@ -427,32 +440,19 @@ impl .context(crate::error::AuthTypeMissingOauthConfig { ty })?; let client = auth.oauth_client(config).unwrap(); - if let (Some(token_filename), true) = - (client.server_token_file(true), auth.oauth_id().is_some()) + if client.server_token_file(true).is_some() + && auth.oauth_id().is_some() { let term_type = term_type.to_string(); - let client = conn.oauth_client.take().unwrap(); - let fut = tokio::fs::File::open(token_filename.clone()) - .with_context(move || crate::error::OpenFile { - filename: token_filename.to_string_lossy().to_string(), - }) - .and_then(|file| { - tokio::io::lines(std::io::BufReader::new(file)) - .into_future() - .map_err(|(e, _)| e) - .context(crate::error::ReadFile) - }) - .and_then(|(refresh_token, _)| { - // XXX unwrap here isn't super safe - let refresh_token = refresh_token.unwrap(); - client - .get_access_token_from_refresh_token( - refresh_token.trim(), + let fut = client + .get_access_token_from_refresh_token() + .and_then(move |access_token| match ty { + crate::protocol::AuthType::RecurseCenter => { + crate::auth::recurse_center::get_username( + &access_token, ) - .and_then(move |access_token| { - client - .get_username_from_access_token(&access_token) - }) + } + _ => unreachable!(), }) .map(move |username| { ( @@ -470,7 +470,7 @@ impl } else { conn.oauth_client = Some(client); let client = conn.oauth_client.as_ref().unwrap(); - conn.state.login_oauth_start(term_type, size); + conn.state.login_oauth_start(ty, term_type, size); let authorize_url = client.generate_authorize_url(); let user_id = client.user_id().to_string(); conn.send_message(crate::protocol::Message::oauth_cli_request( @@ -644,11 +644,15 @@ impl } })?; + let ty = conn.state.auth_type().unwrap(); let term_info = conn.state.term_info().unwrap().clone(); let fut = client .get_access_token_from_auth_code(code) - .and_then(move |access_token| { - client.get_username_from_access_token(&access_token) + .and_then(move |access_token| match ty { + crate::protocol::AuthType::RecurseCenter => { + crate::auth::recurse_center::get_username(&access_token) + } + _ => unreachable!(), }) .map(|username| { ( -- cgit v1.2.3