diff options
Diffstat (limited to 'src/bin/rbw-agent')
-rw-r--r-- | src/bin/rbw-agent/actions.rs | 331 | ||||
-rw-r--r-- | src/bin/rbw-agent/agent.rs | 205 | ||||
-rw-r--r-- | src/bin/rbw-agent/daemon.rs | 59 | ||||
-rw-r--r-- | src/bin/rbw-agent/debugger.rs | 2 | ||||
-rw-r--r-- | src/bin/rbw-agent/main.rs | 31 | ||||
-rw-r--r-- | src/bin/rbw-agent/notifications.rs | 174 | ||||
-rw-r--r-- | src/bin/rbw-agent/sock.rs | 5 | ||||
-rw-r--r-- | src/bin/rbw-agent/timeout.rs | 66 |
8 files changed, 634 insertions, 239 deletions
diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs index 1cc71c3..674442b 100644 --- a/src/bin/rbw-agent/actions.rs +++ b/src/bin/rbw-agent/actions.rs @@ -10,9 +10,7 @@ pub async fn register( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -33,7 +31,7 @@ pub async fn register( let client_id = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__id", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -43,7 +41,7 @@ pub async fn register( let client_secret = rbw::pinentry::getpin( &config_pinentry().await?, "API key client__secret", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, false, @@ -61,10 +59,9 @@ pub async fn register( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -81,7 +78,7 @@ pub async fn register( pub async fn login( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { let db = load_db().await.unwrap_or_else(|_| rbw::db::Db::new()); @@ -90,9 +87,7 @@ pub async fn login( let url_str = config_base_url().await?; let url = reqwest::Url::parse(&url_str) .context("failed to parse base url")?; - let host = if let Some(host) = url.host_str() { - host - } else { + let Some(host) = url.host_str() else { return Err(anyhow::anyhow!( "couldn't find host in rbw base url {}", url_str @@ -102,7 +97,7 @@ pub async fn login( let email = config_email().await?; let mut err_msg = None; - for i in 1_u8..=3 { + 'attempts: for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -113,7 +108,7 @@ pub async fn login( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - &format!("Log in to {}", host), + &format!("Log in to {host}"), err.as_deref(), tty, true, @@ -126,55 +121,70 @@ pub async fn login( Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) => { login_success( - sock, - state, + state.clone(), access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, password, db, email, ) .await?; - break; + break 'attempts; } Err(rbw::error::Error::TwoFactorRequired { providers }) => { - if providers.contains( - &rbw::api::TwoFactorProviderType::Authenticator, - ) { - let ( - access_token, - refresh_token, - iterations, - protected_key, - ) = two_factor( - tty, - &email, - password.clone(), - rbw::api::TwoFactorProviderType::Authenticator, - ) - .await?; - login_success( - sock, - state, - access_token, - refresh_token, - iterations, - protected_key, - password, - db, - email, - ) - .await?; - break; - } else { - return Err(anyhow::anyhow!("TODO")); + let supported_types = vec![ + rbw::api::TwoFactorProviderType::Authenticator, + rbw::api::TwoFactorProviderType::Yubikey, + rbw::api::TwoFactorProviderType::Email, + ]; + + for provider in supported_types { + if providers.contains(&provider) { + let ( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + ) = two_factor( + tty, + &email, + password.clone(), + provider, + ) + .await?; + login_success( + state.clone(), + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + password, + db, + email, + ) + .await?; + break 'attempts; + } } + return Err(anyhow::anyhow!("TODO")); } Err(rbw::error::Error::IncorrectPassword { message }) => { if i == 3 { @@ -182,10 +192,9 @@ pub async fn login( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -205,7 +214,15 @@ async fn two_factor( email: &str, password: rbw::locked::Password, provider: rbw::api::TwoFactorProviderType, -) -> anyhow::Result<(String, String, u32, String)> { +) -> anyhow::Result<( + String, + String, + rbw::api::KdfType, + u32, + Option<u32>, + Option<u32>, + String, +)> { let mut err_msg = None; for i in 1_u8..=3 { let err = if i > 1 { @@ -217,11 +234,11 @@ async fn two_factor( }; let code = rbw::pinentry::getpin( &config_pinentry().await?, - "Authenticator App", - "Enter the 6 digit verification code from your authenticator app.", + provider.header(), + provider.message(), err.as_deref(), tty, - true, + provider.grab(), ) .await .context("failed to read code from pinentry")?; @@ -235,11 +252,22 @@ async fn two_factor( ) .await { - Ok((access_token, refresh_token, iterations, protected_key)) => { + Ok(( + access_token, + refresh_token, + kdf, + iterations, + memory, + parallelism, + protected_key, + )) => { return Ok(( access_token, refresh_token, + kdf, iterations, + memory, + parallelism, protected_key, )) } @@ -249,10 +277,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } // can get this if the user passes an empty string Err(rbw::error::Error::TwoFactorRequired { .. }) => { @@ -262,10 +289,9 @@ async fn two_factor( message, }) .context("failed to log in to bitwarden instance"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => { return Err(e) @@ -278,11 +304,13 @@ async fn two_factor( } async fn login_success( - sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, access_token: String, refresh_token: String, + kdf: rbw::api::KdfType, iterations: u32, + memory: Option<u32>, + parallelism: Option<u32>, protected_key: String, password: rbw::locked::Password, mut db: rbw::db::Db, @@ -290,35 +318,37 @@ async fn login_success( ) -> anyhow::Result<()> { db.access_token = Some(access_token.to_string()); db.refresh_token = Some(refresh_token.to_string()); + db.kdf = Some(kdf); db.iterations = Some(iterations); + db.memory = memory; + db.parallelism = parallelism; db.protected_key = Some(protected_key.to_string()); save_db(&db).await?; - sync(sock, false).await?; + sync(None, state.clone()).await?; let db = load_db().await?; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let res = rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await; + ); match res { Ok((keys, org_keys)) => { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); } @@ -330,39 +360,40 @@ async fn login_success( pub async fn unlock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, tty: Option<&str>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { let db = load_db().await?; - let iterations = if let Some(iterations) = db.iterations { - iterations - } else { + let Some(kdf) = db.kdf else { + return Err(anyhow::anyhow!("failed to find kdf type in db")); + }; + + let Some(iterations) = db.iterations else { return Err(anyhow::anyhow!( "failed to find number of iterations in db" )); }; - let protected_key = if let Some(protected_key) = db.protected_key { - protected_key - } else { + + let memory = db.memory; + let parallelism = db.parallelism; + + let Some(protected_key) = db.protected_key else { return Err(anyhow::anyhow!( "failed to find protected key in db" )); }; - let protected_private_key = - if let Some(protected_private_key) = db.protected_private_key { - protected_private_key - } else { - return Err(anyhow::anyhow!( - "failed to find protected private key in db" - )); - }; + let Some(protected_private_key) = db.protected_private_key else { + return Err(anyhow::anyhow!( + "failed to find protected private key in db" + )); + }; let email = config_email().await?; let mut err_msg = None; - for i in 1u8..=3 { + for i in 1_u8..=3 { let err = if i > 1 { // this unwrap is safe because we only ever continue the loop // if we have set err_msg @@ -373,7 +404,10 @@ pub async fn unlock( let password = rbw::pinentry::getpin( &config_pinentry().await?, "Master Password", - "Unlock the local database", + &format!( + "Unlock the local database for '{}'", + rbw::dirs::profile() + ), err.as_deref(), tty, true, @@ -383,13 +417,14 @@ pub async fn unlock( match rbw::actions::unlock( &email, &password, + kdf, iterations, + memory, + parallelism, &protected_key, &protected_private_key, &db.protected_org_keys, - ) - .await - { + ) { Ok((keys, org_keys)) => { unlock_success(state, keys, org_keys).await?; break; @@ -400,10 +435,9 @@ pub async fn unlock( message, }) .context("failed to unlock database"); - } else { - err_msg = Some(message); - continue; } + err_msg = Some(message); + continue; } Err(e) => return Err(e).context("failed to unlock database"), } @@ -416,11 +450,11 @@ pub async fn unlock( } async fn unlock_success( - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, keys: rbw::locked::Keys, org_keys: std::collections::HashMap<String, rbw::locked::Keys>, ) -> anyhow::Result<()> { - let mut state = state.write().await; + let mut state = state.lock().await; state.priv_key = Some(keys); state.org_keys = Some(org_keys); Ok(()) @@ -428,9 +462,9 @@ async fn unlock_success( pub async fn lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - state.write().await.clear(); + state.lock().await.clear(); respond_ack(sock).await?; @@ -439,10 +473,9 @@ pub async fn lock( pub async fn check_lock( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, - _tty: Option<&str>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { - if state.read().await.needs_unlock() { + if state.lock().await.needs_unlock() { return Err(anyhow::anyhow!("agent is locked")); } @@ -452,8 +485,8 @@ pub async fn check_lock( } pub async fn sync( - sock: &mut crate::sock::Sock, - ack: bool, + sock: Option<&mut crate::sock::Sock>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, ) -> anyhow::Result<()> { let mut db = load_db().await?; @@ -482,7 +515,11 @@ pub async fn sync( db.entries = entries; save_db(&db).await?; - if ack { + if let Err(e) = subscribe_to_notifications(state.clone()).await { + eprintln!("failed to subscribe to notifications: {e}"); + } + + if let Some(sock) = sock { respond_ack(sock).await?; } @@ -491,14 +528,12 @@ pub async fn sync( pub async fn decrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, cipherstring: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find decryption keys in in-memory state" )); @@ -519,14 +554,12 @@ pub async fn decrypt( pub async fn encrypt( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, plaintext: &str, org_id: Option<&str>, ) -> anyhow::Result<()> { - let state = state.read().await; - let keys = if let Some(keys) = state.key(org_id) { - keys - } else { + let state = state.lock().await; + let Some(keys) = state.key(org_id) else { return Err(anyhow::anyhow!( "failed to find encryption keys in in-memory state" )); @@ -542,6 +575,25 @@ pub async fn encrypt( Ok(()) } +pub async fn clipboard_store( + sock: &mut crate::sock::Sock, + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, + text: &str, +) -> anyhow::Result<()> { + state + .lock() + .await + .clipboard + .set_contents(text.to_owned()) + .map_err(|e| { + anyhow::anyhow!("couldn't store value to clipboard: {e}") + })?; + + respond_ack(sock).await?; + + Ok(()) +} + pub async fn version(sock: &mut crate::sock::Sock) -> anyhow::Result<()> { sock.send(&rbw::protocol::Response::Version { version: rbw::protocol::version(), @@ -579,11 +631,10 @@ async fn respond_encrypt( async fn config_email() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; - if let Some(email) = config.email { - Ok(email) - } else { - Err(anyhow::anyhow!("failed to find email address in config")) - } + config.email.map_or_else( + || Err(anyhow::anyhow!("failed to find email address in config")), + Ok, + ) } async fn load_db() -> anyhow::Result<rbw::db::Db> { @@ -617,3 +668,35 @@ async fn config_pinentry() -> anyhow::Result<String> { let config = rbw::config::Config::load_async().await?; Ok(config.pinentry) } + +pub async fn subscribe_to_notifications( + state: std::sync::Arc<tokio::sync::Mutex<crate::agent::State>>, +) -> anyhow::Result<()> { + if state.lock().await.notifications_handler.is_connected() { + return Ok(()); + } + + let config = rbw::config::Config::load_async() + .await + .context("Config is missing")?; + let email = config.email.clone().context("Config is missing email")?; + let db = rbw::db::Db::load_async(config.server_name().as_str(), &email) + .await?; + let access_token = + db.access_token.context("Error getting access token")?; + + let websocket_url = format!( + "{}/hub?access_token={}", + config.notifications_url(), + access_token + ) + .replace("https://", "wss://"); + + let mut state = state.lock().await; + state + .notifications_handler + .connect(websocket_url) + .await + .err() + .map_or_else(|| Ok(()), |err| Err(anyhow::anyhow!(err.to_string()))) +} diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index fae8c7b..a3fecb4 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,24 +1,25 @@ use anyhow::Context as _; +use futures_util::StreamExt as _; -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use crate::notifications; pub struct State { pub priv_key: Option<rbw::locked::Keys>, pub org_keys: Option<std::collections::HashMap<String, rbw::locked::Keys>>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, + pub sync_timeout: crate::timeout::Timeout, + pub sync_timeout_duration: std::time::Duration, + pub notifications_handler: crate::notifications::Handler, + pub clipboard: Box<dyn copypasta::ClipboardProvider>, } impl State { pub fn key(&self, org_id: Option<&str>) -> Option<&rbw::locked::Keys> { - match org_id { - Some(id) => self.org_keys.as_ref().and_then(|h| h.get(id)), - None => self.priv_key.as_ref(), - } + org_id.map_or(self.priv_key.as_ref(), |id| { + self.org_keys.as_ref().and_then(|h| h.get(id)) + }) } pub fn needs_unlock(&self) -> bool { @@ -26,103 +27,163 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; - self.org_keys = Default::default(); - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.org_keys = None; + self.timeout.clear(); + } + + pub fn set_sync_timeout(&mut self) { + self.sync_timeout.set(self.sync_timeout_duration); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + sync_timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, } impl Agent { pub fn new() -> anyhow::Result<Self> { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let sync_timeout_duration = + std::time::Duration::from_secs(config.sync_interval); + let (timeout, timer_r) = crate::timeout::Timeout::new(); + let (sync_timeout, sync_timer_r) = crate::timeout::Timeout::new(); + if sync_timeout_duration > std::time::Duration::ZERO { + sync_timeout.set(sync_timeout_duration); + } + let notifications_handler = crate::notifications::Handler::new(); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + copypasta::ClipboardContext::new().map_or_else( + |e| { + log::warn!("couldn't create clipboard context: {e}"); + let clipboard = Box::new( + // infallible + copypasta::nop_clipboard::NopClipboardContext::new() + .unwrap(), + ); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + |clipboard| { + let clipboard = Box::new(clipboard); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + ); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, - state: std::sync::Arc::new(tokio::sync::RwLock::new(State { + timer_r, + sync_timer_r, + state: std::sync::Arc::new(tokio::sync::Mutex::new(State { priv_key: None, - org_keys: Default::default(), - timeout_chan: w, + org_keys: None, + timeout, + timeout_duration, + sync_timeout, + sync_timeout_duration, + notifications_handler, + clipboard, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = if let Some(timeout) = &mut self.timeout { - timeout - } else { - &mut forever - }; - tokio::select! { - sock = listener.accept() => { + pub enum Event { + Request(std::io::Result<tokio::net::UnixStream>), + Timeout(()), + Sync(()), + } + + let notifications = self + .state + .lock() + .await + .notifications_handler + .get_channel() + .await; + let notifications = + tokio_stream::wrappers::UnboundedReceiverStream::new( + notifications, + ) + .map(|message| match message { + notifications::Message::Logout => Event::Timeout(()), + notifications::Message::Sync => Event::Sync(()), + }) + .boxed(); + + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.sync_timer_r, + ) + .map(Event::Sync) + .boxed(), + notifications, + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { - error: format!("{:#}", e), - }).await.unwrap(); + error: format!("{e:#}"), + }) + .await + .unwrap(); } }); } - _ = timeout => { + Event::Timeout(()) => { + self.state.lock().await.clear(); + } + Event::Sync(()) => { let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); + tokio::spawn(async move { + // this could fail if we aren't logged in, but we + // don't care about that + if let Err(e) = + crate::actions::sync(None, state.clone()).await + { + eprintln!("failed to sync: {e:#}"); + } }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + self.state.lock().await.set_sync_timeout(); } } } + Ok(()) } } async fn handle_request( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, ) -> anyhow::Result<()> { let req = sock.recv().await?; let req = match req { @@ -148,12 +209,7 @@ async fn handle_request( true } rbw::protocol::Action::CheckLock => { - crate::actions::check_lock( - sock, - state.clone(), - req.tty.as_deref(), - ) - .await?; + crate::actions::check_lock(sock, state.clone()).await?; false } rbw::protocol::Action::Lock => { @@ -161,7 +217,7 @@ async fn handle_request( false } rbw::protocol::Action::Sync => { - crate::actions::sync(sock, true).await?; + crate::actions::sync(Some(sock), state.clone()).await?; false } rbw::protocol::Action::Decrypt { @@ -187,6 +243,11 @@ async fn handle_request( .await?; true } + rbw::protocol::Action::ClipboardStore { text } => { + crate::actions::clipboard_store(sock, state.clone(), text) + .await?; + true + } rbw::protocol::Action::Quit => std::process::exit(0), rbw::protocol::Action::Version => { crate::actions::version(sock).await?; @@ -195,7 +256,7 @@ async fn handle_request( }; if set_timeout { - state.write().await.set_timeout(); + state.lock().await.set_timeout(); } Ok(()) diff --git a/src/bin/rbw-agent/daemon.rs b/src/bin/rbw-agent/daemon.rs index 923a217..06db891 100644 --- a/src/bin/rbw-agent/daemon.rs +++ b/src/bin/rbw-agent/daemon.rs @@ -1,25 +1,15 @@ pub struct StartupAck { - writer: std::os::unix::io::RawFd, + writer: std::os::unix::io::OwnedFd, } impl StartupAck { - pub fn ack(&self) -> anyhow::Result<()> { - nix::unistd::write(self.writer, &[0])?; - nix::unistd::close(self.writer)?; + pub fn ack(self) -> anyhow::Result<()> { + rustix::io::write(&self.writer, &[0])?; Ok(()) } } -impl Drop for StartupAck { - fn drop(&mut self) { - // best effort close here, can't do better in a destructor - let _ = nix::unistd::close(self.writer); - } -} - pub fn daemonize() -> anyhow::Result<StartupAck> { - rbw::dirs::make_all()?; - let stdout = std::fs::OpenOptions::new() .append(true) .create(true) @@ -29,33 +19,38 @@ pub fn daemonize() -> anyhow::Result<StartupAck> { .create(true) .open(rbw::dirs::agent_stderr_file())?; - let (r, w) = nix::unistd::pipe()?; - let res = daemonize::Daemonize::new() + let (r, w) = rustix::pipe::pipe()?; + let daemonize = daemonize::Daemonize::new() .pid_file(rbw::dirs::pid_file()) .stdout(stdout) - .stderr(stderr) - .exit_action(move || { + .stderr(stderr); + let res = match daemonize.execute() { + daemonize::Outcome::Parent(_) => { + drop(w); + let mut buf = [0; 1]; // unwraps are necessary because not really a good way to handle // errors here otherwise - let _ = nix::unistd::close(w); - let mut buf = [0; 1]; - nix::unistd::read(r, &mut buf).unwrap(); - nix::unistd::close(r).unwrap(); - }) - .start(); - let _ = nix::unistd::close(r); + rustix::io::read(&r, &mut buf).unwrap(); + drop(r); + std::process::exit(0); + } + daemonize::Outcome::Child(res) => res, + }; + + drop(r); match res { Ok(_) => (), Err(e) => { - match e { - daemonize::DaemonizeError::LockPidfile(_) => { - // this means that there is already an agent running, so - // return a special exit code to allow the cli to detect - // this case and not error out - std::process::exit(23); - } - _ => panic!("failed to daemonize: {}", e), + // XXX super gross, but daemonize removed the ability to match + // on specific error types for some reason? + if e.to_string().contains("unable to lock pid file") { + // this means that there is already an agent running, so + // return a special exit code to allow the cli to detect + // this case and not error out + std::process::exit(23); + } else { + panic!("failed to daemonize: {e}"); } } } diff --git a/src/bin/rbw-agent/debugger.rs b/src/bin/rbw-agent/debugger.rs index 59bbe50..be5260c 100644 --- a/src/bin/rbw-agent/debugger.rs +++ b/src/bin/rbw-agent/debugger.rs @@ -12,7 +12,7 @@ pub fn disable_tracing() -> anyhow::Result<()> { if ret == 0 { Ok(()) } else { - let e = nix::Error::last(); + let e = std::io::Error::last_os_error(); Err(anyhow::anyhow!("failed to disable PTRACE_ATTACH, agent memory may be dumpable by other processes: {}", e)) } } diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs index 69411ae..d470e10 100644 --- a/src/bin/rbw-agent/main.rs +++ b/src/bin/rbw-agent/main.rs @@ -1,4 +1,19 @@ +#![warn(clippy::cargo)] +#![warn(clippy::pedantic)] +#![warn(clippy::nursery)] +#![warn(clippy::as_conversions)] +#![warn(clippy::get_unwrap)] +#![allow(clippy::cognitive_complexity)] +#![allow(clippy::missing_const_for_fn)] +#![allow(clippy::similar_names)] +#![allow(clippy::struct_excessive_bools)] #![allow(clippy::too_many_arguments)] +#![allow(clippy::too_many_lines)] +#![allow(clippy::type_complexity)] +#![allow(clippy::multiple_crate_versions)] +#![allow(clippy::large_enum_variant)] +// this one looks plausibly useful, but currently has too many bugs +#![allow(clippy::significant_drop_tightening)] use anyhow::Context as _; @@ -6,7 +21,9 @@ mod actions; mod agent; mod daemon; mod debugger; +mod notifications; mod sock; +mod timeout; async fn tokio_main( startup_ack: Option<crate::daemon::StartupAck>, @@ -17,7 +34,7 @@ async fn tokio_main( startup_ack.ack()?; } - let mut agent = crate::agent::Agent::new()?; + let agent = crate::agent::Agent::new()?; agent.run(listener).await?; Ok(()) @@ -29,11 +46,11 @@ fn real_main() -> anyhow::Result<()> { ) .init(); - let no_daemonize = if let Some(arg) = std::env::args().nth(1) { - arg == "--no-daemonize" - } else { - false - }; + let no_daemonize = std::env::args() + .nth(1) + .map_or(false, |arg| arg == "--no-daemonize"); + + rbw::dirs::make_all()?; let startup_ack = if no_daemonize { None @@ -69,7 +86,7 @@ fn main() { if let Err(e) = res { // XXX log file? - eprintln!("{:#}", e); + eprintln!("{e:#}"); std::process::exit(1); } } diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs new file mode 100644 index 0000000..8176603 --- /dev/null +++ b/src/bin/rbw-agent/notifications.rs @@ -0,0 +1,174 @@ +use futures_util::{SinkExt as _, StreamExt as _}; + +#[derive(Clone, Copy, Debug)] +pub enum Message { + Sync, + Logout, +} + +pub struct Handler { + write: Option< + futures::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, + >, + tokio_tungstenite::tungstenite::Message, + >, + >, + read_handle: Option<tokio::task::JoinHandle<()>>, + sending_channels: std::sync::Arc< + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, + >, +} + +impl Handler { + pub fn new() -> Self { + Self { + write: None, + read_handle: None, + sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new( + Vec::new(), + )), + } + } + + pub async fn connect( + &mut self, + url: String, + ) -> Result<(), Box<dyn std::error::Error>> { + if self.is_connected() { + self.disconnect().await?; + } + + let (write, read_handle) = + subscribe_to_notifications(url, self.sending_channels.clone()) + .await?; + + self.write = Some(write); + self.read_handle = Some(read_handle); + Ok(()) + } + + pub fn is_connected(&self) -> bool { + self.write.is_some() + && self.read_handle.is_some() + && !self.read_handle.as_ref().unwrap().is_finished() + } + + pub async fn disconnect( + &mut self, + ) -> Result<(), Box<dyn std::error::Error>> { + self.sending_channels.write().await.clear(); + if let Some(mut write) = self.write.take() { + write + .send(tokio_tungstenite::tungstenite::Message::Close(None)) + .await?; + write.close().await?; + self.read_handle.take().unwrap().await?; + } + self.write = None; + self.read_handle = None; + Ok(()) + } + + pub async fn get_channel( + &mut self, + ) -> tokio::sync::mpsc::UnboundedReceiver<Message> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + self.sending_channels.write().await.push(tx); + rx + } +} + +async fn subscribe_to_notifications( + url: String, + sending_channels: std::sync::Arc< + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, + >, +) -> Result< + ( + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, + >, + tokio_tungstenite::tungstenite::Message, + >, + tokio::task::JoinHandle<()>, + ), + Box<dyn std::error::Error>, +> { + let url = url::Url::parse(url.as_str())?; + let (ws_stream, _response) = + tokio_tungstenite::connect_async(url).await?; + let (mut write, read) = ws_stream.split(); + + write + .send(tokio_tungstenite::tungstenite::Message::Text( + "{\"protocol\":\"messagepack\",\"version\":1}\x1e".to_string(), + )) + .await + .unwrap(); + + let read_future = async move { + let sending_channels = &sending_channels; + read.for_each(|message| async move { + match message { + Ok(message) => { + if let Some(message) = parse_message(message) { + let sending_channels = sending_channels.read().await; + let sending_channels = sending_channels.as_slice(); + for channel in sending_channels { + channel.send(message).unwrap(); + } + } + } + Err(e) => { + eprintln!("websocket error: {e:?}"); + } + } + }) + .await; + }; + + Ok((write, tokio::spawn(read_future))) +} + +fn parse_message( + message: tokio_tungstenite::tungstenite::Message, +) -> Option<Message> { + let tokio_tungstenite::tungstenite::Message::Binary(data) = message + else { + return None; + }; + + // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message + let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; + + let unpacked_messagepack = + rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; + + let unpacked_message = unpacked_messagepack.as_array()?; + let message_type = unpacked_message.first()?.as_u64()?; + // invocation + if message_type != 1 { + return None; + } + let target = unpacked_message.get(3)?.as_str()?; + if target != "ReceiveMessage" { + return None; + } + + let args = unpacked_message.get(4)?.as_array()?; + let map = args.first()?.as_map()?; + for (k, v) in map { + if k.as_str()? == "Type" { + let ty = v.as_i64()?; + return match ty { + 11 => Some(Message::Logout), + _ => Some(Message::Sync), + }; + } + } + + None +} diff --git a/src/bin/rbw-agent/sock.rs b/src/bin/rbw-agent/sock.rs index 311176c..280b8cc 100644 --- a/src/bin/rbw-agent/sock.rs +++ b/src/bin/rbw-agent/sock.rs @@ -36,9 +36,8 @@ impl Sock { buf.read_line(&mut line) .await .context("failed to read message from socket")?; - Ok(serde_json::from_str(&line).map_err(|e| { - format!("failed to parse message '{}': {}", line, e) - })) + Ok(serde_json::from_str(&line) + .map_err(|e| format!("failed to parse message '{line}': {e}"))) } } diff --git a/src/bin/rbw-agent/timeout.rs b/src/bin/rbw-agent/timeout.rs new file mode 100644 index 0000000..e2aba06 --- /dev/null +++ b/src/bin/rbw-agent/timeout.rs @@ -0,0 +1,66 @@ +use futures_util::StreamExt as _; + +#[derive(Debug, Hash, Eq, PartialEq, Copy, Clone)] +enum Streams { + Requests, + Timer, +} + +#[derive(Debug)] +enum Action { + Set(std::time::Duration), + Clear, +} + +pub struct Timeout { + req_w: tokio::sync::mpsc::UnboundedSender<Action>, +} + +impl Timeout { + pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<()>) { + let (req_w, req_r) = tokio::sync::mpsc::unbounded_channel(); + let (timer_w, timer_r) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + enum Event { + Request(Action), + Timer, + } + let mut stream = tokio_stream::StreamMap::new(); + stream.insert( + Streams::Requests, + tokio_stream::wrappers::UnboundedReceiverStream::new(req_r) + .map(Event::Request) + .boxed(), + ); + while let Some(event) = stream.next().await { + match event { + (_, Event::Request(Action::Set(dur))) => { + stream.insert( + Streams::Timer, + futures_util::stream::once(tokio::time::sleep( + dur, + )) + .map(|()| Event::Timer) + .boxed(), + ); + } + (_, Event::Request(Action::Clear)) => { + stream.remove(&Streams::Timer); + } + (_, Event::Timer) => { + timer_w.send(()).unwrap(); + } + } + } + }); + (Self { req_w }, timer_r) + } + + pub fn set(&self, dur: std::time::Duration) { + self.req_w.send(Action::Set(dur)).unwrap(); + } + + pub fn clear(&self) { + self.req_w.send(Action::Clear).unwrap(); + } +} |