From b659cc500476a7b4b94bc6659d46922be9465b99 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Sat, 25 Mar 2023 18:29:04 -0400 Subject: stop using tokio::select! --- src/bin/rbw-agent/agent.rs | 91 ++++++++++++++++++-------------------------- src/bin/rbw-agent/main.rs | 3 +- src/bin/rbw-agent/timeout.rs | 66 ++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 54 deletions(-) create mode 100644 src/bin/rbw-agent/timeout.rs (limited to 'src/bin') diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index b36fbb7..8fa6768 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,16 +1,12 @@ use anyhow::Context as _; - -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use futures_util::StreamExt as _; pub struct State { pub priv_key: Option, pub org_keys: Option>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, } impl State { @@ -25,22 +21,18 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; self.org_keys = None; - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.timeout.clear(); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, state: std::sync::Arc>, } @@ -48,70 +40,63 @@ impl Agent { pub fn new() -> anyhow::Result { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let (timeout, timer_r) = crate::timeout::Timeout::new(); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, + timer_r, state: std::sync::Arc::new(tokio::sync::RwLock::new(State { priv_key: None, org_keys: None, - timeout_chan: w, + timeout, + timeout_duration, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = self.timeout.as_mut().unwrap_or(&mut forever); - tokio::select! { - sock = listener.accept() => { + enum Event { + Request(std::io::Result), + Timeout(()), + } + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { error: format!("{e:#}"), - }).await.unwrap(); + }) + .await + .unwrap(); } }); } - _ = timeout => { - let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); - }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + Event::Timeout(()) => { + self.state.write().await.clear(); } } } + Ok(()) } } diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs index ad1fe80..81eee3a 100644 --- a/src/bin/rbw-agent/main.rs +++ b/src/bin/rbw-agent/main.rs @@ -20,6 +20,7 @@ mod agent; mod daemon; mod debugger; mod sock; +mod timeout; async fn tokio_main( startup_ack: Option, @@ -30,7 +31,7 @@ async fn tokio_main( startup_ack.ack()?; } - let mut agent = crate::agent::Agent::new()?; + let agent = crate::agent::Agent::new()?; agent.run(listener).await?; Ok(()) diff --git a/src/bin/rbw-agent/timeout.rs b/src/bin/rbw-agent/timeout.rs new file mode 100644 index 0000000..e613ff0 --- /dev/null +++ b/src/bin/rbw-agent/timeout.rs @@ -0,0 +1,66 @@ +use futures_util::StreamExt as _; + +#[derive(Debug, Hash, Eq, PartialEq, Copy, Clone)] +enum Streams { + Requests, + Timer, +} + +#[derive(Debug)] +enum Action { + Set(std::time::Duration), + Clear, +} + +pub struct Timeout { + req_w: tokio::sync::mpsc::UnboundedSender, +} + +impl Timeout { + pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<()>) { + let (req_w, req_r) = tokio::sync::mpsc::unbounded_channel(); + let (timer_w, timer_r) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + enum Event { + Request(Action), + Timer, + } + let mut stream = tokio_stream::StreamMap::new(); + stream.insert( + Streams::Requests, + tokio_stream::wrappers::UnboundedReceiverStream::new(req_r) + .map(Event::Request) + .boxed(), + ); + while let Some(event) = stream.next().await { + match event { + (_, Event::Request(Action::Set(dur))) => { + stream.insert( + Streams::Timer, + futures_util::stream::once(tokio::time::sleep( + dur, + )) + .map(|_| Event::Timer) + .boxed(), + ); + } + (_, Event::Request(Action::Clear)) => { + stream.remove(&Streams::Timer); + } + (_, Event::Timer) => { + timer_w.send(()).unwrap(); + } + } + } + }); + (Self { req_w }, timer_r) + } + + pub fn set(&self, dur: std::time::Duration) { + self.req_w.send(Action::Set(dur)).unwrap(); + } + + pub fn clear(&self) { + self.req_w.send(Action::Clear).unwrap(); + } +} -- cgit v1.2.3-54-g00ecf