diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/actions.rs | 123 | ||||
-rw-r--r-- | src/api.rs | 556 | ||||
-rw-r--r-- | src/base64.rs | 15 | ||||
-rw-r--r-- | src/bin/rbw-agent/actions.rs | 331 | ||||
-rw-r--r-- | src/bin/rbw-agent/agent.rs | 205 | ||||
-rw-r--r-- | src/bin/rbw-agent/daemon.rs | 59 | ||||
-rw-r--r-- | src/bin/rbw-agent/debugger.rs | 2 | ||||
-rw-r--r-- | src/bin/rbw-agent/main.rs | 31 | ||||
-rw-r--r-- | src/bin/rbw-agent/notifications.rs | 174 | ||||
-rw-r--r-- | src/bin/rbw-agent/sock.rs | 5 | ||||
-rw-r--r-- | src/bin/rbw-agent/timeout.rs | 66 | ||||
-rw-r--r-- | src/bin/rbw/actions.rs | 48 | ||||
-rw-r--r-- | src/bin/rbw/commands.rs | 1687 | ||||
-rw-r--r-- | src/bin/rbw/main.rs | 230 | ||||
-rw-r--r-- | src/cipherstring.rs | 118 | ||||
-rw-r--r-- | src/config.rs | 119 | ||||
-rw-r--r-- | src/db.rs | 8 | ||||
-rw-r--r-- | src/dirs.rs | 56 | ||||
-rw-r--r-- | src/edit.rs | 19 | ||||
-rw-r--r-- | src/error.rs | 37 | ||||
-rw-r--r-- | src/identity.rs | 53 | ||||
-rw-r--r-- | src/lib.rs | 19 | ||||
-rw-r--r-- | src/locked.rs | 14 | ||||
-rw-r--r-- | src/pinentry.rs | 29 | ||||
-rw-r--r-- | src/protocol.rs | 7 | ||||
-rw-r--r-- | src/pwgen.rs | 3 |
26 files changed, 3115 insertions, 899 deletions
diff --git a/src/actions.rs b/src/actions.rs index 02ec854..7ee1fa4 100644 --- a/src/actions.rs +++ b/src/actions.rs @@ -4,11 +4,11 @@ pub async fn register( email: &str, apikey: crate::locked::ApiKey, ) -> Result<()> { - let config = crate::config::Config::load_async().await?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, config) = api_client_async().await?; - client.register(email, &config.device_id, &apikey).await?; + client + .register(email, &crate::config::device_id(&config).await?, &apikey) + .await?; Ok(()) } @@ -18,40 +18,70 @@ pub async fn login( password: crate::locked::Password, two_factor_token: Option<&str>, two_factor_provider: Option<crate::api::TwoFactorProviderType>, -) -> Result<(String, String, u32, String)> { - let config = crate::config::Config::load_async().await?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); +) -> Result<( + String, + String, + crate::api::KdfType, + u32, + Option<u32>, + Option<u32>, + String, +)> { + let (client, config) = api_client_async().await?; + let (kdf, iterations, memory, parallelism) = + client.prelogin(email).await?; - let iterations = client.prelogin(email).await?; - let identity = - crate::identity::Identity::new(email, &password, iterations)?; + let identity = crate::identity::Identity::new( + email, + &password, + kdf, + iterations, + memory, + parallelism, + )?; let (access_token, refresh_token, protected_key) = client .login( email, - &config.device_id, + &crate::config::device_id(&config).await?, &identity.master_password_hash, two_factor_token, two_factor_provider, ) .await?; - Ok((access_token, refresh_token, iterations, protected_key)) + Ok(( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + )) } -pub async fn unlock( +pub fn unlock<S: std::hash::BuildHasher>( email: &str, password: &crate::locked::Password, + kdf: crate::api::KdfType, iterations: u32, + memory: Option<u32>, + parallelism: Option<u32>, protected_key: &str, protected_private_key: &str, - protected_org_keys: &std::collections::HashMap<String, String>, + protected_org_keys: &std::collections::HashMap<String, String, S>, ) -> Result<( crate::locked::Keys, std::collections::HashMap<String, crate::locked::Keys>, )> { - let identity = - crate::identity::Identity::new(email, password, iterations)?; + let identity = crate::identity::Identity::new( + email, + password, + kdf, + iterations, + memory, + parallelism, + )?; let protected_key = crate::cipherstring::CipherString::new(protected_key)?; @@ -119,9 +149,7 @@ async fn sync_once( std::collections::HashMap<String, String>, Vec<crate::db::Entry>, )> { - let config = crate::config::Config::load_async().await?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client_async().await?; client.sync(access_token).await } @@ -145,10 +173,8 @@ fn add_once( notes: Option<&str>, folder_id: Option<&str>, ) -> Result<()> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); - client.add(access_token, name, data, notes, folder_id.as_deref())?; + let (client, _) = api_client()?; + client.add(access_token, name, data, notes, folder_id)?; Ok(()) } @@ -159,6 +185,7 @@ pub fn edit( org_id: Option<&str>, name: &str, data: &crate::db::EntryData, + fields: &[crate::db::Field], notes: Option<&str>, folder_uuid: Option<&str>, history: &[crate::db::HistoryEntry], @@ -170,6 +197,7 @@ pub fn edit( org_id, name, data, + fields, notes, folder_uuid, history, @@ -183,19 +211,19 @@ fn edit_once( org_id: Option<&str>, name: &str, data: &crate::db::EntryData, + fields: &[crate::db::Field], notes: Option<&str>, folder_uuid: Option<&str>, history: &[crate::db::HistoryEntry], ) -> Result<()> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.edit( access_token, id, org_id, name, data, + fields, notes, folder_uuid, history, @@ -214,9 +242,7 @@ pub fn remove( } fn remove_once(access_token: &str, id: &str) -> Result<()> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.remove(access_token, id)?; Ok(()) } @@ -231,9 +257,7 @@ pub fn list_folders( } fn list_folders_once(access_token: &str) -> Result<Vec<(String, String)>> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.folders(access_token) } @@ -248,9 +272,7 @@ pub fn create_folder( } fn create_folder_once(access_token: &str, name: &str) -> Result<String> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.create_folder(access_token, name) } @@ -300,15 +322,32 @@ where } fn exchange_refresh_token(refresh_token: &str) -> Result<String> { - let config = crate::config::Config::load()?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.exchange_refresh_token(refresh_token) } async fn exchange_refresh_token_async(refresh_token: &str) -> Result<String> { - let config = crate::config::Config::load_async().await?; - let client = - crate::api::Client::new(&config.base_url(), &config.identity_url()); + let (client, _) = api_client()?; client.exchange_refresh_token_async(refresh_token).await } + +fn api_client() -> Result<(crate::api::Client, crate::config::Config)> { + let config = crate::config::Config::load()?; + let client = crate::api::Client::new( + &config.base_url(), + &config.identity_url(), + config.client_cert_path(), + ); + Ok((client, config)) +} + +async fn api_client_async( +) -> Result<(crate::api::Client, crate::config::Config)> { + let config = crate::config::Config::load_async().await?; + let client = crate::api::Client::new( + &config.base_url(), + &config.identity_url(), + config.client_cert_path(), + ); + Ok((client, config)) +} @@ -1,9 +1,15 @@ +// serde_repr generates some as conversions that we can't seem to silence from +// here, unfortunately +#![allow(clippy::as_conversions)] + use crate::prelude::*; use crate::json::{ DeserializeJsonWithPath as _, DeserializeJsonWithPathAsync as _, }; +use tokio::io::AsyncReadExt as _; + #[derive( serde_repr::Serialize_repr, serde_repr::Deserialize_repr, @@ -35,7 +41,7 @@ impl std::fmt::Display for UriMatchType { RegularExpression => "regular_expression", Never => "never", }; - write!(f, "{}", s) + write!(f, "{s}") } } @@ -51,6 +57,33 @@ pub enum TwoFactorProviderType { WebAuthn = 7, } +impl TwoFactorProviderType { + #[must_use] + pub fn message(&self) -> &str { + match *self { + Self::Authenticator => "Enter the 6 digit verification code from your authenticator app.", + Self::Yubikey => "Insert your Yubikey and push the button.", + Self::Email => "Enter the PIN you received via email.", + _ => "Enter the code." + } + } + + #[must_use] + pub fn header(&self) -> &str { + match *self { + Self::Authenticator => "Authenticator App", + Self::Yubikey => "Yubikey", + Self::Email => "Email Code", + _ => "Two Factor Authentication", + } + } + + #[must_use] + pub fn grab(&self) -> bool { + !matches!(self, Self::Email) + } +} + impl<'de> serde::Deserialize<'de> for TwoFactorProviderType { fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> where @@ -107,7 +140,7 @@ impl std::convert::TryFrom<u64> for TwoFactorProviderType { 6 => Ok(Self::OrganizationDuo), 7 => Ok(Self::WebAuthn), _ => Err(Error::InvalidTwoFactorProvider { - ty: format!("{}", ty), + ty: format!("{ty}"), }), } } @@ -131,6 +164,96 @@ impl std::str::FromStr for TwoFactorProviderType { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum KdfType { + Pbkdf2 = 0, + Argon2id = 1, +} + +impl<'de> serde::Deserialize<'de> for KdfType { + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + struct KdfTypeVisitor; + impl<'de> serde::de::Visitor<'de> for KdfTypeVisitor { + type Value = KdfType; + + fn expecting( + &self, + formatter: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + formatter.write_str("kdf id") + } + + fn visit_str<E>( + self, + value: &str, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + value.parse().map_err(serde::de::Error::custom) + } + + fn visit_u64<E>( + self, + value: u64, + ) -> std::result::Result<Self::Value, E> + where + E: serde::de::Error, + { + std::convert::TryFrom::try_from(value) + .map_err(serde::de::Error::custom) + } + } + + deserializer.deserialize_any(KdfTypeVisitor) + } +} + +impl std::convert::TryFrom<u64> for KdfType { + type Error = Error; + + fn try_from(ty: u64) -> Result<Self> { + match ty { + 0 => Ok(Self::Pbkdf2), + 1 => Ok(Self::Argon2id), + _ => Err(Error::InvalidKdfType { + ty: format!("{ty}"), + }), + } + } +} + +impl std::str::FromStr for KdfType { + type Err = Error; + + fn from_str(ty: &str) -> Result<Self> { + match ty { + "0" => Ok(Self::Pbkdf2), + "1" => Ok(Self::Argon2id), + _ => Err(Error::InvalidKdfType { ty: ty.to_string() }), + } + } +} + +impl serde::Serialize for KdfType { + fn serialize<S>( + &self, + serializer: S, + ) -> std::result::Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + let s = match self { + Self::Pbkdf2 => "0", + Self::Argon2id => "1", + }; + serializer.serialize_str(s) + } +} + #[derive(serde::Serialize, Debug)] struct PreloginReq { email: String, @@ -138,10 +261,14 @@ struct PreloginReq { #[derive(serde::Deserialize, Debug)] struct PreloginRes { - #[serde(rename = "Kdf")] - kdf: u32, - #[serde(rename = "KdfIterations")] + #[serde(rename = "Kdf", alias = "kdf")] + kdf: KdfType, + #[serde(rename = "KdfIterations", alias = "kdfIterations")] kdf_iterations: u32, + #[serde(rename = "KdfMemory", alias = "kdfMemory")] + kdf_memory: Option<u32>, + #[serde(rename = "KdfParallelism", alias = "kdfParallelism")] + kdf_parallelism: Option<u32>, } #[derive(serde::Serialize, Debug)] @@ -169,10 +296,8 @@ struct ConnectPasswordReq { #[derive(serde::Deserialize, Debug)] struct ConnectPasswordRes { access_token: String, - expires_in: u32, - token_type: String, refresh_token: String, - #[serde(rename = "Key")] + #[serde(rename = "Key", alias = "key")] key: String, } @@ -180,15 +305,15 @@ struct ConnectPasswordRes { struct ConnectErrorRes { error: String, error_description: Option<String>, - #[serde(rename = "ErrorModel")] + #[serde(rename = "ErrorModel", alias = "errorModel")] error_model: Option<ConnectErrorResErrorModel>, - #[serde(rename = "TwoFactorProviders")] + #[serde(rename = "TwoFactorProviders", alias = "twoFactorProviders")] two_factor_providers: Option<Vec<TwoFactorProviderType>>, } #[derive(serde::Deserialize, Debug)] struct ConnectErrorResErrorModel { - #[serde(rename = "Message")] + #[serde(rename = "Message", alias = "message")] message: String, } @@ -202,46 +327,43 @@ struct ConnectRefreshTokenReq { #[derive(serde::Deserialize, Debug)] struct ConnectRefreshTokenRes { access_token: String, - expires_in: u32, - token_type: String, - refresh_token: String, } #[derive(serde::Deserialize, Debug)] struct SyncRes { - #[serde(rename = "Ciphers")] + #[serde(rename = "Ciphers", alias = "ciphers")] ciphers: Vec<SyncResCipher>, - #[serde(rename = "Profile")] + #[serde(rename = "Profile", alias = "profile")] profile: SyncResProfile, - #[serde(rename = "Folders")] + #[serde(rename = "Folders", alias = "folders")] folders: Vec<SyncResFolder>, } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct SyncResCipher { - #[serde(rename = "Id")] + #[serde(rename = "Id", alias = "id")] id: String, - #[serde(rename = "FolderId")] + #[serde(rename = "FolderId", alias = "folderId")] folder_id: Option<String>, - #[serde(rename = "OrganizationId")] + #[serde(rename = "OrganizationId", alias = "organizationId")] organization_id: Option<String>, - #[serde(rename = "Name")] + #[serde(rename = "Name", alias = "name")] name: String, - #[serde(rename = "Login")] + #[serde(rename = "Login", alias = "login")] login: Option<CipherLogin>, - #[serde(rename = "Card")] + #[serde(rename = "Card", alias = "card")] card: Option<CipherCard>, - #[serde(rename = "Identity")] + #[serde(rename = "Identity", alias = "identity")] identity: Option<CipherIdentity>, - #[serde(rename = "SecureNote")] + #[serde(rename = "SecureNote", alias = "secureNote")] secure_note: Option<CipherSecureNote>, - #[serde(rename = "Notes")] + #[serde(rename = "Notes", alias = "notes")] notes: Option<String>, - #[serde(rename = "PasswordHistory")] + #[serde(rename = "PasswordHistory", alias = "passwordHistory")] password_history: Option<Vec<SyncResPasswordHistory>>, - #[serde(rename = "Fields")] - fields: Option<Vec<SyncResField>>, - #[serde(rename = "DeletedDate")] + #[serde(rename = "Fields", alias = "fields")] + fields: Option<Vec<CipherField>>, + #[serde(rename = "DeletedDate", alias = "deletedDate")] deleted_date: Option<String>, } @@ -253,32 +375,37 @@ impl SyncResCipher { if self.deleted_date.is_some() { return None; } - let history = if let Some(history) = &self.password_history { - history - .iter() - .filter_map(|entry| { - // Gets rid of entries with a non-existent password - entry.password.clone().map(|p| crate::db::HistoryEntry { - last_used_date: entry.last_used_date.clone(), - password: p, - }) - }) - .collect() - } else { - vec![] - }; + let history = + self.password_history + .as_ref() + .map_or_else(Vec::new, |history| { + history + .iter() + .filter_map(|entry| { + // Gets rid of entries with a non-existent + // password + entry.password.clone().map(|p| { + crate::db::HistoryEntry { + last_used_date: entry + .last_used_date + .clone(), + password: p, + } + }) + }) + .collect() + }); - let (folder, folder_id) = if let Some(folder_id) = &self.folder_id { - let mut folder_name = None; - for folder in folders { - if &folder.id == folder_id { - folder_name = Some(folder.name.clone()); + let (folder, folder_id) = + self.folder_id.as_ref().map_or((None, None), |folder_id| { + let mut folder_name = None; + for folder in folders { + if &folder.id == folder_id { + folder_name = Some(folder.name.clone()); + } } - } - (folder_name, Some(folder_id)) - } else { - (None, None) - }; + (folder_name, Some(folder_id)) + }); let data = if let Some(login) = &self.login { crate::db::EntryData::Login { username: login.username.clone(), @@ -332,17 +459,17 @@ impl SyncResCipher { } else { return None; }; - let fields = if let Some(fields) = &self.fields { + let fields = self.fields.as_ref().map_or_else(Vec::new, |fields| { fields .iter() .map(|field| crate::db::Field { + ty: field.ty, name: field.name.clone(), value: field.value.clone(), + linked_id: field.linked_id, }) .collect() - } else { - vec![] - }; + }); Some(crate::db::Entry { id: self.id.clone(), org_id: self.organization_id.clone(), @@ -359,104 +486,173 @@ impl SyncResCipher { #[derive(serde::Deserialize, Debug)] struct SyncResProfile { - #[serde(rename = "Key")] + #[serde(rename = "Key", alias = "key")] key: String, - #[serde(rename = "PrivateKey")] + #[serde(rename = "PrivateKey", alias = "privateKey")] private_key: String, - #[serde(rename = "Organizations")] + #[serde(rename = "Organizations", alias = "organizations")] organizations: Vec<SyncResProfileOrganization>, } #[derive(serde::Deserialize, Debug)] struct SyncResProfileOrganization { - #[serde(rename = "Id")] + #[serde(rename = "Id", alias = "id")] id: String, - #[serde(rename = "Key")] + #[serde(rename = "Key", alias = "key")] key: String, } #[derive(serde::Deserialize, Debug, Clone)] struct SyncResFolder { - #[serde(rename = "Id")] + #[serde(rename = "Id", alias = "id")] id: String, - #[serde(rename = "Name")] + #[serde(rename = "Name", alias = "name")] name: String, } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct CipherLogin { - #[serde(rename = "Username")] + #[serde(rename = "Username", alias = "username")] username: Option<String>, - #[serde(rename = "Password")] + #[serde(rename = "Password", alias = "password")] password: Option<String>, - #[serde(rename = "Totp")] + #[serde(rename = "Totp", alias = "totp")] totp: Option<String>, - #[serde(rename = "Uris")] + #[serde(rename = "Uris", alias = "uris")] uris: Option<Vec<CipherLoginUri>>, } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct CipherLoginUri { - #[serde(rename = "Uri")] + #[serde(rename = "Uri", alias = "uri")] uri: Option<String>, - #[serde(rename = "Match")] + #[serde(rename = "Match", alias = "match")] match_type: Option<UriMatchType>, } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct CipherCard { - #[serde(rename = "CardholderName")] + #[serde(rename = "CardholderName", alias = "cardHolderName")] cardholder_name: Option<String>, - #[serde(rename = "Number")] + #[serde(rename = "Number", alias = "number")] number: Option<String>, - #[serde(rename = "Brand")] + #[serde(rename = "Brand", alias = "brand")] brand: Option<String>, - #[serde(rename = "ExpMonth")] + #[serde(rename = "ExpMonth", alias = "expMonth")] exp_month: Option<String>, - #[serde(rename = "ExpYear")] + #[serde(rename = "ExpYear", alias = "expYear")] exp_year: Option<String>, - #[serde(rename = "Code")] + #[serde(rename = "Code", alias = "code")] code: Option<String>, } #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct CipherIdentity { - #[serde(rename = "Title")] + #[serde(rename = "Title", alias = "title")] title: Option<String>, - #[serde(rename = "FirstName")] + #[serde(rename = "FirstName", alias = "firstName")] first_name: Option<String>, - #[serde(rename = "MiddleName")] + #[serde(rename = "MiddleName", alias = "middleName")] middle_name: Option<String>, - #[serde(rename = "LastName")] + #[serde(rename = "LastName", alias = "lastName")] last_name: Option<String>, - #[serde(rename = "Address1")] + #[serde(rename = "Address1", alias = "address1")] address1: Option<String>, - #[serde(rename = "Address2")] + #[serde(rename = "Address2", alias = "address2")] address2: Option<String>, - #[serde(rename = "Address3")] + #[serde(rename = "Address3", alias = "address3")] address3: Option<String>, - #[serde(rename = "City")] + #[serde(rename = "City", alias = "city")] city: Option<String>, - #[serde(rename = "State")] + #[serde(rename = "State", alias = "state")] state: Option<String>, - #[serde(rename = "PostalCode")] + #[serde(rename = "PostalCode", alias = "postalCode")] postal_code: Option<String>, - #[serde(rename = "Country")] + #[serde(rename = "Country", alias = "country")] country: Option<String>, - #[serde(rename = "Phone")] + #[serde(rename = "Phone", alias = "phone")] phone: Option<String>, - #[serde(rename = "Email")] + #[serde(rename = "Email", alias = "email")] email: Option<String>, - #[serde(rename = "SSN")] + #[serde(rename = "SSN", alias = "ssn")] ssn: Option<String>, - #[serde(rename = "LicenseNumber")] + #[serde(rename = "LicenseNumber", alias = "licenseNumber")] license_number: Option<String>, - #[serde(rename = "PassportNumber")] + #[serde(rename = "PassportNumber", alias = "passportNumber")] passport_number: Option<String>, - #[serde(rename = "Username")] + #[serde(rename = "Username", alias = "username")] username: Option<String>, } +#[derive( + serde_repr::Serialize_repr, + serde_repr::Deserialize_repr, + Debug, + Clone, + Copy, + PartialEq, + Eq, +)] +#[repr(u16)] +pub enum FieldType { + Text = 0, + Hidden = 1, + Boolean = 2, + Linked = 3, +} + +#[derive( + serde_repr::Serialize_repr, + serde_repr::Deserialize_repr, + Debug, + Clone, + Copy, + PartialEq, + Eq, +)] +#[repr(u16)] +pub enum LinkedIdType { + LoginUsername = 100, + LoginPassword = 101, + CardCardholderName = 300, + CardExpMonth = 301, + CardExpYear = 302, + CardCode = 303, + CardBrand = 304, + CardNumber = 305, + IdentityTitle = 400, + IdentityMiddleName = 401, + IdentityAddress1 = 402, + IdentityAddress2 = 403, + IdentityAddress3 = 404, + IdentityCity = 405, + IdentityState = 406, + IdentityPostalCode = 407, + IdentityCountry = 408, + IdentityCompany = 409, + IdentityEmail = 410, + IdentityPhone = 411, + IdentitySsn = 412, + IdentityUsername = 413, + IdentityPassportNumber = 414, + IdentityLicenseNumber = 415, + IdentityFirstName = 416, + IdentityLastName = 417, + IdentityFullName = 418, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] +struct CipherField { + #[serde(rename = "Type", alias = "type")] + ty: FieldType, + #[serde(rename = "Name", alias = "name")] + name: Option<String>, + #[serde(rename = "Value", alias = "value")] + value: Option<String>, + #[serde(rename = "LinkedId", alias = "linkedId")] + linked_id: Option<LinkedIdType>, +} + // this is just a name and some notes, both of which are already on the cipher // object #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] @@ -464,22 +660,12 @@ struct CipherSecureNote {} #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] struct SyncResPasswordHistory { - #[serde(rename = "LastUsedDate")] + #[serde(rename = "LastUsedDate", alias = "lastUsedDate")] last_used_date: String, - #[serde(rename = "Password")] + #[serde(rename = "Password", alias = "password")] password: Option<String>, } -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] -struct SyncResField { - #[serde(rename = "Type")] - ty: u32, - #[serde(rename = "Name")] - name: Option<String>, - #[serde(rename = "Value")] - value: Option<String>, -} - #[derive(serde::Serialize, Debug)] struct CiphersPostReq { #[serde(rename = "type")] @@ -508,6 +694,7 @@ struct CiphersPutReq { login: Option<CipherLogin>, card: Option<CipherCard>, identity: Option<CipherIdentity>, + fields: Vec<CipherField>, #[serde(rename = "secureNote")] secure_note: Option<CipherSecureNote>, #[serde(rename = "passwordHistory")] @@ -530,15 +717,15 @@ struct CiphersPutReqHistory { #[derive(serde::Deserialize, Debug)] struct FoldersRes { - #[serde(rename = "Data")] + #[serde(rename = "Data", alias = "data")] data: Vec<FoldersResData>, } #[derive(serde::Deserialize, Debug)] struct FoldersResData { - #[serde(rename = "Id")] + #[serde(rename = "Id", alias = "id")] id: String, - #[serde(rename = "Name")] + #[serde(rename = "Name", alias = "name")] name: String, } @@ -551,21 +738,70 @@ struct FoldersPostReq { pub struct Client { base_url: String, identity_url: String, + client_cert_path: Option<std::path::PathBuf>, } impl Client { - pub fn new(base_url: &str, identity_url: &str) -> Self { + #[must_use] + pub fn new( + base_url: &str, + identity_url: &str, + client_cert_path: Option<&std::path::Path>, + ) -> Self { Self { base_url: base_url.to_string(), identity_url: identity_url.to_string(), + client_cert_path: client_cert_path + .map(std::path::Path::to_path_buf), + } + } + + async fn reqwest_client(&self) -> Result<reqwest::Client> { + if let Some(client_cert_path) = self.client_cert_path.as_ref() { + let mut buf = Vec::new(); + let mut f = tokio::fs::File::open(client_cert_path) + .await + .map_err(|e| Error::LoadClientCert { + source: e, + file: client_cert_path.clone(), + })?; + f.read_to_end(&mut buf).await.map_err(|e| { + Error::LoadClientCert { + source: e, + file: client_cert_path.clone(), + } + })?; + let pem = reqwest::Identity::from_pem(&buf) + .map_err(|e| Error::CreateReqwestClient { source: e })?; + Ok(reqwest::Client::builder() + .user_agent(format!( + "{}/{}", + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_VERSION") + )) + .identity(pem) + .build() + .map_err(|e| Error::CreateReqwestClient { source: e })?) + } else { + Ok(reqwest::Client::builder() + .user_agent(format!( + "{}/{}", + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_VERSION") + )) + .build() + .map_err(|e| Error::CreateReqwestClient { source: e })?) } } - pub async fn prelogin(&self, email: &str) -> Result<u32> { + pub async fn prelogin( + &self, + email: &str, + ) -> Result<(KdfType, u32, Option<u32>, Option<u32>)> { let prelogin = PreloginReq { email: email.to_string(), }; - let client = reqwest::Client::new(); + let client = self.reqwest_client().await?; let res = client .post(&self.api_url("/accounts/prelogin")) .json(&prelogin) @@ -573,7 +809,12 @@ impl Client { .await .map_err(|source| Error::Reqwest { source })?; let prelogin_res: PreloginRes = res.json_with_path().await?; - Ok(prelogin_res.kdf_iterations) + Ok(( + prelogin_res.kdf, + prelogin_res.kdf_iterations, + prelogin_res.kdf_memory, + prelogin_res.kdf_parallelism, + )) } pub async fn register( @@ -596,22 +837,34 @@ impl Client { device_type: 8, device_identifier: device_id.to_string(), device_name: "rbw".to_string(), - device_push_token: "".to_string(), + device_push_token: String::new(), two_factor_token: None, two_factor_provider: None, }; - let client = reqwest::Client::new(); + let client = self.reqwest_client().await?; let res = client .post(&self.identity_url("/connect/token")) .form(&connect_req) .send() .await .map_err(|source| Error::Reqwest { source })?; - if let reqwest::StatusCode::OK = res.status() { + if res.status() == reqwest::StatusCode::OK { Ok(()) } else { let code = res.status().as_u16(); - Err(classify_login_error(&res.json_with_path().await?, code)) + match res.text().await { + Ok(body) => match body.clone().json_with_path() { + Ok(json) => Err(classify_login_error(&json, code)), + Err(e) => { + log::warn!("{e}: {body}"); + Err(Error::RequestFailed { status: code }) + } + }, + Err(e) => { + log::warn!("failed to read response body: {e}"); + Err(Error::RequestFailed { status: code }) + } + } } } @@ -626,30 +879,30 @@ impl Client { let connect_req = ConnectPasswordReq { grant_type: "password".to_string(), username: email.to_string(), - password: Some(base64::encode(password_hash.hash())), + password: Some(crate::base64::encode(password_hash.hash())), scope: "api offline_access".to_string(), client_id: "desktop".to_string(), client_secret: None, device_type: 8, device_identifier: device_id.to_string(), device_name: "rbw".to_string(), - device_push_token: "".to_string(), + device_push_token: String::new(), two_factor_token: two_factor_token .map(std::string::ToString::to_string), two_factor_provider: two_factor_provider.map(|ty| ty as u32), }; - let client = reqwest::Client::new(); + let client = self.reqwest_client().await?; let res = client .post(&self.identity_url("/connect/token")) .form(&connect_req) .header( "auth-email", - base64::encode_config(email, base64::URL_SAFE_NO_PAD), + crate::base64::encode_url_safe_no_pad(email), ) .send() .await .map_err(|source| Error::Reqwest { source })?; - if let reqwest::StatusCode::OK = res.status() { + if res.status() == reqwest::StatusCode::OK { let connect_res: ConnectPasswordRes = res.json_with_path().await?; Ok(( @@ -659,7 +912,19 @@ impl Client { )) } else { let code = res.status().as_u16(); - Err(classify_login_error(&res.json_with_path().await?, code)) + match res.text().await { + Ok(body) => match body.clone().json_with_path() { + Ok(json) => Err(classify_login_error(&json, code)), + Err(e) => { + log::warn!("{e}: {body}"); + Err(Error::RequestFailed { status: code }) + } + }, + Err(e) => { + log::warn!("failed to read response body: {e}"); + Err(Error::RequestFailed { status: code }) + } + } } } @@ -672,10 +937,10 @@ impl Client { std::collections::HashMap<String, String>, Vec<crate::db::Entry>, )> { - let client = reqwest::Client::new(); + let client = self.reqwest_client().await?; let res = client .get(&self.api_url("/sync")) - .header("Authorization", format!("Bearer {}", access_token)) + .header("Authorization", format!("Bearer {access_token}")) .send() .await .map_err(|source| Error::Reqwest { source })?; @@ -816,8 +1081,8 @@ impl Client { } let client = reqwest::blocking::Client::new(); let res = client - .post(&self.api_url("/ciphers")) - .header("Authorization", format!("Bearer {}", access_token)) + .post(self.api_url("/ciphers")) + .header("Authorization", format!("Bearer {access_token}")) .json(&req) .send() .map_err(|source| Error::Reqwest { source })?; @@ -839,12 +1104,18 @@ impl Client { org_id: Option<&str>, name: &str, data: &crate::db::EntryData, + fields: &[crate::db::Field], notes: Option<&str>, folder_uuid: Option<&str>, history: &[crate::db::HistoryEntry], ) -> Result<()> { let mut req = CiphersPutReq { - ty: 1, + ty: match data { + crate::db::EntryData::Login { .. } => 1, + crate::db::EntryData::SecureNote { .. } => 2, + crate::db::EntryData::Card { .. } => 3, + crate::db::EntryData::Identity { .. } => 4, + }, folder_id: folder_uuid.map(std::string::ToString::to_string), organization_id: org_id.map(std::string::ToString::to_string), name: name.to_string(), @@ -853,6 +1124,15 @@ impl Client { card: None, identity: None, secure_note: None, + fields: fields + .iter() + .map(|field| CipherField { + ty: field.ty, + name: field.name.clone(), + value: field.value.clone(), + linked_id: field.linked_id, + }) + .collect(), password_history: history .iter() .map(|entry| CiphersPutReqHistory { @@ -949,8 +1229,8 @@ impl Client { } let client = reqwest::blocking::Client::new(); let res = client - .put(&self.api_url(&format!("/ciphers/{}", id))) - .header("Authorization", format!("Bearer {}", access_token)) + .put(self.api_url(&format!("/ciphers/{id}"))) + .header("Authorization", format!("Bearer {access_token}")) .json(&req) .send() .map_err(|source| Error::Reqwest { source })?; @@ -968,8 +1248,8 @@ impl Client { pub fn remove(&self, access_token: &str, id: &str) -> Result<()> { let client = reqwest::blocking::Client::new(); let res = client - .delete(&self.api_url(&format!("/ciphers/{}", id))) - .header("Authorization", format!("Bearer {}", access_token)) + .delete(self.api_url(&format!("/ciphers/{id}"))) + .header("Authorization", format!("Bearer {access_token}")) .send() .map_err(|source| Error::Reqwest { source })?; match res.status() { @@ -989,8 +1269,8 @@ impl Client { ) -> Result<Vec<(String, String)>> { let client = reqwest::blocking::Client::new(); let res = client - .get(&self.api_url("/folders")) - .header("Authorization", format!("Bearer {}", access_token)) + .get(self.api_url("/folders")) + .header("Authorization", format!("Bearer {access_token}")) .send() .map_err(|source| Error::Reqwest { source })?; match res.status() { @@ -1021,8 +1301,8 @@ impl Client { }; let client = reqwest::blocking::Client::new(); let res = client - .post(&self.api_url("/folders")) - .header("Authorization", format!("Bearer {}", access_token)) + .post(self.api_url("/folders")) + .header("Authorization", format!("Bearer {access_token}")) .json(&req) .send() .map_err(|source| Error::Reqwest { source })?; @@ -1051,7 +1331,7 @@ impl Client { }; let client = reqwest::blocking::Client::new(); let res = client - .post(&self.identity_url("/connect/token")) + .post(self.identity_url("/connect/token")) .form(&connect_req) .send() .map_err(|source| Error::Reqwest { source })?; @@ -1068,7 +1348,7 @@ impl Client { client_id: "desktop".to_string(), refresh_token: refresh_token.to_string(), }; - let client = reqwest::Client::new(); + let client = self.reqwest_client().await?; let res = client .post(&self.identity_url("/connect/token")) .form(&connect_req) diff --git a/src/base64.rs b/src/base64.rs new file mode 100644 index 0000000..86971bc --- /dev/null +++ b/src/base64.rs @@ -0,0 +1,15 @@ +use base64::Engine as _; + +pub fn encode<T: AsRef<[u8]>>(input: T) -> String { + base64::engine::general_purpose::STANDARD.encode(input) +} + +pub fn encode_url_safe_no_pad<T: AsRef<[u8]>>(input: T) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input) +} + +pub fn decode<T: AsRef<[u8]>>( + input: T, +) -> Result<Vec<u8>, base64::DecodeError> { + base64::engine::general_purpose::STANDARD.decode(input) +} diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs index 1cc71c3..674442b 100644 --- a/src/bin/rbw-agent/actions.rs +++ b/src/bin/rbw-agent/actions.rs @@ -10,9 +10,7 @@ pub async fn register( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -33,7 +31,7 @@ pub async fn register( let client_id = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__id", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -43,7 +41,7 @@ pub async fn register( let client_secret = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__secret", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -61,10 +59,9 @@ pub async fn register( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -81,7 +78,7 @@ pub async fn register( pub async fn login( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { let db = load_db().await.unwrap_or_else(|_| rbw::db::Db::new()); @@ -90,9 +87,7 @@ pub async fn login( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -102,7 +97,7 @@ pub async fn login( let email = config_email().await?; let mut err_msg = None; - for i in 1_u8..=3 { + 'attempts: for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -113,7 +108,7 @@ pub async fn login( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, true, @@ -126,55 +121,70 @@ pub async fn login( Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) => { login_success( - sock, - state, + state.clone(), access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, password, db, email, ) .await?; - break; + break 'attempts; } Err(rbw::error::Error::TwoFactorRequired { providers }) => { - if providers.contains( - &rbw::api::TwoFactorProviderType::Authenticator, - ) { - let ( - access_token, - refresh_token, - iterations, - protected_key, - ) = two_factor( - tty, - &email, - password.clone(), - rbw::api::TwoFactorProviderType::Authenticator, - ) - .await?; - login_success( - sock, - state, - access_token, - refresh_token, - iterations, - protected_key, - password, - db, - email, - ) - .await?; - break; - } else { - return Err(anyhow::anyhow!("TODO")); + let supported_types = vec![ + rbw::api::TwoFactorProviderType::Authenticator, + rbw::api::TwoFactorProviderType::Yubikey, + rbw::api::TwoFactorProviderType::Email, + ]; + + for provider in supported_types { + if providers.contains(&provider) { + let ( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + ) = two_factor( + tty, + &email, + password.clone(), + provider, + ) + .await?; + login_success( + state.clone(), + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + password, + db, + email, + ) + .await?; + break 'attempts; + } } + return Err(anyhow::anyhow!("TODO")); } Err(rbw::error::Error::IncorrectPassword { message }) => { if i == 3 { @@ -182,10 +192,9 @@ pub async fn login( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -205,7 +214,15 @@ async fn two_factor( email: &str, password: rbw::locked::Password, provider: rbw::api::TwoFactorProviderType, -) -> anyhow::Result<(String, String, u32, String)> { +) -> anyhow::Result<( + String, + String, + rbw::api::KdfType, + u32, + Option<u32>, + Option<u32>, + String, +)> { let mut err_msg = None; for i in 1_u8..=3 { let err = if i > 1 { @@ -217,11 +234,11 @@ async fn two_factor( }; let code = rbw::pinentry::getpin( &config_pinentry().await?, - "Authenticator App", - "Enter the 6 digit verification code from your authenticator app.", + provider.header(), + provider.message(), err.as_deref(), tty, - true, + provider.grab(), ) .await .context("failed to read code from pinentry")?; @@ -235,11 +252,22 @@ async fn two_factor( ) .await { - Ok((access_token, refresh_token, iterations, protected_key)) => { + Ok(( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + )) => { return Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) } @@ -249,10 +277,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } // can get this if the user passes an empty string Err(rbw::error::Error::TwoFactorRequired { .. }) => { @@ -262,10 +289,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -278,11 +304,13 @@ async fn two_factor( } async fn login_success( - sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, access_token: String, refresh_token: String, + kdf: rbw::api::KdfType, iterations: u32, + memory: Option<u32>, + parallelism: Option<u32>, protected_key: String, password: rbw::locked::Password, mut db: rbw::db::Db, @@ -290,35 +318,37 @@ async fn login_success( ) -> anyhow::Result<()> { db.access_token = Some(access_token.to_string()); db.refresh_token = Some(refresh_token.to_string()); + db.kdf = Some(kdf); db.iterations = Some(iterations); + db.memory = memory; + db.parallelism = parallelism; db.protected_key = Some(protected_key.to_string()); save_db(&db).await?; - sync(sock, false).await?; + sync(None, state.clone()).await?; let db = load_db().await?; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let res = rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await; + ); match res { Ok((keys, org_keys)) => { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); } @@ -330,39 +360,40 @@ async fn login_success( pub async fn unlock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { let db = load_db().await?; - let iterations = if let Some(iterations) = db.iterations { - iterations - } else { + let Some(kdf) = db.kdf else { + return Err(anyhow::anyhow!("failed to find kdf type in db")); + }; + + let Some(iterations) = db.iterations else { return Err(anyhow::anyhow!( "failed to find number of iterations in db" )); }; - let protected_key = if let Some(protected_key) = db.protected_key { - protected_key - } else { + + let memory = db.memory; + let parallelism = db.parallelism; + + let Some(protected_key) = db.protected_key else { return Err(anyhow::anyhow!( "failed to find protected key in db" )); }; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let email = config_email().await?; let mut err_msg = None; - for i in 1u8..=3 { + for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -373,7 +404,10 @@ pub async fn unlock( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - "Unlock the local database", + &format!( + "Unlock the local database for '{}'", + rbw::dirs::profile() + ), err.as_deref(), tty, true, @@ -383,13 +417,14 @@ pub async fn unlock( match rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await - { + ) { Ok((keys, org_keys)) => { unlock_success(state, keys, org_keys).await?; break; @@ -400,10 +435,9 @@ pub async fn unlock( message, }) .context("failed to unlock database"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => return Err(e).context("failed to unlock database"), } @@ -416,11 +450,11 @@ pub async fn unlock( } async fn unlock_success( - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, keys: rbw::locked::Keys, org_keys: std::collections::HashMap<String, rbw::locked::Keys>, ) -> anyhow::Result<()> { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); Ok(()) @@ -428,9 +462,9 @@ async fn unlock_success( pub async fn lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - state.write().await.clear(); + state.lock().await.clear(); respond_ack(sock).await?; @@ -439,10 +473,9 @@ pub async fn lock( pub async fn check_lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, - _tty: Option<&str>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { return Err(anyhow::anyhow!("agent is locked")); } @@ -452,8 +485,8 @@ pub async fn check_lock( } pub async fn sync( - sock: &mut crate::sock::Sock, - ack: bool, + sock: Option<&mut crate::sock::Sock>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { let mut db = load_db().await?; @@ -482,7 +515,11 @@ pub async fn sync( db.entries = entries; save_db(&db).await?; - if ack { + if let Err(e) = subscribe_to_notifications(state.clone()).await { + eprintln!("failed to subscribe to notifications: {e}"); + } + + if let Some(sock) = sock { respond_ack(sock).await?; } @@ -491,14 +528,12 @@ pub async fn sync( pub async fn decrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, cipherstring: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find decryption keys in in-memory state" )); @@ -519,14 +554,12 @@ pub async fn decrypt( pub async fn encrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, plaintext: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find encryption keys in in-memory state" )); @@ -542,6 +575,25 @@ pub async fn encrypt( Ok(()) } +pub async fn clipboard_store( + sock: &mut crate::sock::Sock, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, + text: &str, +) -> anyhow::Result<()> { + state + .lock() + .await + .clipboard + .set_contents(text.to_owned()) + .map_err(|e| { + anyhow::anyhow!("couldn't store value to clipboard: {e}") + })?; + + respond_ack(sock).await?; + + Ok(()) +} + pub async fn version(sock: &mut crate::sock::Sock) -> anyhow::Result<()> { sock.send(&rbw::protocol::Response::Version { version: rbw::protocol::version(), @@ -579,11 +631,10 @@ async fn respond_encrypt( async fn config_email() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; - if let Some(email) = config.email { - Ok(email) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + Ok, + ) } async fn load_db() -> anyhow::Result<rbw::db::Db> { @@ -617,3 +668,35 @@ async fn config_pinentry() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; Ok(config.pinentry) } + +pub async fn subscribe_to_notifications( + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, +) -> anyhow::Result<()> { + if state.lock().await.notifications_handler.is_connected() { + return Ok(()); + } + + let config = rbw::config::Config::load_async() + .await + .context("Config is missing")?; + let email = config.email.clone().context("Config is missing email")?; + let db = rbw::db::Db::load_async(config.server_name().as_str(), &email) + .await?; + let access_token = + db.access_token.context("Error getting access token")?; + + let websocket_url = format!( + "{}/hub?access_token={}", + config.notifications_url(), + access_token + ) + .replace("https://", "wss://"); + + let mut state = state.lock().await; + state + .notifications_handler + .connect(websocket_url) + .await + .err() + .map_or_else(|| Ok(()), |err| Err(anyhow::anyhow!(err.to_string()))) +} diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index fae8c7b..a3fecb4 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,24 +1,25 @@ use anyhow::Context as _; +use futures_util::StreamExt as _; -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use crate::notifications; pub struct State { pub priv_key: Option<rbw::locked::Keys>, pub org_keys: Option<std::collections::HashMap<String, rbw::locked::Keys>>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, + pub sync_timeout: crate::timeout::Timeout, + pub sync_timeout_duration: std::time::Duration, + pub notifications_handler: crate::notifications::Handler, + pub clipboard: Box<dyn copypasta::ClipboardProvider>, } impl State { pub fn key(&self, org_id: Option<&str>) -> Option<&rbw::locked::Keys> { - match org_id { - Some(id) => self.org_keys.as_ref().and_then(|h| h.get(id)), - None => self.priv_key.as_ref(), - } + org_id.map_or(self.priv_key.as_ref(), |id| { + self.org_keys.as_ref().and_then(|h| h.get(id)) + }) } pub fn needs_unlock(&self) -> bool { @@ -26,103 +27,163 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; - self.org_keys = Default::default(); - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.org_keys = None; + self.timeout.clear(); + } + + pub fn set_sync_timeout(&mut self) { + self.sync_timeout.set(self.sync_timeout_duration); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + sync_timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, } impl Agent { pub fn new() -> anyhow::Result<Self> { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let sync_timeout_duration = + std::time::Duration::from_secs(config.sync_interval); + let (timeout, timer_r) = crate::timeout::Timeout::new(); + let (sync_timeout, sync_timer_r) = crate::timeout::Timeout::new(); + if sync_timeout_duration > std::time::Duration::ZERO { + sync_timeout.set(sync_timeout_duration); + } + let notifications_handler = crate::notifications::Handler::new(); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + copypasta::ClipboardContext::new().map_or_else( + |e| { + log::warn!("couldn't create clipboard context: {e}"); + let clipboard = Box::new( + // infallible + copypasta::nop_clipboard::NopClipboardContext::new() + .unwrap(), + ); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + |clipboard| { + let clipboard = Box::new(clipboard); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + ); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, - state: std::sync::Arc::new(tokio::sync::RwLock::new(State { + timer_r, + sync_timer_r, + state: std::sync::Arc::new(tokio::sync::Mutex::new(State { priv_key: None, - org_keys: Default::default(), - timeout_chan: w, + org_keys: None, + timeout, + timeout_duration, + sync_timeout, + sync_timeout_duration, + notifications_handler, + clipboard, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = if let Some(timeout) = &mut self.timeout { - timeout - } else { - &mut forever - }; - tokio::select! { - sock = listener.accept() => { + pub enum Event { + Request(std::io::Result<tokio::net::UnixStream>), + Timeout(()), + Sync(()), + } + + let notifications = self + .state + .lock() + .await + .notifications_handler + .get_channel() + .await; + let notifications = + tokio_stream::wrappers::UnboundedReceiverStream::new( + notifications, + ) + .map(|message| match message { + notifications::Message::Logout => Event::Timeout(()), + notifications::Message::Sync => Event::Sync(()), + }) + .boxed(); + + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.sync_timer_r, + ) + .map(Event::Sync) + .boxed(), + notifications, + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { - error: format!("{:#}", e), - }).await.unwrap(); + error: format!("{e:#}"), + }) + .await + .unwrap(); } }); } - _ = timeout => { + Event::Timeout(()) => { + self.state.lock().await.clear(); + } + Event::Sync(()) => { let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); + tokio::spawn(async move { + // this could fail if we aren't logged in, but we + // don't care about that + if let Err(e) = + crate::actions::sync(None, state.clone()).await + { + eprintln!("failed to sync: {e:#}"); + } }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + self.state.lock().await.set_sync_timeout(); } } } + Ok(()) } } async fn handle_request( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, ) -> anyhow::Result<()> { let req = sock.recv().await?; let req = match req { @@ -148,12 +209,7 @@ async fn handle_request( true } rbw::protocol::Action::CheckLock => { - crate::actions::check_lock( - sock, - state.clone(), - req.tty.as_deref(), - ) - .await?; + crate::actions::check_lock(sock, state.clone()).await?; false } rbw::protocol::Action::Lock => { @@ -161,7 +217,7 @@ async fn handle_request( false } rbw::protocol::Action::Sync => { - crate::actions::sync(sock, true).await?; + crate::actions::sync(Some(sock), state.clone()).await?; false } rbw::protocol::Action::Decrypt { @@ -187,6 +243,11 @@ async fn handle_request( .await?; true } + rbw::protocol::Action::ClipboardStore { text } => { + crate::actions::clipboard_store(sock, state.clone(), text) + .await?; + true + } rbw::protocol::Action::Quit => std::process::exit(0), rbw::protocol::Action::Version => { crate::actions::version(sock).await?; @@ -195,7 +256,7 @@ async fn handle_request( }; if set_timeout { - state.write().await.set_timeout(); + state.lock().await.set_timeout(); } Ok(()) diff --git a/src/bin/rbw-agent/daemon.rs b/src/bin/rbw-agent/daemon.rs index 923a217..06db891 100644 --- a/src/bin/rbw-agent/daemon.rs +++ b/src/bin/rbw-agent/daemon.rs @@ -1,25 +1,15 @@ pub struct StartupAck { - writer: std::os::unix::io::RawFd, + writer: std::os::unix::io::OwnedFd, } impl StartupAck { - pub fn ack(&self) -> anyhow::Result<()> { - nix::unistd::write(self.writer, &[0])?; - nix::unistd::close(self.writer)?; + pub fn ack(self) -> anyhow::Result<()> { + rustix::io::write(&self.writer, &[0])?; Ok(()) } } -impl Drop for StartupAck { - fn drop(&mut self) { - // best effort close here, can't do better in a destructor - let _ = nix::unistd::close(self.writer); - } -} - pub fn daemonize() -> anyhow::Result<StartupAck> { - rbw::dirs::make_all()?; - let stdout = std::fs::OpenOptions::new() .append(true) .create(true) @@ -29,33 +19,38 @@ pub fn daemonize() -> anyhow::Result<StartupAck> { .create(true) .open(rbw::dirs::agent_stderr_file())?; - let (r, w) = nix::unistd::pipe()?; - let res = daemonize::Daemonize::new() + let (r, w) = rustix::pipe::pipe()?; + let daemonize = daemonize::Daemonize::new() .pid_file(rbw::dirs::pid_file()) .stdout(stdout) - .stderr(stderr) - .exit_action(move || { + .stderr(stderr); + let res = match daemonize.execute() { + daemonize::Outcome::Parent(_) => { + drop(w); + let mut buf = [0; 1]; // unwraps are necessary because not really a good way to handle // errors here otherwise - let _ = nix::unistd::close(w); - let mut buf = [0; 1]; - nix::unistd::read(r, &mut buf).unwrap(); - nix::unistd::close(r).unwrap(); - }) - .start(); - let _ = nix::unistd::close(r); + rustix::io::read(&r, &mut buf).unwrap(); + drop(r); + std::process::exit(0); + } + daemonize::Outcome::Child(res) => res, + }; + + drop(r); match res { Ok(_) => (), Err(e) => { - match e { - daemonize::DaemonizeError::LockPidfile(_) => { - // this means that there is already an agent running, so - // return a special exit code to allow the cli to detect - // this case and not error out - std::process::exit(23); - } - _ => panic!("failed to daemonize: {}", e), + // XXX super gross, but daemonize removed the ability to match + // on specific error types for some reason? + if e.to_string().contains("unable to lock pid file") { + // this means that there is already an agent running, so + // return a special exit code to allow the cli to detect + // this case and not error out + std::process::exit(23); + } else { + panic!("failed to daemonize: {e}"); } } } diff --git a/src/bin/rbw-agent/debugger.rs b/src/bin/rbw-agent/debugger.rs index 59bbe50..be5260c 100644 --- a/src/bin/rbw-agent/debugger.rs +++ b/src/bin/rbw-agent/debugger.rs @@ -12,7 +12,7 @@ pub fn disable_tracing() -> anyhow::Result<()> { if ret == 0 { Ok(()) } else { - let e = nix::Error::last(); + let e = std::io::Error::last_os_error(); Err(anyhow::anyhow!("failed to disable PTRACE_ATTACH, agent memory may be dumpable by other processes: {}", e)) } } diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs index 69411ae..d470e10 100644 --- a/src/bin/rbw-agent/main.rs +++ b/src/bin/rbw-agent/main.rs @@ -1,4 +1,19 @@ +#![warn(clippy::cargo)] +#![warn(clippy::pedantic)] +#![warn(clippy::nursery)] +#![warn(clippy::as_conversions)] +#![warn(clippy::get_unwrap)] +#![allow(clippy::cognitive_complexity)] +#![allow(clippy::missing_const_for_fn)] +#![allow(clippy::similar_names)] +#![allow(clippy::struct_excessive_bools)] #![allow(clippy::too_many_arguments)] +#![allow(clippy::too_many_lines)] +#![allow(clippy::type_complexity)] +#![allow(clippy::multiple_crate_versions)] +#![allow(clippy::large_enum_variant)] +// this one looks plausibly useful, but currently has too many bugs +#![allow(clippy::significant_drop_tightening)] use anyhow::Context as _; @@ -6,7 +21,9 @@ mod actions; mod agent; mod daemon; mod debugger; +mod notifications; mod sock; +mod timeout; async fn tokio_main( startup_ack: Option<crate::daemon::StartupAck>, @@ -17,7 +34,7 @@ async fn tokio_main( startup_ack.ack()?; } - let mut agent = crate::agent::Agent::new()?; + let agent = crate::agent::Agent::new()?; agent.run(listener).await?; Ok(()) @@ -29,11 +46,11 @@ fn real_main() -> anyhow::Result<()> { ) .init(); - let no_daemonize = if let Some(arg) = std::env::args().nth(1) { - arg == "--no-daemonize" - } else { - false - }; + let no_daemonize = std::env::args() + .nth(1) + .map_or(false, |arg| arg == "--no-daemonize"); + + rbw::dirs::make_all()?; let startup_ack = if no_daemonize { None @@ -69,7 +86,7 @@ fn main() { if let Err(e) = res { // XXX log file? - eprintln!("{:#}", e); + eprintln!("{e:#}"); std::process::exit(1); } } diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs new file mode 100644 index 0000000..8176603 --- /dev/null +++ b/src/bin/rbw-agent/notifications.rs @@ -0,0 +1,174 @@ +use futures_util::{SinkExt as _, StreamExt as _}; + +#[derive(Clone, Copy, Debug)] +pub enum Message { + Sync, + Logout, +} + +pub struct Handler { + write: Option< + futures::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, + >, + tokio_tungstenite::tungstenite::Message, + >, + >, + read_handle: Option<tokio::task::JoinHandle<()>>, + sending_channels: std::sync::Arc< + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, + >, +} + +impl Handler { + pub fn new() -> Self { + Self { + write: None, + read_handle: None, + sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new( + Vec::new(), + )), + } + } + + pub async fn connect( + &mut self, + url: String, + ) -> Result<(), Box<dyn std::error::Error>> { + if self.is_connected() { + self.disconnect().await?; + } + + let (write, read_handle) = + subscribe_to_notifications(url, self.sending_channels.clone()) + .await?; + + self.write = Some(write); + self.read_handle = Some(read_handle); + Ok(()) + } + + pub fn is_connected(&self) -> bool { + self.write.is_some() + && self.read_handle.is_some() + && !self.read_handle.as_ref().unwrap().is_finished() + } + + pub async fn disconnect( + &mut self, + ) -> Result<(), Box<dyn std::error::Error>> { + self.sending_channels.write().await.clear(); + if let Some(mut write) = self.write.take() { + write + .send(tokio_tungstenite::tungstenite::Message::Close(None)) + .await?; + write.close().await?; + self.read_handle.take().unwrap().await?; + } + self.write = None; + self.read_handle = None; + Ok(()) + } + + pub async fn get_channel( + &mut self, + ) -> tokio::sync::mpsc::UnboundedReceiver<Message> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + self.sending_channels.write().await.push(tx); + rx + } +} + +async fn subscribe_to_notifications( + url: String, + sending_channels: std::sync::Arc< + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, + >, +) -> Result< + ( + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, + >, + tokio_tungstenite::tungstenite::Message, + >, + tokio::task::JoinHandle<()>, + ), + Box<dyn std::error::Error>, +> { + let url = url::Url::parse(url.as_str())?; + let (ws_stream, _response) = + tokio_tungstenite::connect_async(url).await?; + let (mut write, read) = ws_stream.split(); + + write + .send(tokio_tungstenite::tungstenite::Message::Text( + "{\"protocol\":\"messagepack\",\"version\":1}\x1e".to_string(), + )) + .await + .unwrap(); + + let read_future = async move { + let sending_channels = &sending_channels; + read.for_each(|message| async move { + match message { + Ok(message) => { + if let Some(message) = parse_message(message) { + let sending_channels = sending_channels.read().await; + let sending_channels = sending_channels.as_slice(); + for channel in sending_channels { + channel.send(message).unwrap(); + } + } + } + Err(e) => { + eprintln!("websocket error: {e:?}"); + } + } + }) + .await; + }; + + Ok((write, tokio::spawn(read_future))) +} + +fn parse_message( + message: tokio_tungstenite::tungstenite::Message, +) -> Option<Message> { + let tokio_tungstenite::tungstenite::Message::Binary(data) = message + else { + return None; + }; + + // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message + let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; + + let unpacked_messagepack = + rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; + + let unpacked_message = unpacked_messagepack.as_array()?; + let message_type = unpacked_message.first()?.as_u64()?; + // invocation + if message_type != 1 { + return None; + } + let target = unpacked_message.get(3)?.as_str()?; + if target != "ReceiveMessage" { + return None; + } + + let args = unpacked_message.get(4)?.as_array()?; + let map = args.first()?.as_map()?; + for (k, v) in map { + if k.as_str()? == "Type" { + let ty = v.as_i64()?; + return match ty { + 11 => Some(Message::Logout), + _ => Some(Message::Sync), + }; + } + } + + None +} diff --git a/src/bin/rbw-agent/sock.rs b/src/bin/rbw-agent/sock.rs index 311176c..280b8cc 100644 --- a/src/bin/rbw-agent/sock.rs +++ b/src/bin/rbw-agent/sock.rs @@ -36,9 +36,8 @@ impl Sock { buf.read_line(&mut line) .await .context("failed to read message from socket")?; - Ok(serde_json::from_str(&line).map_err(|e| { - format!("failed to parse message '{}': {}", line, e) - })) + Ok(serde_json::from_str(&line) + .map_err(|e| format!("failed to parse message '{line}': {e}"))) } } diff --git a/src/bin/rbw-agent/timeout.rs b/src/bin/rbw-agent/timeout.rs new file mode 100644 index 0000000..e2aba06 --- /dev/null +++ b/src/bin/rbw-agent/timeout.rs @@ -0,0 +1,66 @@ +use futures_util::StreamExt as _; + +#[derive(Debug, Hash, Eq, PartialEq, Copy, Clone)] +enum Streams { + Requests, + Timer, +} + +#[derive(Debug)] +enum Action { + Set(std::time::Duration), + Clear, +} + +pub struct Timeout { + req_w: tokio::sync::mpsc::UnboundedSender<Action>, +} + +impl Timeout { + pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<()>) { + let (req_w, req_r) = tokio::sync::mpsc::unbounded_channel(); + let (timer_w, timer_r) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + enum Event { + Request(Action), + Timer, + } + let mut stream = tokio_stream::StreamMap::new(); + stream.insert( + Streams::Requests, + tokio_stream::wrappers::UnboundedReceiverStream::new(req_r) + .map(Event::Request) + .boxed(), + ); + while let Some(event) = stream.next().await { + match event { + (_, Event::Request(Action::Set(dur))) => { + stream.insert( + Streams::Timer, + futures_util::stream::once(tokio::time::sleep( + dur, + )) + .map(|()| Event::Timer) + .boxed(), + ); + } + (_, Event::Request(Action::Clear)) => { + stream.remove(&Streams::Timer); + } + (_, Event::Timer) => { + timer_w.send(()).unwrap(); + } + } + } + }); + (Self { req_w }, timer_r) + } + + pub fn set(&self, dur: std::time::Duration) { + self.req_w.send(Action::Set(dur)).unwrap(); + } + + pub fn clear(&self) { + self.req_w.send(Action::Clear).unwrap(); + } +} diff --git a/src/bin/rbw/actions.rs b/src/bin/rbw/actions.rs index 39fde15..c84ccd4 100644 --- a/src/bin/rbw/actions.rs +++ b/src/bin/rbw/actions.rs @@ -1,4 +1,4 @@ -use anyhow::Context as _; +use anyhow::{bail, Context as _}; use std::io::Read as _; pub fn register() -> anyhow::Result<()> { @@ -31,11 +31,17 @@ pub fn quit() -> anyhow::Result<()> { let pidfile = rbw::dirs::pid_file(); let mut pid = String::new(); std::fs::File::open(pidfile)?.read_to_string(&mut pid)?; - let pid = nix::unistd::Pid::from_raw(pid.parse()?); + let Some(pid) = + rustix::process::Pid::from_raw(pid.trim_end().parse()?) + else { + bail!("failed to read pid from pidfile"); + }; sock.send(&rbw::protocol::Request { - tty: nix::unistd::ttyname(0) + tty: rustix::termios::ttyname(std::io::stdin(), vec![]) .ok() - .and_then(|p| p.to_str().map(|s| s.to_string())), + .and_then(|p| { + p.to_str().map(std::string::ToString::to_string).ok() + }), action: rbw::protocol::Action::Quit, })?; wait_for_exit(pid); @@ -57,9 +63,11 @@ pub fn decrypt( ) -> anyhow::Result<String> { let mut sock = connect()?; sock.send(&rbw::protocol::Request { - tty: nix::unistd::ttyname(0) + tty: rustix::termios::ttyname(std::io::stdin(), vec![]) .ok() - .and_then(|p| p.to_str().map(|s| s.to_string())), + .and_then(|p| { + p.to_str().map(std::string::ToString::to_string).ok() + }), action: rbw::protocol::Action::Decrypt { cipherstring: cipherstring.to_string(), org_id: org_id.map(std::string::ToString::to_string), @@ -82,9 +90,11 @@ pub fn encrypt( ) -> anyhow::Result<String> { let mut sock = connect()?; sock.send(&rbw::protocol::Request { - tty: nix::unistd::ttyname(0) + tty: rustix::termios::ttyname(std::io::stdin(), vec![]) .ok() - .and_then(|p| p.to_str().map(|s| s.to_string())), + .and_then(|p| { + p.to_str().map(std::string::ToString::to_string).ok() + }), action: rbw::protocol::Action::Encrypt { plaintext: plaintext.to_string(), org_id: org_id.map(std::string::ToString::to_string), @@ -101,12 +111,20 @@ pub fn encrypt( } } +pub fn clipboard_store(text: &str) -> anyhow::Result<()> { + simple_action(rbw::protocol::Action::ClipboardStore { + text: text.to_string(), + }) +} + pub fn version() -> anyhow::Result<u32> { let mut sock = connect()?; sock.send(&rbw::protocol::Request { - tty: nix::unistd::ttyname(0) + tty: rustix::termios::ttyname(std::io::stdin(), vec![]) .ok() - .and_then(|p| p.to_str().map(|s| s.to_string())), + .and_then(|p| { + p.to_str().map(std::string::ToString::to_string).ok() + }), action: rbw::protocol::Action::Version, })?; @@ -124,9 +142,11 @@ fn simple_action(action: rbw::protocol::Action) -> anyhow::Result<()> { let mut sock = connect()?; sock.send(&rbw::protocol::Request { - tty: nix::unistd::ttyname(0) + tty: rustix::termios::ttyname(std::io::stdin(), vec![]) .ok() - .and_then(|p| p.to_str().map(|s| s.to_string())), + .and_then(|p| { + p.to_str().map(std::string::ToString::to_string).ok() + }), action, })?; @@ -152,9 +172,9 @@ fn connect() -> anyhow::Result<crate::sock::Sock> { }) } -fn wait_for_exit(pid: nix::unistd::Pid) { +fn wait_for_exit(pid: rustix::process::Pid) { loop { - if nix::sys::signal::kill(pid, None).is_err() { + if rustix::process::test_kill_process(pid).is_err() { break; } std::thread::sleep(std::time::Duration::from_millis(10)); diff --git a/src/bin/rbw/commands.rs b/src/bin/rbw/commands.rs index 9efd966..6d36eb3 100644 --- a/src/bin/rbw/commands.rs +++ b/src/bin/rbw/commands.rs @@ -1,4 +1,9 @@ use anyhow::Context as _; +use serde::Serialize; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::io; +use std::io::prelude::Write; +use url::Url; const MISSING_CONFIG_HELP: &str = "Before using rbw, you must configure the email address you would like to \ @@ -11,6 +16,36 @@ const MISSING_CONFIG_HELP: &str = rbw config set identity_url <url>\n"; #[derive(Debug, Clone)] +pub enum Needle { + Name(String), + Uri(Url), + Uuid(uuid::Uuid), +} + +impl Display for Needle { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + let value = match &self { + Self::Name(name) => name.clone(), + Self::Uri(uri) => uri.to_string(), + Self::Uuid(uuid) => uuid.to_string(), + }; + write!(f, "{value}") + } +} + +#[allow(clippy::unnecessary_wraps)] +pub fn parse_needle(arg: &str) -> Result<Needle, std::convert::Infallible> { + if let Ok(uuid) = uuid::Uuid::parse_str(arg) { + return Ok(Needle::Uuid(uuid)); + } + if let Ok(url) = Url::parse(arg) { + return Ok(Needle::Uri(url)); + } + + Ok(Needle::Name(arg.to_string())) +} + +#[derive(Debug, Clone, Serialize)] #[cfg_attr(test, derive(Eq, PartialEq))] struct DecryptedCipher { id: String, @@ -23,25 +58,25 @@ struct DecryptedCipher { } impl DecryptedCipher { - fn display_short(&self, desc: &str) -> bool { + fn display_short(&self, desc: &str, clipboard: bool) -> bool { match &self.data { DecryptedData::Login { password, .. } => { - if let Some(password) = password { - println!("{}", password); - true - } else { - eprintln!("entry for '{}' had no password", desc); - false - } + password.as_ref().map_or_else( + || { + eprintln!("entry for '{desc}' had no password"); + false + }, + |password| val_display_or_store(clipboard, password), + ) } DecryptedData::Card { number, .. } => { - if let Some(number) = number { - println!("{}", number); - true - } else { - eprintln!("entry for '{}' had no card number", desc); - false - } + number.as_ref().map_or_else( + || { + eprintln!("entry for '{desc}' had no card number"); + false + }, + |number| val_display_or_store(clipboard, number), + ) } DecryptedData::Identity { title, @@ -54,30 +89,272 @@ impl DecryptedCipher { [title, first_name, middle_name, last_name] .iter() .copied() - .cloned() .flatten() + .cloned() .collect(); if names.is_empty() { - eprintln!("entry for '{}' had no name", desc); + eprintln!("entry for '{desc}' had no name"); false } else { - println!("{}", names.join(" ")); - true + val_display_or_store(clipboard, &names.join(" ")) } } - DecryptedData::SecureNote {} => { - if let Some(notes) = &self.notes { - println!("{}", notes); - true - } else { - eprintln!("entry for '{}' had no notes", desc); + DecryptedData::SecureNote {} => self.notes.as_ref().map_or_else( + || { + eprintln!("entry for '{desc}' had no notes"); false + }, + |notes| val_display_or_store(clipboard, notes), + ), + } + } + + fn display_field(&self, desc: &str, field: &str, clipboard: bool) { + let field = field.to_lowercase(); + let field = field.as_str(); + match &self.data { + DecryptedData::Login { + username, + totp, + uris, + .. + } => match field { + "notes" => { + if let Some(notes) = &self.notes { + val_display_or_store(clipboard, notes); + } } - } + "username" | "user" => { + if let Some(username) = &username { + val_display_or_store(clipboard, username); + } + } + "totp" | "code" => { + if let Some(totp) = totp { + match generate_totp(totp) { + Ok(code) => { + val_display_or_store(clipboard, &code); + } + Err(e) => { + eprintln!("{e}"); + } + } + } + } + "uris" | "urls" | "sites" => { + if let Some(uris) = uris { + let uri_strs: Vec<_> = uris + .iter() + .map(|uri| uri.uri.to_string()) + .collect(); + val_display_or_store(clipboard, &uri_strs.join("\n")); + } + } + "password" => { + self.display_short(desc, clipboard); + } + _ => { + for f in &self.fields { + if let Some(name) = &f.name { + if name.to_lowercase().as_str().contains(field) { + val_display_or_store( + clipboard, + f.value.as_deref().unwrap_or(""), + ); + break; + } + } + } + } + }, + DecryptedData::Card { + cardholder_name, + brand, + exp_month, + exp_year, + code, + .. + } => match field { + "number" | "card" => { + self.display_short(desc, clipboard); + } + "exp" => { + if let (Some(month), Some(year)) = (exp_month, exp_year) { + val_display_or_store( + clipboard, + &format!("{month}/{year}"), + ); + } + } + "exp_month" | "month" => { + if let Some(exp_month) = exp_month { + val_display_or_store(clipboard, exp_month); + } + } + "exp_year" | "year" => { + if let Some(exp_year) = exp_year { + val_display_or_store(clipboard, exp_year); + } + } + "cvv" => { + if let Some(code) = code { + val_display_or_store(clipboard, code); + } + } + "name" | "cardholder" => { + if let Some(cardholder_name) = cardholder_name { + val_display_or_store(clipboard, cardholder_name); + } + } + "brand" | "type" => { + if let Some(brand) = brand { + val_display_or_store(clipboard, brand); + } + } + "notes" => { + if let Some(notes) = &self.notes { + val_display_or_store(clipboard, notes); + } + } + _ => { + for f in &self.fields { + if let Some(name) = &f.name { + if name.to_lowercase().as_str().contains(field) { + val_display_or_store( + clipboard, + f.value.as_deref().unwrap_or(""), + ); + break; + } + } + } + } + }, + DecryptedData::Identity { + address1, + address2, + address3, + city, + state, + postal_code, + country, + phone, + email, + ssn, + license_number, + passport_number, + username, + .. + } => match field { + "name" => { + self.display_short(desc, clipboard); + } + "email" => { + if let Some(email) = email { + val_display_or_store(clipboard, email); + } + } + "address" => { + let mut strs = vec![]; + if let Some(address1) = address1 { + strs.push(address1.clone()); + } + if let Some(address2) = address2 { + strs.push(address2.clone()); + } + if let Some(address3) = address3 { + strs.push(address3.clone()); + } + if !strs.is_empty() { + val_display_or_store(clipboard, &strs.join("\n")); + } + } + "city" => { + if let Some(city) = city { + val_display_or_store(clipboard, city); + } + } + "state" => { + if let Some(state) = state { + val_display_or_store(clipboard, state); + } + } + "postcode" | "zipcode" | "zip" => { + if let Some(postal_code) = postal_code { + val_display_or_store(clipboard, postal_code); + } + } + "country" => { + if let Some(country) = country { + val_display_or_store(clipboard, country); + } + } + "phone" => { + if let Some(phone) = phone { + val_display_or_store(clipboard, phone); + } + } + "ssn" => { + if let Some(ssn) = ssn { + val_display_or_store(clipboard, ssn); + } + } + "license" => { + if let Some(license_number) = license_number { + val_display_or_store(clipboard, license_number); + } + } + "passport" => { + if let Some(passport_number) = passport_number { + val_display_or_store(clipboard, passport_number); + } + } + "username" => { + if let Some(username) = username { + val_display_or_store(clipboard, username); + } + } + "notes" => { + if let Some(notes) = &self.notes { + val_display_or_store(clipboard, notes); + } + } + _ => { + for f in &self.fields { + if let Some(name) = &f.name { + if name.to_lowercase().as_str().contains(field) { + val_display_or_store( + clipboard, + f.value.as_deref().unwrap_or(""), + ); + break; + } + } + } + } + }, + DecryptedData::SecureNote {} => match field { + "note" | "notes" => { + self.display_short(desc, clipboard); + } + _ => { + for f in &self.fields { + if let Some(name) = &f.name { + if name.to_lowercase().as_str().contains(field) { + val_display_or_store( + clipboard, + f.value.as_deref().unwrap_or(""), + ); + break; + } + } + } + } + }, } } - fn display_long(&self, desc: &str) { + fn display_long(&self, desc: &str, clipboard: bool) { match &self.data { DecryptedData::Login { username, @@ -85,29 +362,31 @@ impl DecryptedCipher { uris, .. } => { - let mut displayed = self.display_short(desc); + let mut displayed = self.display_short(desc, clipboard); displayed |= - self.display_field("Username", username.as_deref()); + display_field("Username", username.as_deref(), clipboard); displayed |= - self.display_field("TOTP Secret", totp.as_deref()); + display_field("TOTP Secret", totp.as_deref(), clipboard); if let Some(uris) = uris { for uri in uris { displayed |= - self.display_field("URI", Some(&uri.uri)); + display_field("URI", Some(&uri.uri), clipboard); let match_type = - uri.match_type.map(|ty| format!("{}", ty)); - displayed |= self.display_field( + uri.match_type.map(|ty| format!("{ty}")); + displayed |= display_field( "Match type", match_type.as_deref(), + clipboard, ); } } for field in &self.fields { - displayed |= self.display_field( + displayed |= display_field( field.name.as_deref().unwrap_or("(null)"), Some(field.value.as_deref().unwrap_or("")), + clipboard, ); } @@ -115,7 +394,7 @@ impl DecryptedCipher { if displayed { println!(); } - println!("{}", notes); + println!("{notes}"); } } DecryptedData::Card { @@ -126,24 +405,28 @@ impl DecryptedCipher { code, .. } => { - let mut displayed = self.display_short(desc); + let mut displayed = self.display_short(desc, clipboard); if let (Some(exp_month), Some(exp_year)) = (exp_month, exp_year) { - println!("Expiration: {}/{}", exp_month, exp_year); + println!("Expiration: {exp_month}/{exp_year}"); displayed = true; } - displayed |= self.display_field("CVV", code.as_deref()); + displayed |= display_field("CVV", code.as_deref(), clipboard); + displayed |= display_field( + "Name", + cardholder_name.as_deref(), + clipboard, + ); displayed |= - self.display_field("Name", cardholder_name.as_deref()); - displayed |= self.display_field("Brand", brand.as_deref()); + display_field("Brand", brand.as_deref(), clipboard); if let Some(notes) = &self.notes { if displayed { println!(); } - println!("{}", notes); + println!("{notes}"); } } DecryptedData::Identity { @@ -162,74 +445,110 @@ impl DecryptedCipher { username, .. } => { - let mut displayed = self.display_short(desc); + let mut displayed = self.display_short(desc, clipboard); displayed |= - self.display_field("Address", address1.as_deref()); + display_field("Address", address1.as_deref(), clipboard); + displayed |= + display_field("Address", address2.as_deref(), clipboard); displayed |= - self.display_field("Address", address2.as_deref()); + display_field("Address", address3.as_deref(), clipboard); displayed |= - self.display_field("Address", address3.as_deref()); - displayed |= self.display_field("City", city.as_deref()); - displayed |= self.display_field("State", state.as_deref()); + display_field("City", city.as_deref(), clipboard); displayed |= - self.display_field("Postcode", postal_code.as_deref()); + display_field("State", state.as_deref(), clipboard); + displayed |= display_field( + "Postcode", + postal_code.as_deref(), + clipboard, + ); displayed |= - self.display_field("Country", country.as_deref()); - displayed |= self.display_field("Phone", phone.as_deref()); - displayed |= self.display_field("Email", email.as_deref()); - displayed |= self.display_field("SSN", ssn.as_deref()); + display_field("Country", country.as_deref(), clipboard); displayed |= - self.display_field("License", license_number.as_deref()); - displayed |= self - .display_field("Passport", passport_number.as_deref()); + display_field("Phone", phone.as_deref(), clipboard); displayed |= - self.display_field("Username", username.as_deref()); + display_field("Email", email.as_deref(), clipboard); + displayed |= display_field("SSN", ssn.as_deref(), clipboard); + displayed |= display_field( + "License", + license_number.as_deref(), + clipboard, + ); + displayed |= display_field( + "Passport", + passport_number.as_deref(), + clipboard, + ); + displayed |= + display_field("Username", username.as_deref(), clipboard); if let Some(notes) = &self.notes { if displayed { println!(); } - println!("{}", notes); + println!("{notes}"); } } DecryptedData::SecureNote {} => { - self.display_short(desc); + self.display_short(desc, clipboard); } } } - fn display_field(&self, name: &str, field: Option<&str>) -> bool { - if let Some(field) = field { - println!("{}: {}", name, field); - true - } else { - false - } - } - fn display_name(&self) -> String { match &self.data { DecryptedData::Login { username, .. } => { - if let Some(username) = username { - format!("{}@{}", username, self.name) - } else { - self.name.clone() - } + username.as_ref().map_or_else( + || self.name.clone(), + |username| format!("{}@{}", username, self.name), + ) } _ => self.name.clone(), } } + fn display_json(&self, desc: &str) -> anyhow::Result<()> { + serde_json::to_writer_pretty(std::io::stdout(), &self) + .context(format!("failed to write entry '{desc}' to stdout"))?; + println!(); + + Ok(()) + } + fn exact_match( &self, - name: &str, + needle: &Needle, username: Option<&str>, folder: Option<&str>, try_match_folder: bool, ) -> bool { - if name != self.name { - return false; + match needle { + Needle::Name(name) => { + if &self.name != name { + return false; + } + } + Needle::Uri(given_uri) => { + match &self.data { + DecryptedData::Login { + uris: Some(uris), .. + } => { + if !uris.iter().any(|uri| uri.matches_url(given_uri)) + { + return false; + } + } + _ => { + // not sure what else to do here, but open to suggestions + return false; + } + } + } + Needle::Uuid(uuid) => { + if uuid::Uuid::parse_str(&self.id) != Ok(*uuid) { + return false; + } + } } if let Some(given_username) = username { @@ -312,7 +631,23 @@ impl DecryptedCipher { } } -#[derive(Debug, Clone)] +fn val_display_or_store(clipboard: bool, password: &str) -> bool { + if clipboard { + match clipboard_store(password) { + Ok(()) => true, + Err(e) => { + eprintln!("{e}"); + false + } + } + } else { + println!("{password}"); + true + } +} + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] #[cfg_attr(test, derive(Eq, PartialEq))] enum DecryptedData { Login { @@ -351,27 +686,97 @@ enum DecryptedData { SecureNote, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] #[cfg_attr(test, derive(Eq, PartialEq))] struct DecryptedField { name: Option<String>, value: Option<String>, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] #[cfg_attr(test, derive(Eq, PartialEq))] struct DecryptedHistoryEntry { last_used_date: String, password: String, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] #[cfg_attr(test, derive(Eq, PartialEq))] struct DecryptedUri { uri: String, match_type: Option<rbw::api::UriMatchType>, } +impl DecryptedUri { + fn matches_url(&self, url: &Url) -> bool { + match self.match_type.unwrap_or(rbw::api::UriMatchType::Domain) { + rbw::api::UriMatchType::Domain => { + let Some(given_domain_port) = domain_port(url) else { + return false; + }; + if let Ok(self_url) = url::Url::parse(&self.uri) { + if let Some(self_domain_port) = domain_port(&self_url) { + if self_url.scheme() == url.scheme() + && (self_domain_port == given_domain_port + || given_domain_port.ends_with(&format!( + ".{self_domain_port}" + ))) + { + return true; + } + } + } + self.uri == given_domain_port + || given_domain_port.ends_with(&format!(".{}", self.uri)) + } + rbw::api::UriMatchType::Host => { + let Some(given_host_port) = host_port(url) else { + return false; + }; + if let Ok(self_url) = url::Url::parse(&self.uri) { + if let Some(self_host_port) = host_port(&self_url) { + if self_url.scheme() == url.scheme() + && self_host_port == given_host_port + { + return true; + } + } + } + self.uri == given_host_port + } + rbw::api::UriMatchType::StartsWith => { + url.to_string().starts_with(&self.uri) + } + rbw::api::UriMatchType::Exact => url.to_string() == self.uri, + rbw::api::UriMatchType::RegularExpression => { + let Ok(rx) = regex::Regex::new(&self.uri) else { + return false; + }; + rx.is_match(url.as_ref()) + } + rbw::api::UriMatchType::Never => false, + } + } +} + +fn host_port(url: &Url) -> Option<String> { + let host = url.host_str()?; + Some( + url.port().map_or_else( + || host.to_string(), + |port| format!("{host}:{port}"), + ), + ) +} + +fn domain_port(url: &Url) -> Option<String> { + let domain = url.domain()?; + Some(url.port().map_or_else( + || domain.to_string(), + |port| format!("{domain}:{port}"), + )) +} + enum ListField { Name, Id, @@ -393,11 +798,16 @@ impl std::convert::TryFrom<&String> for ListField { } } -const HELP: &str = r#" +const HELP_PW: &str = r" # The first line of this file will be the password, and the remainder of the # file (after any blank lines after the password) will be stored as a note. # Lines with leading # will be ignored. -"#; +"; + +const HELP_NOTES: &str = r" +# The content of this file will be stored as a note. +# Lines with leading # will be ignored. +"; pub fn config_show() -> anyhow::Result<()> { let config = rbw::config::Config::load()?; @@ -415,6 +825,13 @@ pub fn config_set(key: &str, value: &str) -> anyhow::Result<()> { "email" => config.email = Some(value.to_string()), "base_url" => config.base_url = Some(value.to_string()), "identity_url" => config.identity_url = Some(value.to_string()), + "notifications_url" => { + config.notifications_url = Some(value.to_string()); + } + "client_cert_path" => { + config.client_cert_path = + Some(std::path::PathBuf::from(value.to_string())); + } "lock_timeout" => { let timeout = value .parse() @@ -425,6 +842,12 @@ pub fn config_set(key: &str, value: &str) -> anyhow::Result<()> { config.lock_timeout = timeout; } } + "sync_interval" => { + let interval = value + .parse() + .context("failed to parse value for sync_interval")?; + config.sync_interval = interval; + } "pinentry" => config.pinentry = value.to_string(), _ => return Err(anyhow::anyhow!("invalid config key: {}", key)), } @@ -447,8 +870,10 @@ pub fn config_unset(key: &str) -> anyhow::Result<()> { "email" => config.email = None, "base_url" => config.base_url = None, "identity_url" => config.identity_url = None, + "notifications_url" => config.notifications_url = None, + "client_cert_path" => config.client_cert_path = None, "lock_timeout" => { - config.lock_timeout = rbw::config::default_lock_timeout() + config.lock_timeout = rbw::config::default_lock_timeout(); } "pinentry" => config.pinentry = rbw::config::default_pinentry(), _ => return Err(anyhow::anyhow!("invalid config key: {}", key)), @@ -465,6 +890,13 @@ pub fn config_unset(key: &str) -> anyhow::Result<()> { Ok(()) } +fn clipboard_store(val: &str) -> anyhow::Result<()> { + ensure_agent()?; + crate::actions::clipboard_store(val)?; + + Ok(()) +} + pub fn register() -> anyhow::Result<()> { ensure_agent()?; crate::actions::register()?; @@ -514,8 +946,7 @@ pub fn list(fields: &[String]) -> anyhow::Result<()> { let mut ciphers: Vec<DecryptedCipher> = db .entries .iter() - .cloned() - .map(|entry| decrypt_cipher(&entry)) + .map(decrypt_cipher) .collect::<anyhow::Result<_>>()?; ciphers.sort_unstable_by(|a, b| a.name.cmp(&b.name)); @@ -526,30 +957,40 @@ pub fn list(fields: &[String]) -> anyhow::Result<()> { ListField::Name => cipher.name.clone(), ListField::Id => cipher.id.clone(), ListField::User => match &cipher.data { - DecryptedData::Login { username, .. } => username - .as_ref() - .map(std::string::ToString::to_string) - .unwrap_or_else(|| "".to_string()), - _ => "".to_string(), + DecryptedData::Login { username, .. } => { + username.as_ref().map_or_else( + String::new, + std::string::ToString::to_string, + ) + } + _ => String::new(), }, - ListField::Folder => cipher - .folder - .as_ref() - .map(std::string::ToString::to_string) - .unwrap_or_else(|| "".to_string()), + ListField::Folder => cipher.folder.as_ref().map_or_else( + String::new, + std::string::ToString::to_string, + ), }) .collect(); - println!("{}", values.join("\t")); + + // write to stdout but don't panic when pipe get's closed + // this happens when piping stdout in a shell + match writeln!(&mut io::stdout(), "{}", values.join("\t")) { + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => Ok(()), + res => res, + }?; } Ok(()) } pub fn get( - name: &str, + needle: &Needle, user: Option<&str>, folder: Option<&str>, + field: Option<&str>, full: bool, + raw: bool, + clipboard: bool, ) -> anyhow::Result<()> { unlock()?; @@ -557,26 +998,30 @@ pub fn get( let desc = format!( "{}{}", - user.map(|s| format!("{}@", s)) - .unwrap_or_else(|| "".to_string()), - name + user.map_or_else(String::new, |s| format!("{s}@")), + needle ); - let (_, decrypted) = find_entry(&db, name, user, folder) - .with_context(|| format!("couldn't find entry for '{}'", desc))?; - if full { - decrypted.display_long(&desc); + let (_, decrypted) = find_entry(&db, needle, user, folder) + .with_context(|| format!("couldn't find entry for '{desc}'"))?; + if raw { + decrypted.display_json(&desc)?; + } else if full { + decrypted.display_long(&desc, clipboard); + } else if let Some(field) = field { + decrypted.display_field(&desc, field, clipboard); } else { - decrypted.display_short(&desc); + decrypted.display_short(&desc, clipboard); } Ok(()) } pub fn code( - name: &str, + needle: &Needle, user: Option<&str>, folder: Option<&str>, + clipboard: bool, ) -> anyhow::Result<()> { unlock()?; @@ -584,17 +1029,16 @@ pub fn code( let desc = format!( "{}{}", - user.map(|s| format!("{}@", s)) - .unwrap_or_else(|| "".to_string()), - name + user.map_or_else(String::new, |s| format!("{s}@")), + needle ); - let (_, decrypted) = find_entry(&db, name, user, folder) - .with_context(|| format!("couldn't find entry for '{}'", desc))?; + let (_, decrypted) = find_entry(&db, needle, user, folder) + .with_context(|| format!("couldn't find entry for '{desc}'"))?; if let DecryptedData::Login { totp, .. } = decrypted.data { if let Some(totp) = totp { - println!("{}", generate_totp(&totp)?) + val_display_or_store(clipboard, &generate_totp(&totp)?); } else { return Err(anyhow::anyhow!( "entry does not contain a totp secret" @@ -610,7 +1054,7 @@ pub fn code( pub fn add( name: &str, username: Option<&str>, - uris: Vec<(String, Option<rbw::api::UriMatchType>)>, + uris: &[(String, Option<rbw::api::UriMatchType>)], folder: Option<&str>, ) -> anyhow::Result<()> { unlock()?; @@ -627,7 +1071,7 @@ pub fn add( .map(|username| crate::actions::encrypt(username, None)) .transpose()?; - let contents = rbw::edit::edit("", HELP)?; + let contents = rbw::edit::edit("", HELP_PW)?; let (password, notes) = parse_editor(&contents); let password = password @@ -707,13 +1151,13 @@ pub fn add( pub fn generate( name: Option<&str>, username: Option<&str>, - uris: Vec<(String, Option<rbw::api::UriMatchType>)>, + uris: &[(String, Option<rbw::api::UriMatchType>)], folder: Option<&str>, len: usize, ty: rbw::pwgen::Type, ) -> anyhow::Result<()> { let password = rbw::pwgen::pwgen(ty, len); - println!("{}", password); + println!("{password}"); if let Some(name) = name { unlock()?; @@ -813,24 +1257,23 @@ pub fn edit( let desc = format!( "{}{}", - username - .map(|s| format!("{}@", s)) - .unwrap_or_else(|| "".to_string()), + username.map_or_else(String::new, |s| format!("{s}@")), name ); - let (entry, decrypted) = find_entry(&db, name, username, folder) - .with_context(|| format!("couldn't find entry for '{}'", desc))?; + let (entry, decrypted) = + find_entry(&db, &Needle::Name(name.to_string()), username, folder) + .with_context(|| format!("couldn't find entry for '{desc}'"))?; - let (data, notes, history) = match &decrypted.data { + let (data, fields, notes, history) = match &decrypted.data { DecryptedData::Login { password, .. } => { let mut contents = format!("{}\n", password.as_deref().unwrap_or("")); if let Some(notes) = decrypted.notes { - contents.push_str(&format!("\n{}\n", notes)); + contents.push_str(&format!("\n{notes}\n")); } - let contents = rbw::edit::edit(&contents, HELP)?; + let contents = rbw::edit::edit(&contents, HELP_PW)?; let (password, notes) = parse_editor(&contents); let password = password @@ -847,16 +1290,15 @@ pub fn edit( }) .transpose()?; let mut history = entry.history.clone(); - let (entry_username, entry_password, entry_uris, entry_totp) = - match &entry.data { - rbw::db::EntryData::Login { - username, - password, - uris, - totp, - } => (username, password, uris, totp), - _ => unreachable!(), - }; + let rbw::db::EntryData::Login { + username: entry_username, + password: entry_password, + uris: entry_uris, + totp: entry_totp, + } = &entry.data + else { + unreachable!(); + }; if let Some(prev_password) = entry_password.clone() { let new_history_entry = rbw::db::HistoryEntry { @@ -874,14 +1316,34 @@ pub fn edit( let data = rbw::db::EntryData::Login { username: entry_username.clone(), password, - uris: entry_uris.to_vec(), + uris: entry_uris.clone(), totp: entry_totp.clone(), }; - (data, notes, history) + (data, entry.fields, notes, history) + } + DecryptedData::SecureNote {} => { + let data = rbw::db::EntryData::SecureNote {}; + + let editor_content = decrypted.notes.map_or_else( + || "\n".to_string(), + |notes| format!("{notes}\n"), + ); + let contents = rbw::edit::edit(&editor_content, HELP_NOTES)?; + + // prepend blank line to be parsed as pw by `parse_editor` + let (_, notes) = parse_editor(&format!("\n{contents}\n")); + + let notes = notes + .map(|notes| { + crate::actions::encrypt(¬es, entry.org_id.as_deref()) + }) + .transpose()?; + + (data, entry.fields, notes, entry.history) } _ => { return Err(anyhow::anyhow!( - "modifications are only supported for login entries" + "modifications are only supported for login and note entries" )); } }; @@ -893,6 +1355,7 @@ pub fn edit( entry.org_id.as_deref(), &entry.name, &data, + &fields, notes.as_deref(), entry.folder_id.as_deref(), &history, @@ -918,14 +1381,13 @@ pub fn remove( let desc = format!( "{}{}", - username - .map(|s| format!("{}@", s)) - .unwrap_or_else(|| "".to_string()), + username.map_or_else(String::new, |s| format!("{s}@")), name ); - let (entry, _) = find_entry(&db, name, username, folder) - .with_context(|| format!("couldn't find entry for '{}'", desc))?; + let (entry, _) = + find_entry(&db, &Needle::Name(name.to_string()), username, folder) + .with_context(|| format!("couldn't find entry for '{desc}'"))?; if let (Some(access_token), ()) = rbw::actions::remove(access_token, refresh_token, &entry.id)? @@ -950,14 +1412,13 @@ pub fn history( let desc = format!( "{}{}", - username - .map(|s| format!("{}@", s)) - .unwrap_or_else(|| "".to_string()), + username.map_or_else(String::new, |s| format!("{s}@")), name ); - let (_, decrypted) = find_entry(&db, name, username, folder) - .with_context(|| format!("couldn't find entry for '{}'", desc))?; + let (_, decrypted) = + find_entry(&db, &Needle::Name(name.to_string()), username, folder) + .with_context(|| format!("couldn't find entry for '{desc}'"))?; for history in decrypted.history { println!("{}: {}", history.last_used_date, history.password); } @@ -1017,7 +1478,7 @@ fn ensure_agent_once() -> anyhow::Result<()> { let agent_path = std::env::var("RBW_AGENT"); let agent_path = agent_path .as_ref() - .map(|s| s.as_str()) + .map(std::string::String::as_str) .unwrap_or("rbw-agent"); let status = std::process::Command::new(agent_path) .status() @@ -1052,45 +1513,42 @@ fn version_or_quit() -> anyhow::Result<u32> { fn find_entry( db: &rbw::db::Db, - name: &str, + needle: &Needle, username: Option<&str>, folder: Option<&str>, ) -> anyhow::Result<(rbw::db::Entry, DecryptedCipher)> { - match uuid::Uuid::parse_str(name) { - Ok(_) => { - for cipher in &db.entries { - if name == cipher.id { - return Ok((cipher.clone(), decrypt_cipher(cipher)?)); - } + if let Needle::Uuid(uuid) = needle { + for cipher in &db.entries { + if uuid::Uuid::parse_str(&cipher.id) == Ok(*uuid) { + return Ok((cipher.clone(), decrypt_cipher(cipher)?)); } - Err(anyhow::anyhow!("no entry found")) - } - Err(_) => { - let ciphers: Vec<(rbw::db::Entry, DecryptedCipher)> = db - .entries - .iter() - .cloned() - .map(|entry| { - decrypt_cipher(&entry).map(|decrypted| (entry, decrypted)) - }) - .collect::<anyhow::Result<_>>()?; - find_entry_raw(&ciphers, name, username, folder) } + Err(anyhow::anyhow!("no entry found")) + } else { + let ciphers: Vec<(rbw::db::Entry, DecryptedCipher)> = db + .entries + .iter() + .cloned() + .map(|entry| { + decrypt_cipher(&entry).map(|decrypted| (entry, decrypted)) + }) + .collect::<anyhow::Result<_>>()?; + find_entry_raw(&ciphers, needle, username, folder) } } fn find_entry_raw( entries: &[(rbw::db::Entry, DecryptedCipher)], - name: &str, + needle: &Needle, username: Option<&str>, folder: Option<&str>, ) -> anyhow::Result<(rbw::db::Entry, DecryptedCipher)> { let mut matches: Vec<(rbw::db::Entry, DecryptedCipher)> = entries .iter() - .cloned() - .filter(|(_, decrypted_cipher)| { - decrypted_cipher.exact_match(name, username, folder, true) + .filter(|&(_, decrypted_cipher)| { + decrypted_cipher.exact_match(needle, username, folder, true) }) + .cloned() .collect(); if matches.len() == 1 { @@ -1100,10 +1558,10 @@ fn find_entry_raw( if folder.is_none() { matches = entries .iter() - .cloned() - .filter(|(_, decrypted_cipher)| { - decrypted_cipher.exact_match(name, username, folder, false) + .filter(|&(_, decrypted_cipher)| { + decrypted_cipher.exact_match(needle, username, folder, false) }) + .cloned() .collect(); if matches.len() == 1 { @@ -1111,29 +1569,32 @@ fn find_entry_raw( } } - matches = entries - .iter() - .cloned() - .filter(|(_, decrypted_cipher)| { - decrypted_cipher.partial_match(name, username, folder, true) - }) - .collect(); - - if matches.len() == 1 { - return Ok(matches[0].clone()); - } - - if folder.is_none() { + if let Needle::Name(name) = needle { matches = entries .iter() - .cloned() - .filter(|(_, decrypted_cipher)| { - decrypted_cipher.partial_match(name, username, folder, false) + .filter(|&(_, decrypted_cipher)| { + decrypted_cipher.partial_match(name, username, folder, true) }) + .cloned() .collect(); + if matches.len() == 1 { return Ok(matches[0].clone()); } + + if folder.is_none() { + matches = entries + .iter() + .filter(|&(_, decrypted_cipher)| { + decrypted_cipher + .partial_match(name, username, folder, false) + }) + .cloned() + .collect(); + if matches.len() == 1 { + return Ok(matches[0].clone()); + } + } } if matches.is_empty() { @@ -1435,8 +1896,11 @@ fn parse_editor(contents: &str) -> (Option<String>, Option<String>) { let mut notes: String = lines .skip_while(|line| line.is_empty()) .filter(|line| !line.starts_with('#')) - .map(|line| format!("{}\n", line)) - .collect(); + .fold(String::new(), |mut notes, line| { + notes.push_str(line); + notes.push('\n'); + notes + }); while notes.ends_with('\n') { notes.pop(); } @@ -1447,36 +1911,54 @@ fn parse_editor(contents: &str) -> (Option<String>, Option<String>) { fn load_db() -> anyhow::Result<rbw::db::Db> { let config = rbw::config::Config::load()?; - if let Some(email) = &config.email { - rbw::db::Db::load(&config.server_name(), email) - .map_err(anyhow::Error::new) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.as_ref().map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + |email| { + rbw::db::Db::load(&config.server_name(), email) + .map_err(anyhow::Error::new) + }, + ) } fn save_db(db: &rbw::db::Db) -> anyhow::Result<()> { let config = rbw::config::Config::load()?; - if let Some(email) = &config.email { - db.save(&config.server_name(), email) - .map_err(anyhow::Error::new) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.as_ref().map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + |email| { + db.save(&config.server_name(), email) + .map_err(anyhow::Error::new) + }, + ) } fn remove_db() -> anyhow::Result<()> { let config = rbw::config::Config::load()?; - if let Some(email) = &config.email { - rbw::db::Db::remove(&config.server_name(), email) - .map_err(anyhow::Error::new) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.as_ref().map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + |email| { + rbw::db::Db::remove(&config.server_name(), email) + .map_err(anyhow::Error::new) + }, + ) } -fn parse_totp_secret(secret: &str) -> anyhow::Result<Vec<u8>> { - let secret_str = if let Ok(u) = url::Url::parse(secret) { +struct TotpParams { + secret: Vec<u8>, + algorithm: String, + digits: u32, + period: u64, +} + +fn decode_totp_secret(secret: &str) -> anyhow::Result<Vec<u8>> { + base32::decode( + base32::Alphabet::RFC4648 { padding: false }, + &secret.replace(' ', ""), + ) + .ok_or_else(|| anyhow::anyhow!("totp secret was not valid base32")) +} + +fn parse_totp_secret(secret: &str) -> anyhow::Result<TotpParams> { + if let Ok(u) = url::Url::parse(secret) { if u.scheme() != "otpauth" { return Err(anyhow::anyhow!( "totp secret url must have otpauth scheme" @@ -1489,32 +1971,80 @@ fn parse_totp_secret(secret: &str) -> anyhow::Result<Vec<u8>> { } let query: std::collections::HashMap<_, _> = u.query_pairs().collect(); - query - .get("secret") - .ok_or_else(|| { - anyhow::anyhow!("totp secret url must have secret") - })? - .to_string() + Ok(TotpParams { + secret: decode_totp_secret(query + .get("secret") + .ok_or_else(|| { + anyhow::anyhow!("totp secret url must have secret") + })?)?, + algorithm:query.get("algorithm").map_or_else(||{String::from("SHA1")},|alg|{alg.to_string()} ), + digits: match query.get("digits") { + Some(dig) => { + dig.parse::<u32>().map_err(|_|{ + anyhow::anyhow!("digits parameter in totp url must be a valid integer.") + })? + } + None => 6, + }, + period: match query.get("period") { + Some(dig) => { + dig.parse::<u64>().map_err(|_|{ + anyhow::anyhow!("period parameter in totp url must be a valid integer.") + })? + } + None => totp_lite::DEFAULT_STEP, + } + }) } else { - secret.to_string() - }; - base32::decode( - base32::Alphabet::RFC4648 { padding: false }, - &secret_str.replace(" ", ""), - ) - .ok_or_else(|| anyhow::anyhow!("totp secret was not valid base32")) + Ok(TotpParams { + secret: decode_totp_secret(secret)?, + algorithm: String::from("SHA1"), + digits: 6, + period: totp_lite::DEFAULT_STEP, + }) + } } fn generate_totp(secret: &str) -> anyhow::Result<String> { - let key = parse_totp_secret(secret)?; - Ok(totp_lite::totp_custom::<totp_lite::Sha1>( - totp_lite::DEFAULT_STEP, - 6, - &key, - std::time::SystemTime::now() - .duration_since(std::time::SystemTime::UNIX_EPOCH)? - .as_secs(), - )) + let totp_params = parse_totp_secret(secret)?; + let alg = totp_params.algorithm.as_str(); + match alg { + "SHA1" => Ok(totp_lite::totp_custom::<totp_lite::Sha1>( + totp_params.period, + totp_params.digits, + &totp_params.secret, + std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH)? + .as_secs(), + )), + "SHA256" => Ok(totp_lite::totp_custom::<totp_lite::Sha256>( + totp_params.period, + totp_params.digits, + &totp_params.secret, + std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH)? + .as_secs(), + )), + "SHA512" => Ok(totp_lite::totp_custom::<totp_lite::Sha512>( + totp_params.period, + totp_params.digits, + &totp_params.secret, + std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH)? + .as_secs(), + )), + _ => Err(anyhow::anyhow!(format!( + "{} is not a valid totp algorithm", + alg + ))), + } +} + +fn display_field(name: &str, field: Option<&str>, clipboard: bool) -> bool { + field.map_or_else( + || false, + |field| val_display_or_store(clipboard, &format!("{name}: {field}")), + ) } #[cfg(test)] @@ -1524,15 +2054,15 @@ mod test { #[test] fn test_find_entry() { let entries = &[ - make_entry("github", Some("foo"), None), - make_entry("gitlab", Some("foo"), None), - make_entry("gitlab", Some("bar"), None), - make_entry("gitter", Some("baz"), None), - make_entry("git", Some("foo"), None), - make_entry("bitwarden", None, None), - make_entry("github", Some("foo"), Some("websites")), - make_entry("github", Some("foo"), Some("ssh")), - make_entry("github", Some("root"), Some("ssh")), + make_entry("github", Some("foo"), None, &[]), + make_entry("gitlab", Some("foo"), None, &[]), + make_entry("gitlab", Some("bar"), None, &[]), + make_entry("gitter", Some("baz"), None, &[]), + make_entry("git", Some("foo"), None, &[]), + make_entry("bitwarden", None, None, &[]), + make_entry("github", Some("foo"), Some("websites"), &[]), + make_entry("github", Some("foo"), Some("ssh"), &[]), + make_entry("github", Some("root"), Some("ssh"), &[]), ]; assert!( @@ -1591,47 +2121,681 @@ mod test { ); } + #[test] + fn test_find_by_uuid() { + let entries = &[ + make_entry("github", Some("foo"), None, &[]), + make_entry("gitlab", Some("foo"), None, &[]), + make_entry("gitlab", Some("bar"), None, &[]), + ]; + + assert!( + one_match(entries, &entries[0].0.id, None, None, 0), + "foo@github" + ); + assert!( + one_match(entries, &entries[1].0.id, None, None, 1), + "foo@gitlab" + ); + assert!( + one_match(entries, &entries[2].0.id, None, None, 2), + "bar@gitlab" + ); + + assert!( + one_match( + entries, + &entries[0].0.id.to_uppercase(), + None, + None, + 0 + ), + "foo@github" + ); + assert!( + one_match( + entries, + &entries[0].0.id.to_lowercase(), + None, + None, + 0 + ), + "foo@github" + ); + } + + #[test] + fn test_find_by_url_default() { + let entries = &[ + make_entry("one", None, None, &[("https://one.com/", None)]), + make_entry("two", None, None, &[("https://two.com/login", None)]), + make_entry( + "three", + None, + None, + &[("https://login.three.com/", None)], + ), + make_entry("four", None, None, &[("four.com", None)]), + make_entry( + "five", + None, + None, + &[("https://five.com:8080/", None)], + ), + make_entry("six", None, None, &[("six.com:8080", None)]), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + one_match(entries, "https://login.one.com/", None, None, 0), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!(one_match(entries, "https://two.com/", None, None, 1), "two"); + assert!( + one_match(entries, "https://two.com/other-page", None, None, 1), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + + assert!( + one_match(entries, "https://four.com/", None, None, 3), + "four" + ); + + assert!( + one_match(entries, "https://five.com:8080/", None, None, 4), + "five" + ); + assert!(no_matches(entries, "https://five.com/", None, None), "five"); + + assert!( + one_match(entries, "https://six.com:8080/", None, None, 5), + "six" + ); + assert!(no_matches(entries, "https://six.com/", None, None), "six"); + } + + #[test] + fn test_find_by_url_domain() { + let entries = &[ + make_entry( + "one", + None, + None, + &[("https://one.com/", Some(rbw::api::UriMatchType::Domain))], + ), + make_entry( + "two", + None, + None, + &[( + "https://two.com/login", + Some(rbw::api::UriMatchType::Domain), + )], + ), + make_entry( + "three", + None, + None, + &[( + "https://login.three.com/", + Some(rbw::api::UriMatchType::Domain), + )], + ), + make_entry( + "four", + None, + None, + &[("four.com", Some(rbw::api::UriMatchType::Domain))], + ), + make_entry( + "five", + None, + None, + &[( + "https://five.com:8080/", + Some(rbw::api::UriMatchType::Domain), + )], + ), + make_entry( + "six", + None, + None, + &[("six.com:8080", Some(rbw::api::UriMatchType::Domain))], + ), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + one_match(entries, "https://login.one.com/", None, None, 0), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!(one_match(entries, "https://two.com/", None, None, 1), "two"); + assert!( + one_match(entries, "https://two.com/other-page", None, None, 1), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + + assert!( + one_match(entries, "https://four.com/", None, None, 3), + "four" + ); + + assert!( + one_match(entries, "https://five.com:8080/", None, None, 4), + "five" + ); + assert!(no_matches(entries, "https://five.com/", None, None), "five"); + + assert!( + one_match(entries, "https://six.com:8080/", None, None, 5), + "six" + ); + assert!(no_matches(entries, "https://six.com/", None, None), "six"); + } + + #[test] + fn test_find_by_url_host() { + let entries = &[ + make_entry( + "one", + None, + None, + &[("https://one.com/", Some(rbw::api::UriMatchType::Host))], + ), + make_entry( + "two", + None, + None, + &[( + "https://two.com/login", + Some(rbw::api::UriMatchType::Host), + )], + ), + make_entry( + "three", + None, + None, + &[( + "https://login.three.com/", + Some(rbw::api::UriMatchType::Host), + )], + ), + make_entry( + "four", + None, + None, + &[("four.com", Some(rbw::api::UriMatchType::Host))], + ), + make_entry( + "five", + None, + None, + &[( + "https://five.com:8080/", + Some(rbw::api::UriMatchType::Host), + )], + ), + make_entry( + "six", + None, + None, + &[("six.com:8080", Some(rbw::api::UriMatchType::Host))], + ), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + no_matches(entries, "https://login.one.com/", None, None), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!(one_match(entries, "https://two.com/", None, None, 1), "two"); + assert!( + one_match(entries, "https://two.com/other-page", None, None, 1), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + + assert!( + one_match(entries, "https://four.com/", None, None, 3), + "four" + ); + + assert!( + one_match(entries, "https://five.com:8080/", None, None, 4), + "five" + ); + assert!(no_matches(entries, "https://five.com/", None, None), "five"); + + assert!( + one_match(entries, "https://six.com:8080/", None, None, 5), + "six" + ); + assert!(no_matches(entries, "https://six.com/", None, None), "six"); + } + + #[test] + fn test_find_by_url_starts_with() { + let entries = &[ + make_entry( + "one", + None, + None, + &[( + "https://one.com/", + Some(rbw::api::UriMatchType::StartsWith), + )], + ), + make_entry( + "two", + None, + None, + &[( + "https://two.com/login", + Some(rbw::api::UriMatchType::StartsWith), + )], + ), + make_entry( + "three", + None, + None, + &[( + "https://login.three.com/", + Some(rbw::api::UriMatchType::StartsWith), + )], + ), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + no_matches(entries, "https://login.one.com/", None, None), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!( + one_match(entries, "https://two.com/login", None, None, 1), + "two" + ); + assert!( + one_match(entries, "https://two.com/login/sso", None, None, 1), + "two" + ); + assert!(no_matches(entries, "https://two.com/", None, None), "two"); + assert!( + no_matches(entries, "https://two.com/other-page", None, None), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + } + + #[test] + fn test_find_by_url_exact() { + let entries = &[ + make_entry( + "one", + None, + None, + &[("https://one.com/", Some(rbw::api::UriMatchType::Exact))], + ), + make_entry( + "two", + None, + None, + &[( + "https://two.com/login", + Some(rbw::api::UriMatchType::Exact), + )], + ), + make_entry( + "three", + None, + None, + &[( + "https://login.three.com/", + Some(rbw::api::UriMatchType::Exact), + )], + ), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + no_matches(entries, "https://login.one.com/", None, None), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!( + one_match(entries, "https://two.com/login", None, None, 1), + "two" + ); + assert!( + no_matches(entries, "https://two.com/login/sso", None, None), + "two" + ); + assert!(no_matches(entries, "https://two.com/", None, None), "two"); + assert!( + no_matches(entries, "https://two.com/other-page", None, None), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + } + + #[test] + fn test_find_by_url_regex() { + let entries = &[ + make_entry( + "one", + None, + None, + &[( + r"^https://one\.com/$", + Some(rbw::api::UriMatchType::RegularExpression), + )], + ), + make_entry( + "two", + None, + None, + &[( + r"^https://two\.com/(login|start)", + Some(rbw::api::UriMatchType::RegularExpression), + )], + ), + make_entry( + "three", + None, + None, + &[( + r"^https://(login\.)?three\.com/$", + Some(rbw::api::UriMatchType::RegularExpression), + )], + ), + ]; + + assert!(one_match(entries, "https://one.com/", None, None, 0), "one"); + assert!( + no_matches(entries, "https://login.one.com/", None, None), + "one" + ); + assert!( + one_match(entries, "https://one.com:443/", None, None, 0), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!( + one_match(entries, "https://two.com/login", None, None, 1), + "two" + ); + assert!( + one_match(entries, "https://two.com/start", None, None, 1), + "two" + ); + assert!( + one_match(entries, "https://two.com/login/sso", None, None, 1), + "two" + ); + assert!(no_matches(entries, "https://two.com/", None, None), "two"); + assert!( + no_matches(entries, "https://two.com/other-page", None, None), + "two" + ); + + assert!( + one_match(entries, "https://login.three.com/", None, None, 2), + "three" + ); + assert!( + one_match(entries, "https://three.com/", None, None, 2), + "three" + ); + assert!( + no_matches(entries, "https://www.three.com/", None, None), + "three" + ); + } + + #[test] + fn test_find_by_url_never() { + let entries = &[ + make_entry( + "one", + None, + None, + &[("https://one.com/", Some(rbw::api::UriMatchType::Never))], + ), + make_entry( + "two", + None, + None, + &[( + "https://two.com/login", + Some(rbw::api::UriMatchType::Never), + )], + ), + make_entry( + "three", + None, + None, + &[( + "https://login.three.com/", + Some(rbw::api::UriMatchType::Never), + )], + ), + make_entry( + "four", + None, + None, + &[("four.com", Some(rbw::api::UriMatchType::Never))], + ), + make_entry( + "five", + None, + None, + &[( + "https://five.com:8080/", + Some(rbw::api::UriMatchType::Never), + )], + ), + make_entry( + "six", + None, + None, + &[("six.com:8080", Some(rbw::api::UriMatchType::Never))], + ), + ]; + + assert!(no_matches(entries, "https://one.com/", None, None), "one"); + assert!( + no_matches(entries, "https://login.one.com/", None, None), + "one" + ); + assert!( + no_matches(entries, "https://one.com:443/", None, None), + "one" + ); + assert!(no_matches(entries, "one.com", None, None), "one"); + assert!(no_matches(entries, "https", None, None), "one"); + assert!(no_matches(entries, "com", None, None), "one"); + assert!(no_matches(entries, "https://com/", None, None), "one"); + + assert!(no_matches(entries, "https://two.com/", None, None), "two"); + assert!( + no_matches(entries, "https://two.com/other-page", None, None), + "two" + ); + + assert!( + no_matches(entries, "https://login.three.com/", None, None), + "three" + ); + assert!( + no_matches(entries, "https://three.com/", None, None), + "three" + ); + + assert!(no_matches(entries, "https://four.com/", None, None), "four"); + + assert!( + no_matches(entries, "https://five.com:8080/", None, None), + "five" + ); + assert!(no_matches(entries, "https://five.com/", None, None), "five"); + + assert!( + no_matches(entries, "https://six.com:8080/", None, None), + "six" + ); + assert!(no_matches(entries, "https://six.com/", None, None), "six"); + } + + #[track_caller] fn one_match( entries: &[(rbw::db::Entry, DecryptedCipher)], - name: &str, + needle: &str, username: Option<&str>, folder: Option<&str>, idx: usize, ) -> bool { entries_eq( - &find_entry_raw(entries, name, username, folder).unwrap(), + &find_entry_raw( + entries, + &parse_needle(needle).unwrap(), + username, + folder, + ) + .unwrap(), &entries[idx], ) } + #[track_caller] fn no_matches( entries: &[(rbw::db::Entry, DecryptedCipher)], - name: &str, + needle: &str, username: Option<&str>, folder: Option<&str>, ) -> bool { - let res = find_entry_raw(entries, name, username, folder); + let res = find_entry_raw( + entries, + &parse_needle(needle).unwrap(), + username, + folder, + ); if let Err(e) = res { - format!("{}", e).contains("no entry found") + format!("{e}").contains("no entry found") } else { false } } + #[track_caller] fn many_matches( entries: &[(rbw::db::Entry, DecryptedCipher)], - name: &str, + needle: &str, username: Option<&str>, folder: Option<&str>, ) -> bool { - let res = find_entry_raw(entries, name, username, folder); + let res = find_entry_raw( + entries, + &parse_needle(needle).unwrap(), + username, + folder, + ); if let Err(e) = res { - format!("{}", e).contains("multiple entries found") + format!("{e}").contains("multiple entries found") } else { false } } + #[track_caller] fn entries_eq( a: &(rbw::db::Entry, DecryptedCipher), b: &(rbw::db::Entry, DecryptedCipher), @@ -1643,10 +2807,12 @@ mod test { name: &str, username: Option<&str>, folder: Option<&str>, + uris: &[(&str, Option<rbw::api::UriMatchType>)], ) -> (rbw::db::Entry, DecryptedCipher) { + let id = uuid::Uuid::new_v4(); ( rbw::db::Entry { - id: "irrelevant".to_string(), + id: id.to_string(), org_id: None, folder: folder.map(|_| "encrypted folder name".to_string()), folder_id: None, @@ -1656,7 +2822,13 @@ mod test { "this is the encrypted username".to_string() }), password: None, - uris: vec![], + uris: uris + .iter() + .map(|(_, match_type)| rbw::db::Uri { + uri: "this is the encrypted uri".to_string(), + match_type: *match_type, + }) + .collect(), totp: None, }, fields: vec![], @@ -1664,14 +2836,21 @@ mod test { history: vec![], }, DecryptedCipher { - id: "irrelevant".to_string(), + id: id.to_string(), folder: folder.map(std::string::ToString::to_string), name: name.to_string(), data: DecryptedData::Login { username: username.map(std::string::ToString::to_string), password: None, totp: None, - uris: None, + uris: Some( + uris.iter() + .map(|(uri, match_type)| DecryptedUri { + uri: (*uri).to_string(), + match_type: *match_type, + }) + .collect(), + ), }, fields: vec![], notes: None, diff --git a/src/bin/rbw/main.rs b/src/bin/rbw/main.rs index 85631c5..2fb96bf 100644 --- a/src/bin/rbw/main.rs +++ b/src/bin/rbw/main.rs @@ -1,23 +1,36 @@ +#![warn(clippy::cargo)] +#![warn(clippy::pedantic)] +#![warn(clippy::nursery)] +#![warn(clippy::as_conversions)] +#![warn(clippy::get_unwrap)] +#![allow(clippy::cognitive_complexity)] +#![allow(clippy::missing_const_for_fn)] +#![allow(clippy::similar_names)] +#![allow(clippy::struct_excessive_bools)] +#![allow(clippy::too_many_arguments)] +#![allow(clippy::too_many_lines)] +#![allow(clippy::type_complexity)] +#![allow(clippy::multiple_crate_versions)] #![allow(clippy::large_enum_variant)] use anyhow::Context as _; +use clap::{CommandFactory as _, Parser as _}; use std::io::Write as _; -use structopt::StructOpt as _; mod actions; mod commands; mod sock; -#[derive(Debug, structopt::StructOpt)] -#[structopt(about = "Unofficial Bitwarden CLI")] +#[derive(Debug, clap::Parser)] +#[command(version, about = "Unofficial Bitwarden CLI")] enum Opt { - #[structopt(about = "Get or set configuration options")] + #[command(about = "Get or set configuration options")] Config { - #[structopt(subcommand)] + #[command(subcommand)] config: Config, }, - #[structopt( + #[command( about = "Register this device with the Bitwarden server", long_about = "Register this device with the Bitwarden server\n\n\ The official Bitwarden server includes bot detection to prevent \ @@ -28,60 +41,68 @@ enum Opt { )] Register, - #[structopt(about = "Log in to the Bitwarden server")] + #[command(about = "Log in to the Bitwarden server")] Login, - #[structopt(about = "Unlock the local Bitwarden database")] + #[command(about = "Unlock the local Bitwarden database")] Unlock, - #[structopt(about = "Check if the local Bitwarden database is unlocked")] + #[command(about = "Check if the local Bitwarden database is unlocked")] Unlocked, - #[structopt(about = "Update the local copy of the Bitwarden database")] + #[command(about = "Update the local copy of the Bitwarden database")] Sync, - #[structopt( + #[command( about = "List all entries in the local Bitwarden database", visible_alias = "ls" )] List { - #[structopt( + #[arg( long, help = "Fields to display. \ Available options are id, name, user, folder. \ Multiple fields will be separated by tabs.", default_value = "name", - use_delimiter = true + use_value_delimiter = true )] fields: Vec<String>, }, - #[structopt(about = "Display the password for a given entry")] + #[command(about = "Display the password for a given entry")] Get { - #[structopt(help = "Name or UUID of the entry to display")] - name: String, - #[structopt(help = "Username of the entry to display")] + #[arg(help = "Name, URI or UUID of the entry to display", value_parser = commands::parse_needle)] + needle: commands::Needle, + #[arg(help = "Username of the entry to display")] user: Option<String>, - #[structopt(long, help = "Folder name to search in")] + #[arg(long, help = "Folder name to search in")] folder: Option<String>, - #[structopt( - long, - help = "Display the notes in addition to the password" - )] + #[arg(short, long, help = "Field to get")] + field: Option<String>, + #[arg(long, help = "Display the notes in addition to the password")] full: bool, + #[structopt(long, help = "Display output as JSON")] + raw: bool, + #[structopt(long, help = "Copy result to clipboard")] + clipboard: bool, }, - #[structopt(about = "Display the authenticator code for a given entry")] + #[command( + about = "Display the authenticator code for a given entry", + visible_alias = "totp" + )] Code { - #[structopt(help = "Name or UUID of the entry to display")] - name: String, - #[structopt(help = "Username of the entry to display")] + #[arg(help = "Name, URI or UUID of the entry to display", value_parser = commands::parse_needle)] + needle: commands::Needle, + #[arg(help = "Username of the entry to display")] user: Option<String>, - #[structopt(long, help = "Folder name to search in")] + #[arg(long, help = "Folder name to search in")] folder: Option<String>, + #[structopt(long, help = "Copy result to clipboard")] + clipboard: bool, }, - #[structopt( + #[command( about = "Add a new password to the database", long_about = "Add a new password to the database\n\n\ This command will open a text editor to enter \ @@ -91,28 +112,27 @@ enum Opt { remainder will be saved as a note." )] Add { - #[structopt(help = "Name of the password entry")] + #[arg(help = "Name of the password entry")] name: String, - #[structopt(help = "Username for the password entry")] + #[arg(help = "Username for the password entry")] user: Option<String>, - #[structopt( + #[arg( long, help = "URI for the password entry", - multiple = true, number_of_values = 1 )] uri: Vec<String>, - #[structopt(long, help = "Folder for the password entry")] + #[arg(long, help = "Folder for the password entry")] folder: Option<String>, }, - #[structopt( + #[command( about = "Generate a new password", long_about = "Generate a new password\n\n\ If given a password entry name, also save the generated \ password to the database.", visible_alias = "gen", - group = structopt::clap::ArgGroup::with_name("password-type").args(&[ + group = clap::ArgGroup::new("password-type").args(&[ "no-symbols", "only-numbers", "nonconfusables", @@ -120,39 +140,38 @@ enum Opt { ]) )] Generate { - #[structopt(help = "Length of the password to generate")] + #[arg(help = "Length of the password to generate")] len: usize, - #[structopt(help = "Name of the password entry")] + #[arg(help = "Name of the password entry")] name: Option<String>, - #[structopt(help = "Username for the password entry")] + #[arg(help = "Username for the password entry")] user: Option<String>, - #[structopt( + #[arg( long, help = "URI for the password entry", - multiple = true, number_of_values = 1 )] uri: Vec<String>, - #[structopt(long, help = "Folder for the password entry")] + #[arg(long, help = "Folder for the password entry")] folder: Option<String>, - #[structopt( + #[arg( long = "no-symbols", help = "Generate a password with no special characters" )] no_symbols: bool, - #[structopt( + #[arg( long = "only-numbers", help = "Generate a password consisting of only numbers" )] only_numbers: bool, - #[structopt( + #[arg( long, help = "Generate a password without visually similar \ characters (useful for passwords intended to be \ written down)" )] nonconfusables: bool, - #[structopt( + #[arg( long, help = "Generate a password of multiple dictionary \ words chosen from the EFF word list. The len \ @@ -162,7 +181,7 @@ enum Opt { diceware: bool, }, - #[structopt( + #[command( about = "Modify an existing password", long_about = "Modify an existing password\n\n\ This command will open a text editor with the existing \ @@ -173,50 +192,48 @@ enum Opt { as a note." )] Edit { - #[structopt(help = "Name or UUID of the password entry")] + #[arg(help = "Name or UUID of the password entry")] name: String, - #[structopt(help = "Username for the password entry")] + #[arg(help = "Username for the password entry")] user: Option<String>, - #[structopt(long, help = "Folder name to search in")] + #[arg(long, help = "Folder name to search in")] folder: Option<String>, }, - #[structopt(about = "Remove a given entry", visible_alias = "rm")] + #[command(about = "Remove a given entry", visible_alias = "rm")] Remove { - #[structopt(help = "Name or UUID of the password entry")] + #[arg(help = "Name or UUID of the password entry")] name: String, - #[structopt(help = "Username for the password entry")] + #[arg(help = "Username for the password entry")] user: Option<String>, - #[structopt(long, help = "Folder name to search in")] + #[arg(long, help = "Folder name to search in")] folder: Option<String>, }, - #[structopt(about = "View the password history for a given entry")] + #[command(about = "View the password history for a given entry")] History { - #[structopt(help = "Name or UUID of the password entry")] + #[arg(help = "Name or UUID of the password entry")] name: String, - #[structopt(help = "Username for the password entry")] + #[arg(help = "Username for the password entry")] user: Option<String>, - #[structopt(long, help = "Folder name to search in")] + #[arg(long, help = "Folder name to search in")] folder: Option<String>, }, - #[structopt(about = "Lock the password database")] + #[command(about = "Lock the password database")] Lock, - #[structopt(about = "Remove the local copy of the password database")] + #[command(about = "Remove the local copy of the password database")] Purge, - #[structopt( - name = "stop-agent", - about = "Terminate the background agent" - )] + #[command(name = "stop-agent", about = "Terminate the background agent")] StopAgent, - #[structopt( + + #[command( name = "gen-completions", about = "Generate completion script for the given shell" )] - GenCompletions { shell: String }, + GenCompletions { shell: clap_complete::Shell }, } impl Opt { @@ -246,20 +263,20 @@ impl Opt { } } -#[derive(Debug, structopt::StructOpt)] +#[derive(Debug, clap::Parser)] enum Config { - #[structopt(about = "Show the values of all configuration settings")] + #[command(about = "Show the values of all configuration settings")] Show, - #[structopt(about = "Set a configuration option")] + #[command(about = "Set a configuration option")] Set { - #[structopt(help = "Configuration key to set")] + #[arg(help = "Configuration key to set")] key: String, - #[structopt(help = "Value to set the configuration option to")] + #[arg(help = "Value to set the configuration option to")] value: String, }, - #[structopt(about = "Reset a configuration option to its default")] + #[command(about = "Reset a configuration option to its default")] Unset { - #[structopt(help = "Configuration key to unset")] + #[arg(help = "Configuration key to unset")] key: String, }, } @@ -275,15 +292,18 @@ impl Config { } } -#[paw::main] -fn main(opt: Opt) { +fn main() { + let opt = Opt::parse(); + env_logger::Builder::from_env( env_logger::Env::default().default_filter_or("info"), ) .format(|buf, record| { - if let Some((w, _)) = term_size::dimensions() { + if let Some((terminal_size::Width(w), _)) = + terminal_size::terminal_size() + { let out = format!("{}: {}", record.level(), record.args()); - writeln!(buf, "{}", textwrap::fill(&out, w - 1)) + writeln!(buf, "{}", textwrap::fill(&out, usize::from(w) - 1)) } else { writeln!(buf, "{}: {}", record.level(), record.args()) } @@ -303,14 +323,33 @@ fn main(opt: Opt) { Opt::Sync => commands::sync(), Opt::List { fields } => commands::list(fields), Opt::Get { - name, + needle, user, folder, + field, full, - } => commands::get(name, user.as_deref(), folder.as_deref(), *full), - Opt::Code { name, user, folder } => { - commands::code(name, user.as_deref(), folder.as_deref()) - } + raw, + clipboard, + } => commands::get( + needle, + user.as_deref(), + folder.as_deref(), + field.as_deref(), + *full, + *raw, + *clipboard, + ), + Opt::Code { + needle, + user, + folder, + clipboard, + } => commands::code( + needle, + user.as_deref(), + folder.as_deref(), + *clipboard, + ), Opt::Add { name, user, @@ -319,7 +358,7 @@ fn main(opt: Opt) { } => commands::add( name, user.as_deref(), - uri.iter() + &uri.iter() // XXX not sure what the ui for specifying the match type // should be .map(|uri| (uri.clone(), None)) @@ -351,7 +390,7 @@ fn main(opt: Opt) { commands::generate( name.as_deref(), user.as_deref(), - uri.iter() + &uri.iter() // XXX not sure what the ui for specifying the match type // should be .map(|uri| (uri.clone(), None)) @@ -373,25 +412,20 @@ fn main(opt: Opt) { Opt::Lock => commands::lock(), Opt::Purge => commands::purge(), Opt::StopAgent => commands::stop_agent(), - Opt::GenCompletions { shell } => gen_completions(shell), + Opt::GenCompletions { shell } => { + clap_complete::generate( + *shell, + &mut Opt::command(), + "rbw", + &mut std::io::stdout(), + ); + Ok(()) + } } .context(format!("rbw {}", opt.subcommand_name())); if let Err(e) = res { - eprintln!("{:#}", e); + eprintln!("{e:#}"); std::process::exit(1); } } - -fn gen_completions(shell: &str) -> anyhow::Result<()> { - let shell = match shell { - "bash" => structopt::clap::Shell::Bash, - "zsh" => structopt::clap::Shell::Zsh, - "fish" => structopt::clap::Shell::Fish, - "powershell" => structopt::clap::Shell::PowerShell, - "elvish" => structopt::clap::Shell::Elvish, - _ => return Err(anyhow::anyhow!("unknown shell {}", shell)), - }; - Opt::clap().gen_completions_to("rbw", shell, &mut std::io::stdout()); - Ok(()) -} diff --git a/src/cipherstring.rs b/src/cipherstring.rs index 39254c7..883cb34 100644 --- a/src/cipherstring.rs +++ b/src/cipherstring.rs @@ -1,10 +1,11 @@ use crate::prelude::*; -use block_modes::BlockMode as _; -use block_padding::Padding as _; -use hmac::{Mac as _, NewMac as _}; +use aes::cipher::{ + BlockDecryptMut as _, BlockEncryptMut as _, KeyIvInit as _, +}; +use hmac::Mac as _; +use pkcs8::DecodePrivateKey as _; use rand::RngCore as _; -use rsa::pkcs8::FromPrivateKey as _; use zeroize::Zeroize as _; pub enum CipherString { @@ -51,15 +52,15 @@ impl CipherString { }); } - let iv = base64::decode(parts[0]) + let iv = crate::base64::decode(parts[0]) .map_err(|source| Error::InvalidBase64 { source })?; - let ciphertext = base64::decode(parts[1]) + let ciphertext = crate::base64::decode(parts[1]) .map_err(|source| Error::InvalidBase64 { source })?; let mac = if parts.len() > 2 { - Some(base64::decode(parts[2]).map_err(|source| { - Error::InvalidBase64 { source } - })?) + Some(crate::base64::decode(parts[2]).map_err( + |source| Error::InvalidBase64 { source }, + )?) } else { None }; @@ -76,7 +77,7 @@ impl CipherString { // https://github.com/bitwarden/jslib/blob/785b681f61f81690de6df55159ab07ae710bcfad/src/enums/encryptionType.ts#L8 // format is: <cipher_text_b64>|<hmac_sig> let contents = contents.split('|').next().unwrap(); - let ciphertext = base64::decode(contents) + let ciphertext = crate::base64::decode(contents) .map_err(|source| Error::InvalidBase64 { source })?; Ok(Self::Asymmetric { ciphertext }) } @@ -98,12 +99,12 @@ impl CipherString { ) -> Result<Self> { let iv = random_iv(); - let cipher = block_modes::Cbc::< - aes::Aes256, - block_modes::block_padding::Pkcs7, - >::new_from_slices(keys.enc_key(), &iv) - .map_err(|source| Error::CreateBlockMode { source })?; - let ciphertext = cipher.encrypt_vec(plaintext); + let cipher = cbc::Encryptor::<aes::Aes256>::new( + keys.enc_key().into(), + iv.as_slice().into(), + ); + let ciphertext = + cipher.encrypt_padded_vec_mut::<block_padding::Pkcs7>(plaintext); let mut digest = hmac::Hmac::<sha2::Sha256>::new_from_slice(keys.mac_key()) @@ -136,7 +137,7 @@ impl CipherString { mac.as_deref(), )?; cipher - .decrypt_vec(ciphertext) + .decrypt_padded_vec_mut::<block_padding::Pkcs7>(ciphertext) .map_err(|source| Error::Decrypt { source }) } else { Err(Error::InvalidCipherString { @@ -166,7 +167,7 @@ impl CipherString { mac.as_deref(), )?; cipher - .decrypt(res.data_mut()) + .decrypt_padded_mut::<block_padding::Pkcs7>(res.data_mut()) .map_err(|source| Error::Decrypt { source })?; Ok(res) } else { @@ -184,15 +185,12 @@ impl CipherString { ) -> Result<crate::locked::Vec> { if let Self::Asymmetric { ciphertext } = self { let privkey_data = private_key.private_key(); - let privkey_data = block_padding::Pkcs7::unpad(privkey_data) - .map_err(|_| Error::Padding)?; + let privkey_data = + pkcs7_unpad(privkey_data).ok_or(Error::Padding)?; let pkey = rsa::RsaPrivateKey::from_pkcs8_der(privkey_data) .map_err(|source| Error::RsaPkcs8 { source })?; let mut bytes = pkey - .decrypt( - rsa::padding::PaddingScheme::new_oaep::<sha1::Sha1>(), - ciphertext, - ) + .decrypt(rsa::Oaep::new::<sha1::Sha1>(), ciphertext) .map_err(|source| Error::Rsa { source })?; // XXX it'd be great if the rsa crate would let us decrypt @@ -218,8 +216,7 @@ fn decrypt_common_symmetric( iv: &[u8], ciphertext: &[u8], mac: Option<&[u8]>, -) -> Result<block_modes::Cbc<aes::Aes256, block_modes::block_padding::Pkcs7>> -{ +) -> Result<cbc::Decryptor<aes::Aes256>> { if let Some(mac) = mac { let mut key = hmac::Hmac::<sha2::Sha256>::new_from_slice(keys.mac_key()) @@ -227,15 +224,12 @@ fn decrypt_common_symmetric( key.update(iv); key.update(ciphertext); - if key.verify(mac).is_err() { + if key.verify(mac.into()).is_err() { return Err(Error::InvalidMac); } } - block_modes::Cbc::< - aes::Aes256, - block_modes::block_padding::Pkcs7, - >::new_from_slices(keys.enc_key(), iv) + cbc::Decryptor::<aes::Aes256>::new_from_slices(keys.enc_key(), iv) .map_err(|source| Error::CreateBlockMode { source }) } @@ -247,18 +241,18 @@ impl std::fmt::Display for CipherString { ciphertext, mac, } => { - let iv = base64::encode(&iv); - let ciphertext = base64::encode(&ciphertext); + let iv = crate::base64::encode(iv); + let ciphertext = crate::base64::encode(ciphertext); if let Some(mac) = &mac { - let mac = base64::encode(&mac); - write!(f, "2.{}|{}|{}", iv, ciphertext, mac) + let mac = crate::base64::encode(mac); + write!(f, "2.{iv}|{ciphertext}|{mac}") } else { - write!(f, "2.{}|{}", iv, ciphertext) + write!(f, "2.{iv}|{ciphertext}") } } Self::Asymmetric { ciphertext } => { - let ciphertext = base64::encode(&ciphertext); - write!(f, "4.{}", ciphertext) + let ciphertext = crate::base64::encode(ciphertext); + write!(f, "4.{ciphertext}") } } } @@ -270,3 +264,51 @@ fn random_iv() -> Vec<u8> { rng.fill_bytes(&mut iv); iv } + +// XXX this should ideally just be block_padding::Pkcs7::unpad, but i can't +// figure out how to get the generic types to work out +fn pkcs7_unpad(b: &[u8]) -> Option<&[u8]> { + if b.is_empty() { + return None; + } + + let padding_val = b[b.len() - 1]; + if padding_val == 0 { + return None; + } + + let padding_len = usize::from(padding_val); + if padding_len > b.len() { + return None; + } + + for c in b.iter().copied().skip(b.len() - padding_len) { + if c != padding_val { + return None; + } + } + + Some(&b[..b.len() - padding_len]) +} + +#[test] +fn test_pkcs7_unpad() { + let tests = [ + (&[][..], None), + (&[0x01][..], Some(&[][..])), + (&[0x02, 0x02][..], Some(&[][..])), + (&[0x03, 0x03, 0x03][..], Some(&[][..])), + (&[0x69, 0x01][..], Some(&[0x69][..])), + (&[0x69, 0x02, 0x02][..], Some(&[0x69][..])), + (&[0x69, 0x03, 0x03, 0x03][..], Some(&[0x69][..])), + (&[0x02][..], None), + (&[0x03][..], None), + (&[0x69, 0x69, 0x03, 0x03][..], None), + (&[0x00][..], None), + (&[0x02, 0x00][..], None), + ]; + for (input, expected) in tests { + let got = pkcs7_unpad(input); + assert_eq!(got, expected); + } +} diff --git a/src/config.rs b/src/config.rs index bbc39f7..efb1b5f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,51 +1,59 @@ use crate::prelude::*; use std::io::{Read as _, Write as _}; -use tokio::io::AsyncReadExt as _; +use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; #[derive(serde::Serialize, serde::Deserialize, Debug)] pub struct Config { pub email: Option<String>, pub base_url: Option<String>, pub identity_url: Option<String>, + pub notifications_url: Option<String>, #[serde(default = "default_lock_timeout")] pub lock_timeout: u64, + #[serde(default = "default_sync_interval")] + pub sync_interval: u64, #[serde(default = "default_pinentry")] pub pinentry: String, - #[serde(default = "stub_device_id")] - pub device_id: String, + pub client_cert_path: Option<std::path::PathBuf>, + // backcompat, no longer generated in new configs + #[serde(skip_serializing)] + pub device_id: Option<String>, } impl Default for Config { fn default() -> Self { Self { - email: Default::default(), - base_url: Default::default(), - identity_url: Default::default(), + email: None, + base_url: None, + identity_url: None, + notifications_url: None, lock_timeout: default_lock_timeout(), + sync_interval: default_sync_interval(), pinentry: default_pinentry(), - device_id: default_device_id(), + client_cert_path: None, + device_id: None, } } } +#[must_use] pub fn default_lock_timeout() -> u64 { 3600 } -pub fn default_pinentry() -> String { - "pinentry".to_string() -} - -fn default_device_id() -> String { - uuid::Uuid::new_v4().to_hyphenated().to_string() +#[must_use] +pub fn default_sync_interval() -> u64 { + 3600 } -fn stub_device_id() -> String { - String::from("fix") +#[must_use] +pub fn default_pinentry() -> String { + "pinentry".to_string() } impl Config { + #[must_use] pub fn new() -> Self { Self::default() } @@ -127,36 +135,103 @@ impl Config { } pub fn validate() -> Result<()> { - let mut config = Self::load()?; + let config = Self::load()?; if config.email.is_none() { return Err(Error::ConfigMissingEmail); } - if config.device_id == stub_device_id() { - config.device_id = default_device_id(); - config.save()?; - } Ok(()) } + #[must_use] pub fn base_url(&self) -> String { self.base_url.clone().map_or_else( || "https://api.bitwarden.com".to_string(), - |url| format!("{}/api", url.trim_end_matches('/')), + |url| { + let clean_url = url.trim_end_matches('/').to_string(); + if clean_url == "https://api.bitwarden.eu" { + clean_url + } else { + format!("{clean_url}/api") + } + }, ) } + #[must_use] pub fn identity_url(&self) -> String { self.identity_url.clone().unwrap_or_else(|| { self.base_url.clone().map_or_else( || "https://identity.bitwarden.com".to_string(), - |url| format!("{}/identity", url.trim_end_matches('/')), + |url| { + let clean_url = url.trim_end_matches('/').to_string(); + if clean_url == "https://identity.bitwarden.eu" { + clean_url + } else { + format!("{clean_url}/identity") + } + }, + ) + }) + } + + #[must_use] + pub fn notifications_url(&self) -> String { + self.notifications_url.clone().unwrap_or_else(|| { + self.base_url.clone().map_or_else( + || "https://notifications.bitwarden.com".to_string(), + |url| { + let clean_url = url.trim_end_matches('/').to_string(); + if clean_url == "https://notifications.bitwarden.eu" { + clean_url + } else { + format!("{clean_url}/notifications") + } + }, ) }) } + #[must_use] + pub fn client_cert_path(&self) -> Option<&std::path::Path> { + self.client_cert_path.as_deref() + } + + #[must_use] pub fn server_name(&self) -> String { self.base_url .clone() .unwrap_or_else(|| "default".to_string()) } } + +pub async fn device_id(config: &Config) -> Result<String> { + let file = crate::dirs::device_id_file(); + if let Ok(mut fh) = tokio::fs::File::open(&file).await { + let mut s = String::new(); + fh.read_to_string(&mut s) + .await + .map_err(|e| Error::LoadDeviceId { + source: e, + file: file.clone(), + })?; + Ok(s.trim().to_string()) + } else { + let id = config.device_id.as_ref().map_or_else( + || uuid::Uuid::new_v4().hyphenated().to_string(), + String::to_string, + ); + let mut fh = tokio::fs::File::create(&file).await.map_err(|e| { + Error::LoadDeviceId { + source: e, + file: file.clone(), + } + })?; + fh.write_all(id.as_bytes()).await.map_err(|e| { + Error::LoadDeviceId { + source: e, + file: file.clone(), + } + })?; + Ok(id) + } +} @@ -147,8 +147,10 @@ pub enum EntryData { serde::Serialize, serde::Deserialize, Debug, Clone, Eq, PartialEq, )] pub struct Field { + pub ty: crate::api::FieldType, pub name: Option<String>, pub value: Option<String>, + pub linked_id: Option<crate::api::LinkedIdType>, } #[derive( @@ -164,7 +166,10 @@ pub struct Db { pub access_token: Option<String>, pub refresh_token: Option<String>, + pub kdf: Option<crate::api::KdfType>, pub iterations: Option<u32>, + pub memory: Option<u32>, + pub parallelism: Option<u32>, pub protected_key: Option<String>, pub protected_private_key: Option<String>, pub protected_org_keys: std::collections::HashMap<String, String>, @@ -173,6 +178,7 @@ pub struct Db { } impl Db { + #[must_use] pub fn new() -> Self { Self::default() } @@ -287,10 +293,12 @@ impl Db { Ok(()) } + #[must_use] pub fn needs_login(&self) -> bool { self.access_token.is_none() || self.refresh_token.is_none() || self.iterations.is_none() + || self.kdf.is_none() || self.protected_key.is_none() } } diff --git a/src/dirs.rs b/src/dirs.rs index 285a0d5..2fa6e50 100644 --- a/src/dirs.rs +++ b/src/dirs.rs @@ -37,59 +37,89 @@ pub fn make_all() -> Result<()> { Ok(()) } +#[must_use] pub fn config_file() -> std::path::PathBuf { config_dir().join("config.json") } const INVALID_PATH: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS.add(b'/').add(b'%').add(b':'); +#[must_use] pub fn db_file(server: &str, email: &str) -> std::path::PathBuf { let server = percent_encoding::percent_encode(server.as_bytes(), INVALID_PATH) .to_string(); - cache_dir().join(format!("{}:{}.json", server, email)) + cache_dir().join(format!("{server}:{email}.json")) } +#[must_use] pub fn pid_file() -> std::path::PathBuf { runtime_dir().join("pidfile") } +#[must_use] pub fn agent_stdout_file() -> std::path::PathBuf { data_dir().join("agent.out") } +#[must_use] pub fn agent_stderr_file() -> std::path::PathBuf { data_dir().join("agent.err") } +#[must_use] +pub fn device_id_file() -> std::path::PathBuf { + data_dir().join("device_id") +} + +#[must_use] pub fn socket_file() -> std::path::PathBuf { runtime_dir().join("socket") } +#[must_use] fn config_dir() -> std::path::PathBuf { - let project_dirs = directories::ProjectDirs::from("", "", "rbw").unwrap(); + let project_dirs = + directories::ProjectDirs::from("", "", &profile()).unwrap(); project_dirs.config_dir().to_path_buf() } +#[must_use] fn cache_dir() -> std::path::PathBuf { - let project_dirs = directories::ProjectDirs::from("", "", "rbw").unwrap(); + let project_dirs = + directories::ProjectDirs::from("", "", &profile()).unwrap(); project_dirs.cache_dir().to_path_buf() } +#[must_use] fn data_dir() -> std::path::PathBuf { - let project_dirs = directories::ProjectDirs::from("", "", "rbw").unwrap(); + let project_dirs = + directories::ProjectDirs::from("", "", &profile()).unwrap(); project_dirs.data_dir().to_path_buf() } +#[must_use] fn runtime_dir() -> std::path::PathBuf { - let project_dirs = directories::ProjectDirs::from("", "", "rbw").unwrap(); - match project_dirs.runtime_dir() { - Some(dir) => dir.to_path_buf(), - None => format!( - "{}/rbw-{}", - std::env::temp_dir().to_string_lossy(), - nix::unistd::getuid().as_raw() - ) - .into(), + let project_dirs = + directories::ProjectDirs::from("", "", &profile()).unwrap(); + project_dirs.runtime_dir().map_or_else( + || { + format!( + "{}/{}-{}", + std::env::temp_dir().to_string_lossy(), + &profile(), + rustix::process::getuid().as_raw() + ) + .into() + }, + std::path::Path::to_path_buf, + ) +} + +#[must_use] +pub fn profile() -> String { + match std::env::var("RBW_PROFILE") { + Ok(profile) if !profile.is_empty() => format!("rbw-{profile}"), + _ => "rbw".to_string(), } } diff --git a/src/edit.rs b/src/edit.rs index 1a831a7..360f31f 100644 --- a/src/edit.rs +++ b/src/edit.rs @@ -2,7 +2,17 @@ use crate::prelude::*; use std::io::{Read as _, Write as _}; +use is_terminal::IsTerminal as _; + pub fn edit(contents: &str, help: &str) -> Result<String> { + if !std::io::stdin().is_terminal() { + // directly read from piped content + return match std::io::read_to_string(std::io::stdin()) { + Err(e) => Err(Error::FailedToReadFromStdin { err: e }), + Ok(res) => Ok(res), + }; + } + let mut var = "VISUAL"; let editor = std::env::var_os(var).unwrap_or_else(|| { var = "EDITOR"; @@ -30,6 +40,7 @@ pub fn edit(contents: &str, help: &str) -> Result<String> { let editor = std::path::Path::new(&editor); let mut editor_args = vec![]; + #[allow(clippy::single_match_else)] // more to come match editor.file_name() { Some(editor) => match editor.to_str() { Some("vim" | "nvim") => { @@ -52,7 +63,7 @@ pub fn edit(contents: &str, help: &str) -> Result<String> { (editor, editor_args) }; - let res = std::process::Command::new(&cmd).args(&args).status(); + let res = std::process::Command::new(cmd).args(&args).status(); match res { Ok(res) => { if !res.success() { @@ -80,8 +91,6 @@ pub fn edit(contents: &str, help: &str) -> Result<String> { } fn contains_shell_metacharacters(cmd: &std::ffi::OsStr) -> bool { - match cmd.to_str() { - Some(s) => s.contains(&[' ', '$', '\'', '"'][..]), - None => false, - } + cmd.to_str() + .map_or(false, |s| s.contains(&[' ', '$', '\'', '"'][..])) } diff --git a/src/error.rs b/src/error.rs index 8116de2..db0503a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,14 +4,10 @@ pub enum Error { ConfigMissingEmail, #[error("failed to create block mode decryptor")] - CreateBlockMode { - source: block_modes::InvalidKeyIvLength, - }, + CreateBlockMode { source: aes::cipher::InvalidLength }, #[error("failed to create block mode decryptor")] - CreateHmac { - source: hmac::crypto_mac::InvalidKeyLength, - }, + CreateHmac { source: aes::cipher::InvalidLength }, #[error("failed to create directory at {}", .file.display())] CreateDirectory { @@ -19,12 +15,18 @@ pub enum Error { file: std::path::PathBuf, }, + #[error("failed to create reqwest client")] + CreateReqwestClient { source: reqwest::Error }, + #[error("failed to decrypt")] - Decrypt { source: block_modes::BlockModeError }, + Decrypt { source: block_padding::UnpadError }, #[error("failed to parse pinentry output ({out:?})")] FailedToParsePinentry { out: String }, + #[error("failed to read from stdin: {err}")] + FailedToReadFromStdin { err: std::io::Error }, + #[error( "failed to run editor {}: {err}", .editor.to_string_lossy(), @@ -116,6 +118,18 @@ pub enum Error { file: std::path::PathBuf, }, + #[error("failed to load device id from {}", .file.display())] + LoadDeviceId { + source: tokio::io::Error, + file: std::path::PathBuf, + }, + + #[error("failed to load client cert from {}", .file.display())] + LoadClientCert { + source: tokio::io::Error, + file: std::path::PathBuf, + }, + #[error("invalid padding")] Padding, @@ -125,6 +139,12 @@ pub enum Error { #[error("pbkdf2 requires at least 1 iteration (got 0)")] Pbkdf2ZeroIterations, + #[error("failed to run pbkdf2")] + Pbkdf2, + + #[error("failed to run argon2")] + Argon2, + #[error("pinentry cancelled")] PinentryCancelled, @@ -207,6 +227,9 @@ pub enum Error { #[error("error writing to pinentry stdin")] WriteStdin { source: tokio::io::Error }, + + #[error("invalid kdf type: {ty}")] + InvalidKdfType { ty: String }, } pub type Result<T> = std::result::Result<T, Error>; diff --git a/src/identity.rs b/src/identity.rs index 90d4fad..fd46b85 100644 --- a/src/identity.rs +++ b/src/identity.rs @@ -1,5 +1,7 @@ use crate::prelude::*; +use sha1::Digest as _; + pub struct Identity { pub email: String, pub keys: crate::locked::Keys, @@ -10,8 +12,13 @@ impl Identity { pub fn new( email: &str, password: &crate::locked::Password, + kdf: crate::api::KdfType, iterations: u32, + memory: Option<u32>, + parallelism: Option<u32>, ) -> Result<Self> { + let email = email.trim().to_lowercase(); + let iterations = std::num::NonZeroU32::new(iterations) .ok_or(Error::Pbkdf2ZeroIterations)?; @@ -19,12 +26,43 @@ impl Identity { keys.extend(std::iter::repeat(0).take(64)); let enc_key = &mut keys.data_mut()[0..32]; - pbkdf2::pbkdf2::<hmac::Hmac<sha2::Sha256>>( - password.password(), - email.as_bytes(), - iterations.get(), - enc_key, - ); + + match kdf { + crate::api::KdfType::Pbkdf2 => { + pbkdf2::pbkdf2::<hmac::Hmac<sha2::Sha256>>( + password.password(), + email.as_bytes(), + iterations.get(), + enc_key, + ) + .map_err(|_| Error::Pbkdf2)?; + } + + crate::api::KdfType::Argon2id => { + let mut hasher = sha2::Sha256::new(); + hasher.update(email.as_bytes()); + let salt = hasher.finalize(); + + let argon2_config = argon2::Argon2::new( + argon2::Algorithm::Argon2id, + argon2::Version::V0x13, + argon2::Params::new( + memory.unwrap() * 1024, + iterations.get(), + parallelism.unwrap(), + Some(32), + ) + .unwrap(), + ); + argon2::Argon2::hash_password_into( + &argon2_config, + password.password(), + &salt, + enc_key, + ) + .map_err(|_| Error::Argon2)?; + } + }; let mut hash = crate::locked::Vec::new(); hash.extend(std::iter::repeat(0).take(32)); @@ -33,7 +71,8 @@ impl Identity { password.password(), 1, hash.data_mut(), - ); + ) + .map_err(|_| Error::Pbkdf2)?; let hkdf = hkdf::Hkdf::<sha2::Sha256>::from_prk(enc_key) .map_err(|_| Error::HkdfExpand)?; @@ -1,21 +1,24 @@ +#![warn(clippy::cargo)] #![warn(clippy::pedantic)] #![warn(clippy::nursery)] -#![allow(clippy::default_trait_access)] -#![allow(clippy::implicit_hasher)] -#![allow(clippy::large_enum_variant)] +#![warn(clippy::as_conversions)] +#![warn(clippy::get_unwrap)] +#![allow(clippy::cognitive_complexity)] #![allow(clippy::missing_const_for_fn)] -#![allow(clippy::missing_errors_doc)] -#![allow(clippy::missing_panics_doc)] -#![allow(clippy::must_use_candidate)] #![allow(clippy::similar_names)] -#![allow(clippy::single_match)] +#![allow(clippy::struct_excessive_bools)] #![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_lines)] #![allow(clippy::type_complexity)] -#![allow(clippy::unused_async)] +#![allow(clippy::multiple_crate_versions)] +#![allow(clippy::large_enum_variant)] +// we aren't really documenting apis anyway +#![allow(clippy::missing_errors_doc)] +#![allow(clippy::missing_panics_doc)] pub mod actions; pub mod api; +pub mod base64; pub mod cipherstring; pub mod config; pub mod db; diff --git a/src/locked.rs b/src/locked.rs index 4ddf021..bfa642d 100644 --- a/src/locked.rs +++ b/src/locked.rs @@ -18,10 +18,12 @@ impl Default for Vec { } impl Vec { + #[must_use] pub fn new() -> Self { Self::default() } + #[must_use] pub fn data(&self) -> &[u8] { self.data.as_slice() } @@ -65,10 +67,12 @@ pub struct Password { } impl Password { + #[must_use] pub fn new(password: Vec) -> Self { Self { password } } + #[must_use] pub fn password(&self) -> &[u8] { self.password.data() } @@ -80,14 +84,17 @@ pub struct Keys { } impl Keys { + #[must_use] pub fn new(keys: Vec) -> Self { Self { keys } } + #[must_use] pub fn enc_key(&self) -> &[u8] { &self.keys.data()[0..32] } + #[must_use] pub fn mac_key(&self) -> &[u8] { &self.keys.data()[32..64] } @@ -99,10 +106,12 @@ pub struct PasswordHash { } impl PasswordHash { + #[must_use] pub fn new(hash: Vec) -> Self { Self { hash } } + #[must_use] pub fn hash(&self) -> &[u8] { self.hash.data() } @@ -114,10 +123,12 @@ pub struct PrivateKey { } impl PrivateKey { + #[must_use] pub fn new(private_key: Vec) -> Self { Self { private_key } } + #[must_use] pub fn private_key(&self) -> &[u8] { self.private_key.data() } @@ -130,6 +141,7 @@ pub struct ApiKey { } impl ApiKey { + #[must_use] pub fn new(client_id: Password, client_secret: Password) -> Self { Self { client_id, @@ -137,10 +149,12 @@ impl ApiKey { } } + #[must_use] pub fn client_id(&self) -> &[u8] { self.client_id.password() } + #[must_use] pub fn client_secret(&self) -> &[u8] { self.client_secret.password() } diff --git a/src/pinentry.rs b/src/pinentry.rs index b4d2bb0..e2a83ed 100644 --- a/src/pinentry.rs +++ b/src/pinentry.rs @@ -1,5 +1,6 @@ use crate::prelude::*; +use std::convert::TryFrom as _; use tokio::io::AsyncWriteExt as _; pub async fn getpin( @@ -33,18 +34,18 @@ pub async fn getpin( .map_err(|source| Error::WriteStdin { source })?; ncommands += 1; stdin - .write_all(format!("SETPROMPT {}\n", prompt).as_bytes()) + .write_all(format!("SETPROMPT {prompt}\n").as_bytes()) .await .map_err(|source| Error::WriteStdin { source })?; ncommands += 1; stdin - .write_all(format!("SETDESC {}\n", desc).as_bytes()) + .write_all(format!("SETDESC {desc}\n").as_bytes()) .await .map_err(|source| Error::WriteStdin { source })?; ncommands += 1; if let Some(err) = err { stdin - .write_all(format!("SETERROR {}\n", err).as_bytes()) + .write_all(format!("SETERROR {err}\n").as_bytes()) .await .map_err(|source| Error::WriteStdin { source })?; ncommands += 1; @@ -76,15 +77,13 @@ pub async fn getpin( Ok(crate::locked::Password::new(buf)) } -async fn read_password< - R: tokio::io::AsyncRead + tokio::io::AsyncReadExt + Unpin, ->( +async fn read_password<R>( mut ncommands: u8, data: &mut [u8], mut r: R, ) -> Result<usize> where - R: Send, + R: tokio::io::AsyncRead + tokio::io::AsyncReadExt + Unpin + Send, { let mut len = 0; loop { @@ -119,7 +118,7 @@ where }); } return Err(Error::PinentryErrorMessage { - error: format!("unknown error ({})", code), + error: format!("unknown error ({code})"), }); } None => { @@ -138,6 +137,14 @@ where .read(&mut data[len..]) .await .map_err(|source| Error::PinentryReadOutput { source })?; + if bytes == 0 { + return Err(Error::PinentryReadOutput { + source: std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected EOF", + ), + }); + } len += bytes; } } @@ -161,9 +168,11 @@ fn percent_decode(buf: &mut [u8]) -> usize { if c == b'%' && read_idx + 2 < len { if let Some(h) = char::from(buf[read_idx + 1]).to_digit(16) { - #[allow(clippy::cast_possible_truncation)] if let Some(l) = char::from(buf[read_idx + 2]).to_digit(16) { - c = h as u8 * 0x10 + l as u8; + // h and l were parsed from a single hex digit, so they + // must be in the range 0-15, so these unwraps are safe + c = u8::try_from(h).unwrap() * 0x10 + + u8::try_from(l).unwrap(); read_idx += 2; } } diff --git a/src/protocol.rs b/src/protocol.rs index 14fa7f9..e883441 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,8 +1,6 @@ -// https://github.com/rust-lang/rust-clippy/issues/6902 -#![allow(clippy::use_self)] - // eventually it would be nice to make this a const function so that we could // just get the version from a variable directly, but this is fine for now +#[must_use] pub fn version() -> u32 { let major = env!("CARGO_PKG_VERSION_MAJOR"); let minor = env!("CARGO_PKG_VERSION_MINOR"); @@ -36,6 +34,9 @@ pub enum Action { plaintext: String, org_id: Option<String>, }, + ClipboardStore { + text: String, + }, Quit, Version, } diff --git a/src/pwgen.rs b/src/pwgen.rs index 55151e6..a112d73 100644 --- a/src/pwgen.rs +++ b/src/pwgen.rs @@ -15,6 +15,7 @@ pub enum Type { Diceware, } +#[must_use] pub fn pwgen(ty: Type, len: usize) -> String { let mut rng = rand::thread_rng(); @@ -101,6 +102,6 @@ mod test { for c in s.chars() { set.insert(c); } - assert!(set.len() < s.len()) + assert!(set.len() < s.len()); } } |