aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2019-12-18 02:01:56 -0500
committerJesse Luehrs <doy@tozt.net>2019-12-18 02:01:56 -0500
commitc7312e513e8d9696fea66b198e3526bc0d9f6325 (patch)
tree046668edc5db35054a3e7891784b1e1d9fafeb15
parentee5de447d977999e7f5b4b41685d74ce28d63535 (diff)
downloadteleterm-c7312e513e8d9696fea66b198e3526bc0d9f6325.tar.gz
teleterm-c7312e513e8d9696fea66b198e3526bc0d9f6325.zip
make the oauth client a struct instead of a trait
-rw-r--r--teleterm/src/auth.rs1
-rw-r--r--teleterm/src/auth/recurse_center.rs57
-rw-r--r--teleterm/src/config.rs2
-rw-r--r--teleterm/src/main.rs1
-rw-r--r--teleterm/src/oauth.rs118
-rw-r--r--teleterm/src/oauth/recurse_center.rs90
-rw-r--r--teleterm/src/protocol.rs24
-rw-r--r--teleterm/src/server.rs60
8 files changed, 182 insertions, 171 deletions
diff --git a/teleterm/src/auth.rs b/teleterm/src/auth.rs
new file mode 100644
index 0000000..3820f57
--- /dev/null
+++ b/teleterm/src/auth.rs
@@ -0,0 +1 @@
+pub mod recurse_center;
diff --git a/teleterm/src/auth/recurse_center.rs b/teleterm/src/auth/recurse_center.rs
new file mode 100644
index 0000000..203a0df
--- /dev/null
+++ b/teleterm/src/auth/recurse_center.rs
@@ -0,0 +1,57 @@
+use crate::prelude::*;
+
+pub fn oauth_config(
+ client_id: &str,
+ client_secret: &str,
+ redirect_url: &url::Url,
+) -> crate::oauth::Config {
+ crate::oauth::Config::new(
+ client_id.to_string(),
+ client_secret.to_string(),
+ url::Url::parse("https://www.recurse.com/oauth/authorize").unwrap(),
+ url::Url::parse("https://www.recurse.com/oauth/token").unwrap(),
+ redirect_url.clone(),
+ )
+}
+
+pub fn get_username(
+ access_token: &str,
+) -> Box<dyn futures::Future<Item = String, Error = Error> + Send> {
+ let fut = reqwest::r#async::Client::new()
+ .get("https://www.recurse.com/api/v1/profiles/me")
+ .bearer_auth(access_token)
+ .send()
+ .context(crate::error::GetRecurseCenterProfile)
+ .and_then(|mut res| res.json().context(crate::error::ParseJson))
+ .map(|user: User| user.name());
+ Box::new(fut)
+}
+
+#[derive(serde::Deserialize)]
+struct User {
+ name: String,
+ stints: Vec<Stint>,
+}
+
+#[derive(serde::Deserialize)]
+struct Stint {
+ batch: Option<Batch>,
+ start_date: String,
+}
+
+#[derive(serde::Deserialize)]
+struct Batch {
+ short_name: String,
+}
+
+impl User {
+ fn name(&self) -> String {
+ let latest_stint =
+ self.stints.iter().max_by_key(|s| &s.start_date).unwrap();
+ if let Some(batch) = &latest_stint.batch {
+ format!("{} ({})", self.name, batch.short_name)
+ } else {
+ self.name.to_string()
+ }
+ }
+}
diff --git a/teleterm/src/config.rs b/teleterm/src/config.rs
index 95d5366..6590ee8 100644
--- a/teleterm/src/config.rs
+++ b/teleterm/src/config.rs
@@ -919,7 +919,7 @@ where
let redirect_url =
url::Url::parse(crate::oauth::CLI_REDIRECT_URL)
.unwrap();
- crate::oauth::RecurseCenter::config(
+ crate::auth::recurse_center::oauth_config(
&client_id,
&client_secret,
&redirect_url,
diff --git a/teleterm/src/main.rs b/teleterm/src/main.rs
index 2475981..6a88960 100644
--- a/teleterm/src/main.rs
+++ b/teleterm/src/main.rs
@@ -18,6 +18,7 @@ const _DUMMY_DEPENDENCY: &str = include_str!("../Cargo.toml");
mod prelude;
mod async_stdin;
+mod auth;
mod client;
mod cmd;
mod config;
diff --git a/teleterm/src/oauth.rs b/teleterm/src/oauth.rs
index d76a1b7..c4ec4bb 100644
--- a/teleterm/src/oauth.rs
+++ b/teleterm/src/oauth.rs
@@ -1,40 +1,39 @@
use crate::prelude::*;
use oauth2::TokenResponse as _;
-mod recurse_center;
-pub use recurse_center::RecurseCenter;
-
// this needs to be fixed because we listen for it in a hardcoded place
pub const CLI_REDIRECT_URL: &str = "http://localhost:44141/oauth";
-pub trait Oauth {
- fn client(&self) -> &oauth2::basic::BasicClient;
- fn user_id(&self) -> &str;
- fn name(&self) -> &str;
+pub struct Oauth {
+ client: oauth2::basic::BasicClient,
+ user_id: String,
+}
- fn server_token_file(
- &self,
- must_exist: bool,
- ) -> Option<std::path::PathBuf> {
- let name = format!("server-oauth-{}-{}", self.name(), self.user_id());
- crate::dirs::Dirs::new().data_file(&name, must_exist)
+impl Oauth {
+ pub fn new(config: Config, user_id: String) -> Self {
+ let client = config.into_basic_client();
+ Self { client, user_id }
}
- fn generate_authorize_url(&self) -> String {
+ pub fn generate_authorize_url(&self) -> String {
let (auth_url, _) = self
- .client()
+ .client
.authorize_url(oauth2::CsrfToken::new_random)
.url();
auth_url.to_string()
}
- fn get_access_token_from_auth_code(
+ pub fn user_id(&self) -> &str {
+ &self.user_id
+ }
+
+ pub fn get_access_token_from_auth_code(
&self,
code: &str,
) -> Box<dyn futures::Future<Item = String, Error = Error> + Send> {
let token_cache_file = self.server_token_file(false).unwrap();
let fut = self
- .client()
+ .client
.exchange_code(oauth2::AuthorizationCode::new(code.to_string()))
.request_future(oauth2::reqwest::future_http_client)
.map_err(|e| {
@@ -48,32 +47,61 @@ pub trait Oauth {
Box::new(fut)
}
- fn get_access_token_from_refresh_token(
- &self,
- token: &str,
+ pub fn get_access_token_from_refresh_token(
+ self,
) -> Box<dyn futures::Future<Item = String, Error = Error> + Send> {
let token_cache_file = self.server_token_file(false).unwrap();
- let fut = self
- .client()
- .exchange_refresh_token(&oauth2::RefreshToken::new(
- token.to_string(),
- ))
- .request_future(oauth2::reqwest::future_http_client)
- .map_err(|e| {
- let msg = stringify_oauth2_http_error(&e);
- Error::ExchangeRefreshToken { msg }
- })
- .and_then(|token| {
- cache_refresh_token(token_cache_file, &token)
- .map(move |_| token.access_token().secret().to_string())
- });
+ let fut = load_refresh_token(&token_cache_file).and_then(
+ move |refresh_token| {
+ // XXX
+ let refresh_token = refresh_token.unwrap();
+ self.client
+ .exchange_refresh_token(&oauth2::RefreshToken::new(
+ refresh_token,
+ ))
+ .request_future(oauth2::reqwest::future_http_client)
+ .map_err(|e| {
+ let msg = stringify_oauth2_http_error(&e);
+ Error::ExchangeRefreshToken { msg }
+ })
+ .and_then(move |token| {
+ cache_refresh_token(token_cache_file, &token).map(
+ move |_| {
+ token.access_token().secret().to_string()
+ },
+ )
+ })
+ },
+ );
Box::new(fut)
}
- fn get_username_from_access_token(
+ pub fn server_token_file(
&self,
- token: &str,
- ) -> Box<dyn futures::Future<Item = String, Error = Error> + Send>;
+ must_exist: bool,
+ ) -> Option<std::path::PathBuf> {
+ let name = format!("server-oauth-{}", self.user_id);
+ crate::dirs::Dirs::new().data_file(&name, must_exist)
+ }
+}
+
+fn load_refresh_token(
+ token_cache_file: &std::path::Path,
+) -> Box<dyn futures::Future<Item = Option<String>, Error = Error> + Send> {
+ let token_cache_file = token_cache_file.to_path_buf();
+ Box::new(
+ tokio::fs::File::open(token_cache_file.clone())
+ .with_context(move || crate::error::OpenFile {
+ filename: token_cache_file.to_string_lossy().to_string(),
+ })
+ .and_then(|file| {
+ tokio::io::lines(std::io::BufReader::new(file))
+ .into_future()
+ .map_err(|(e, _)| e)
+ .context(crate::error::ReadFile)
+ })
+ .map(|(refresh_token, _)| refresh_token),
+ )
}
fn cache_refresh_token(
@@ -107,6 +135,22 @@ pub struct Config {
}
impl Config {
+ pub fn new(
+ client_id: String,
+ client_secret: String,
+ auth_url: url::Url,
+ token_url: url::Url,
+ redirect_url: url::Url,
+ ) -> Self {
+ Self {
+ client_id,
+ client_secret,
+ auth_url,
+ token_url,
+ redirect_url,
+ }
+ }
+
pub fn set_redirect_url(&mut self, url: url::Url) {
self.redirect_url = url;
}
diff --git a/teleterm/src/oauth/recurse_center.rs b/teleterm/src/oauth/recurse_center.rs
deleted file mode 100644
index 6ab41ba..0000000
--- a/teleterm/src/oauth/recurse_center.rs
+++ /dev/null
@@ -1,90 +0,0 @@
-use crate::prelude::*;
-
-pub struct RecurseCenter {
- client: oauth2::basic::BasicClient,
- user_id: String,
-}
-
-impl RecurseCenter {
- pub fn new(config: super::Config, user_id: &str) -> Self {
- Self {
- client: config.into_basic_client(),
- user_id: user_id.to_string(),
- }
- }
-
- pub fn config(
- client_id: &str,
- client_secret: &str,
- redirect_url: &url::Url,
- ) -> super::Config {
- super::Config {
- client_id: client_id.to_string(),
- client_secret: client_secret.to_string(),
- auth_url: url::Url::parse(
- "https://www.recurse.com/oauth/authorize",
- )
- .unwrap(),
- token_url: url::Url::parse("https://www.recurse.com/oauth/token")
- .unwrap(),
- redirect_url: redirect_url.clone(),
- }
- }
-}
-
-impl super::Oauth for RecurseCenter {
- fn client(&self) -> &oauth2::basic::BasicClient {
- &self.client
- }
-
- fn user_id(&self) -> &str {
- &self.user_id
- }
-
- fn name(&self) -> &str {
- crate::protocol::AuthType::RecurseCenter.name()
- }
-
- fn get_username_from_access_token(
- &self,
- token: &str,
- ) -> Box<dyn futures::Future<Item = String, Error = Error> + Send> {
- let fut = reqwest::r#async::Client::new()
- .get("https://www.recurse.com/api/v1/profiles/me")
- .bearer_auth(token)
- .send()
- .context(crate::error::GetRecurseCenterProfile)
- .and_then(|mut res| res.json().context(crate::error::ParseJson))
- .map(|user: User| user.name());
- Box::new(fut)
- }
-}
-
-#[derive(serde::Deserialize)]
-struct User {
- name: String,
- stints: Vec<Stint>,
-}
-
-#[derive(serde::Deserialize)]
-struct Stint {
- batch: Option<Batch>,
- start_date: String,
-}
-
-#[derive(serde::Deserialize)]
-struct Batch {
- short_name: String,
-}
-
-impl User {
- fn name(&self) -> String {
- let latest_stint =
- self.stints.iter().max_by_key(|s| &s.start_date).unwrap();
- if let Some(batch) = &latest_stint.batch {
- format!("{} ({})", self.name, batch.short_name)
- } else {
- self.name.to_string()
- }
- }
-}
diff --git a/teleterm/src/protocol.rs b/teleterm/src/protocol.rs
index dbb69a0..b30b1f2 100644
--- a/teleterm/src/protocol.rs
+++ b/teleterm/src/protocol.rs
@@ -144,20 +144,14 @@ impl AuthType {
self,
config: &crate::oauth::Config,
id: Option<&str>,
- ) -> Option<Box<dyn crate::oauth::Oauth + Send>> {
- match self {
- Self::RecurseCenter => {
- Some(Box::new(crate::oauth::RecurseCenter::new(
- config.clone(),
- &id.map_or_else(
- || format!("{}", uuid::Uuid::new_v4()),
- std::string::ToString::to_string,
- ),
- )))
- }
- ty if !ty.is_oauth() => None,
- _ => unreachable!(),
- }
+ ) -> Option<crate::oauth::Oauth> {
+ Some(crate::oauth::Oauth::new(
+ config.clone(),
+ id.map_or_else(
+ || format!("{}", uuid::Uuid::new_v4()),
+ std::string::ToString::to_string,
+ ),
+ ))
}
}
@@ -213,7 +207,7 @@ impl Auth {
pub fn oauth_client(
&self,
config: &crate::oauth::Config,
- ) -> Option<Box<dyn crate::oauth::Oauth + Send>> {
+ ) -> Option<crate::oauth::Oauth> {
self.auth_type().oauth_client(config, self.oauth_id())
}
diff --git a/teleterm/src/server.rs b/teleterm/src/server.rs
index 36186d1..6437c40 100644
--- a/teleterm/src/server.rs
+++ b/teleterm/src/server.rs
@@ -55,6 +55,7 @@ struct TerminalInfo {
enum ConnectionState {
Accepted,
LoggingIn {
+ auth_type: crate::protocol::AuthType,
term_info: TerminalInfo,
},
LoggedIn {
@@ -88,7 +89,17 @@ impl ConnectionState {
}
}
- fn term_info(&mut self) -> Option<&TerminalInfo> {
+ fn auth_type(&self) -> Option<crate::protocol::AuthType> {
+ match self {
+ Self::Accepted => None,
+ Self::LoggingIn { auth_type, .. } => Some(*auth_type),
+ Self::LoggedIn { .. } => None,
+ Self::Streaming { .. } => None,
+ Self::Watching { .. } => None,
+ }
+ }
+
+ fn term_info(&self) -> Option<&TerminalInfo> {
match self {
Self::Accepted => None,
Self::LoggingIn { term_info, .. } => Some(term_info),
@@ -159,11 +170,13 @@ impl ConnectionState {
fn login_oauth_start(
&mut self,
+ auth_type: crate::protocol::AuthType,
term_type: &str,
size: crate::term::Size,
) {
if let Self::Accepted = self {
*self = Self::LoggingIn {
+ auth_type,
term_info: TerminalInfo {
term: term_type.to_string(),
size,
@@ -218,7 +231,7 @@ struct Connection<
closed: bool,
state: ConnectionState,
last_activity: std::time::Instant,
- oauth_client: Option<Box<dyn crate::oauth::Oauth + Send>>,
+ oauth_client: Option<crate::oauth::Oauth>,
}
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
@@ -427,32 +440,19 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
.context(crate::error::AuthTypeMissingOauthConfig { ty })?;
let client = auth.oauth_client(config).unwrap();
- if let (Some(token_filename), true) =
- (client.server_token_file(true), auth.oauth_id().is_some())
+ if client.server_token_file(true).is_some()
+ && auth.oauth_id().is_some()
{
let term_type = term_type.to_string();
- let client = conn.oauth_client.take().unwrap();
- let fut = tokio::fs::File::open(token_filename.clone())
- .with_context(move || crate::error::OpenFile {
- filename: token_filename.to_string_lossy().to_string(),
- })
- .and_then(|file| {
- tokio::io::lines(std::io::BufReader::new(file))
- .into_future()
- .map_err(|(e, _)| e)
- .context(crate::error::ReadFile)
- })
- .and_then(|(refresh_token, _)| {
- // XXX unwrap here isn't super safe
- let refresh_token = refresh_token.unwrap();
- client
- .get_access_token_from_refresh_token(
- refresh_token.trim(),
+ let fut = client
+ .get_access_token_from_refresh_token()
+ .and_then(move |access_token| match ty {
+ crate::protocol::AuthType::RecurseCenter => {
+ crate::auth::recurse_center::get_username(
+ &access_token,
)
- .and_then(move |access_token| {
- client
- .get_username_from_access_token(&access_token)
- })
+ }
+ _ => unreachable!(),
})
.map(move |username| {
(
@@ -470,7 +470,7 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
} else {
conn.oauth_client = Some(client);
let client = conn.oauth_client.as_ref().unwrap();
- conn.state.login_oauth_start(term_type, size);
+ conn.state.login_oauth_start(ty, term_type, size);
let authorize_url = client.generate_authorize_url();
let user_id = client.user_id().to_string();
conn.send_message(crate::protocol::Message::oauth_cli_request(
@@ -644,11 +644,15 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
}
})?;
+ let ty = conn.state.auth_type().unwrap();
let term_info = conn.state.term_info().unwrap().clone();
let fut = client
.get_access_token_from_auth_code(code)
- .and_then(move |access_token| {
- client.get_username_from_access_token(&access_token)
+ .and_then(move |access_token| match ty {
+ crate::protocol::AuthType::RecurseCenter => {
+ crate::auth::recurse_center::get_username(&access_token)
+ }
+ _ => unreachable!(),
})
.map(|username| {
(