diff options
Diffstat (limited to 'src/bin/rbw-agent/agent.rs')
-rw-r--r-- | src/bin/rbw-agent/agent.rs | 91 |
1 files changed, 38 insertions, 53 deletions
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index b36fbb7..8fa6768 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,16 +1,12 @@ use anyhow::Context as _; - -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use futures_util::StreamExt as _; pub struct State { pub priv_key: Option<rbw::locked::Keys>, pub org_keys: Option<std::collections::HashMap<String, rbw::locked::Keys>>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, } impl State { @@ -25,22 +21,18 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; self.org_keys = None; - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.timeout.clear(); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, state: std::sync::Arc<tokio::sync::RwLock<State>>, } @@ -48,70 +40,63 @@ impl Agent { pub fn new() -> anyhow::Result<Self> { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let (timeout, timer_r) = crate::timeout::Timeout::new(); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, + timer_r, state: std::sync::Arc::new(tokio::sync::RwLock::new(State { priv_key: None, org_keys: None, - timeout_chan: w, + timeout, + timeout_duration, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = self.timeout.as_mut().unwrap_or(&mut forever); - tokio::select! { - sock = listener.accept() => { + enum Event { + Request(std::io::Result<tokio::net::UnixStream>), + Timeout(()), + } + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { error: format!("{e:#}"), - }).await.unwrap(); + }) + .await + .unwrap(); } }); } - _ = timeout => { - let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); - }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + Event::Timeout(()) => { + self.state.write().await.clear(); } } } + Ok(()) } } |