diff options
Diffstat (limited to 'src/bin/rbw-agent/agent.rs')
-rw-r--r-- | src/bin/rbw-agent/agent.rs | 205 |
1 files changed, 133 insertions, 72 deletions
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index fae8c7b..a3fecb4 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,24 +1,25 @@ use anyhow::Context as _; +use futures_util::StreamExt as _; -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use crate::notifications; pub struct State { pub priv_key: Option<rbw::locked::Keys>, pub org_keys: Option<std::collections::HashMap<String, rbw::locked::Keys>>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, + pub sync_timeout: crate::timeout::Timeout, + pub sync_timeout_duration: std::time::Duration, + pub notifications_handler: crate::notifications::Handler, + pub clipboard: Box<dyn copypasta::ClipboardProvider>, } impl State { pub fn key(&self, org_id: Option<&str>) -> Option<&rbw::locked::Keys> { - match org_id { - Some(id) => self.org_keys.as_ref().and_then(|h| h.get(id)), - None => self.priv_key.as_ref(), - } + org_id.map_or(self.priv_key.as_ref(), |id| { + self.org_keys.as_ref().and_then(|h| h.get(id)) + }) } pub fn needs_unlock(&self) -> bool { @@ -26,103 +27,163 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; - self.org_keys = Default::default(); - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.org_keys = None; + self.timeout.clear(); + } + + pub fn set_sync_timeout(&mut self) { + self.sync_timeout.set(self.sync_timeout_duration); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + sync_timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, } impl Agent { pub fn new() -> anyhow::Result<Self> { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let sync_timeout_duration = + std::time::Duration::from_secs(config.sync_interval); + let (timeout, timer_r) = crate::timeout::Timeout::new(); + let (sync_timeout, sync_timer_r) = crate::timeout::Timeout::new(); + if sync_timeout_duration > std::time::Duration::ZERO { + sync_timeout.set(sync_timeout_duration); + } + let notifications_handler = crate::notifications::Handler::new(); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + copypasta::ClipboardContext::new().map_or_else( + |e| { + log::warn!("couldn't create clipboard context: {e}"); + let clipboard = Box::new( + // infallible + copypasta::nop_clipboard::NopClipboardContext::new() + .unwrap(), + ); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + |clipboard| { + let clipboard = Box::new(clipboard); + let clipboard: Box<dyn copypasta::ClipboardProvider> = + clipboard; + clipboard + }, + ); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, - state: std::sync::Arc::new(tokio::sync::RwLock::new(State { + timer_r, + sync_timer_r, + state: std::sync::Arc::new(tokio::sync::Mutex::new(State { priv_key: None, - org_keys: Default::default(), - timeout_chan: w, + org_keys: None, + timeout, + timeout_duration, + sync_timeout, + sync_timeout_duration, + notifications_handler, + clipboard, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = if let Some(timeout) = &mut self.timeout { - timeout - } else { - &mut forever - }; - tokio::select! { - sock = listener.accept() => { + pub enum Event { + Request(std::io::Result<tokio::net::UnixStream>), + Timeout(()), + Sync(()), + } + + let notifications = self + .state + .lock() + .await + .notifications_handler + .get_channel() + .await; + let notifications = + tokio_stream::wrappers::UnboundedReceiverStream::new( + notifications, + ) + .map(|message| match message { + notifications::Message::Logout => Event::Timeout(()), + notifications::Message::Sync => Event::Sync(()), + }) + .boxed(); + + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.sync_timer_r, + ) + .map(Event::Sync) + .boxed(), + notifications, + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { - error: format!("{:#}", e), - }).await.unwrap(); + error: format!("{e:#}"), + }) + .await + .unwrap(); } }); } - _ = timeout => { + Event::Timeout(()) => { + self.state.lock().await.clear(); + } + Event::Sync(()) => { let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); + tokio::spawn(async move { + // this could fail if we aren't logged in, but we + // don't care about that + if let Err(e) = + crate::actions::sync(None, state.clone()).await + { + eprintln!("failed to sync: {e:#}"); + } }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + self.state.lock().await.set_sync_timeout(); } } } + Ok(()) } } async fn handle_request( sock: &mut crate::sock::Sock, - state: std::sync::Arc<tokio::sync::RwLock<State>>, + state: std::sync::Arc<tokio::sync::Mutex<State>>, ) -> anyhow::Result<()> { let req = sock.recv().await?; let req = match req { @@ -148,12 +209,7 @@ async fn handle_request( true } rbw::protocol::Action::CheckLock => { - crate::actions::check_lock( - sock, - state.clone(), - req.tty.as_deref(), - ) - .await?; + crate::actions::check_lock(sock, state.clone()).await?; false } rbw::protocol::Action::Lock => { @@ -161,7 +217,7 @@ async fn handle_request( false } rbw::protocol::Action::Sync => { - crate::actions::sync(sock, true).await?; + crate::actions::sync(Some(sock), state.clone()).await?; false } rbw::protocol::Action::Decrypt { @@ -187,6 +243,11 @@ async fn handle_request( .await?; true } + rbw::protocol::Action::ClipboardStore { text } => { + crate::actions::clipboard_store(sock, state.clone(), text) + .await?; + true + } rbw::protocol::Action::Quit => std::process::exit(0), rbw::protocol::Action::Version => { crate::actions::version(sock).await?; @@ -195,7 +256,7 @@ async fn handle_request( }; if set_timeout { - state.write().await.set_timeout(); + state.lock().await.set_timeout(); } Ok(()) |