diff options
Diffstat (limited to 'src/bin/rbw-agent/actions.rs')
-rw-r--r-- | src/bin/rbw-agent/actions.rs | 331 |
1 files changed, 207 insertions, 124 deletions
diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs index 1cc71c3..674442b 100644 --- a/src/bin/rbw-agent/actions.rs +++ b/src/bin/rbw-agent/actions.rs @@ -10,9 +10,7 @@ pub async fn register( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -33,7 +31,7 @@ pub async fn register( let client_id = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__id", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -43,7 +41,7 @@ pub async fn register( let client_secret = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__secret", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -61,10 +59,9 @@ pub async fn register( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -81,7 +78,7 @@ pub async fn register( pub async fn login( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { let db = load_db().await.unwrap_or_else(|_| rbw::db::Db::new()); @@ -90,9 +87,7 @@ pub async fn login( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -102,7 +97,7 @@ pub async fn login( let email = config_email().await?; let mut err_msg = None; - for i in 1_u8..=3 { + 'attempts: for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -113,7 +108,7 @@ pub async fn login( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, true, @@ -126,55 +121,70 @@ pub async fn login( Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) => { login_success( - sock, - state, + state.clone(), access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, password, db, email, ) .await?; - break; + break 'attempts; } Err(rbw::error::Error::TwoFactorRequired { providers }) => { - if providers.contains( - &rbw::api::TwoFactorProviderType::Authenticator, - ) { - let ( - access_token, - refresh_token, - iterations, - protected_key, - ) = two_factor( - tty, - &email, - password.clone(), - rbw::api::TwoFactorProviderType::Authenticator, - ) - .await?; - login_success( - sock, - state, - access_token, - refresh_token, - iterations, - protected_key, - password, - db, - email, - ) - .await?; - break; - } else { - return Err(anyhow::anyhow!("TODO")); + let supported_types = vec![ + rbw::api::TwoFactorProviderType::Authenticator, + rbw::api::TwoFactorProviderType::Yubikey, + rbw::api::TwoFactorProviderType::Email, + ]; + + for provider in supported_types { + if providers.contains(&provider) { + let ( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + ) = two_factor( + tty, + &email, + password.clone(), + provider, + ) + .await?; + login_success( + state.clone(), + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + password, + db, + email, + ) + .await?; + break 'attempts; + } } + return Err(anyhow::anyhow!("TODO")); } Err(rbw::error::Error::IncorrectPassword { message }) => { if i == 3 { @@ -182,10 +192,9 @@ pub async fn login( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -205,7 +214,15 @@ async fn two_factor( email: &str, password: rbw::locked::Password, provider: rbw::api::TwoFactorProviderType, -) -> anyhow::Result<(String, String, u32, String)> { +) -> anyhow::Result<( + String, + String, + rbw::api::KdfType, + u32, + Option<u32>, + Option<u32>, + String, +)> { let mut err_msg = None; for i in 1_u8..=3 { let err = if i > 1 { @@ -217,11 +234,11 @@ async fn two_factor( }; let code = rbw::pinentry::getpin( &config_pinentry().await?, - "Authenticator App", - "Enter the 6 digit verification code from your authenticator app.", + provider.header(), + provider.message(), err.as_deref(), tty, - true, + provider.grab(), ) .await .context("failed to read code from pinentry")?; @@ -235,11 +252,22 @@ async fn two_factor( ) .await { - Ok((access_token, refresh_token, iterations, protected_key)) => { + Ok(( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + )) => { return Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) } @@ -249,10 +277,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } // can get this if the user passes an empty string Err(rbw::error::Error::TwoFactorRequired { .. }) => { @@ -262,10 +289,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -278,11 +304,13 @@ async fn two_factor( } async fn login_success( - sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, access_token: String, refresh_token: String, + kdf: rbw::api::KdfType, iterations: u32, + memory: Option<u32>, + parallelism: Option<u32>, protected_key: String, password: rbw::locked::Password, mut db: rbw::db::Db, @@ -290,35 +318,37 @@ async fn login_success( ) -> anyhow::Result<()> { db.access_token = Some(access_token.to_string()); db.refresh_token = Some(refresh_token.to_string()); + db.kdf = Some(kdf); db.iterations = Some(iterations); + db.memory = memory; + db.parallelism = parallelism; db.protected_key = Some(protected_key.to_string()); save_db(&db).await?; - sync(sock, false).await?; + sync(None, state.clone()).await?; let db = load_db().await?; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let res = rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await; + ); match res { Ok((keys, org_keys)) => { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); } @@ -330,39 +360,40 @@ async fn login_success( pub async fn unlock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { let db = load_db().await?; - let iterations = if let Some(iterations) = db.iterations { - iterations - } else { + let Some(kdf) = db.kdf else { + return Err(anyhow::anyhow!("failed to find kdf type in db")); + }; + + let Some(iterations) = db.iterations else { return Err(anyhow::anyhow!( "failed to find number of iterations in db" )); }; - let protected_key = if let Some(protected_key) = db.protected_key { - protected_key - } else { + + let memory = db.memory; + let parallelism = db.parallelism; + + let Some(protected_key) = db.protected_key else { return Err(anyhow::anyhow!( "failed to find protected key in db" )); }; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let email = config_email().await?; let mut err_msg = None; - for i in 1u8..=3 { + for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -373,7 +404,10 @@ pub async fn unlock( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - "Unlock the local database", + &format!( + "Unlock the local database for '{}'", + rbw::dirs::profile() + ), err.as_deref(), tty, true, @@ -383,13 +417,14 @@ pub async fn unlock( match rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await - { + ) { Ok((keys, org_keys)) => { unlock_success(state, keys, org_keys).await?; break; @@ -400,10 +435,9 @@ pub async fn unlock( message, }) .context("failed to unlock database"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => return Err(e).context("failed to unlock database"), } @@ -416,11 +450,11 @@ pub async fn unlock( } async fn unlock_success( - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, keys: rbw::locked::Keys, org_keys: std::collections::HashMap<String, rbw::locked::Keys>, ) -> anyhow::Result<()> { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); Ok(()) @@ -428,9 +462,9 @@ async fn unlock_success( pub async fn lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - state.write().await.clear(); + state.lock().await.clear(); respond_ack(sock).await?; @@ -439,10 +473,9 @@ pub async fn lock( pub async fn check_lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, - _tty: Option<&str>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { return Err(anyhow::anyhow!("agent is locked")); } @@ -452,8 +485,8 @@ pub async fn check_lock( } pub async fn sync( - sock: &mut crate::sock::Sock, - ack: bool, + sock: Option<&mut crate::sock::Sock>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { let mut db = load_db().await?; @@ -482,7 +515,11 @@ pub async fn sync( db.entries = entries; save_db(&db).await?; - if ack { + if let Err(e) = subscribe_to_notifications(state.clone()).await { + eprintln!("failed to subscribe to notifications: {e}"); + } + + if let Some(sock) = sock { respond_ack(sock).await?; } @@ -491,14 +528,12 @@ pub async fn sync( pub async fn decrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, cipherstring: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find decryption keys in in-memory state" )); @@ -519,14 +554,12 @@ pub async fn decrypt( pub async fn encrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, plaintext: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find encryption keys in in-memory state" )); @@ -542,6 +575,25 @@ pub async fn encrypt( Ok(()) } +pub async fn clipboard_store( + sock: &mut crate::sock::Sock, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, + text: &str, +) -> anyhow::Result<()> { + state + .lock() + .await + .clipboard + .set_contents(text.to_owned()) + .map_err(|e| { + anyhow::anyhow!("couldn't store value to clipboard: {e}") + })?; + + respond_ack(sock).await?; + + Ok(()) +} + pub async fn version(sock: &mut crate::sock::Sock) -> anyhow::Result<()> { sock.send(&rbw::protocol::Response::Version { version: rbw::protocol::version(), @@ -579,11 +631,10 @@ async fn respond_encrypt( async fn config_email() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; - if let Some(email) = config.email { - Ok(email) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + Ok, + ) } async fn load_db() -> anyhow::Result<rbw::db::Db> { @@ -617,3 +668,35 @@ async fn config_pinentry() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; Ok(config.pinentry) } + +pub async fn subscribe_to_notifications( + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, +) -> anyhow::Result<()> { + if state.lock().await.notifications_handler.is_connected() { + return Ok(()); + } + + let config = rbw::config::Config::load_async() + .await + .context("Config is missing")?; + let email = config.email.clone().context("Config is missing email")?; + let db = rbw::db::Db::load_async(config.server_name().as_str(), &email) + .await?; + let access_token = + db.access_token.context("Error getting access token")?; + + let websocket_url = format!( + "{}/hub?access_token={}", + config.notifications_url(), + access_token + ) + .replace("https://", "wss://"); + + let mut state = state.lock().await; + state + .notifications_handler + .connect(websocket_url) + .await + .err() + .map_or_else(|| Ok(()), |err| Err(anyhow::anyhow!(err.to_string()))) +} |