From b659cc500476a7b4b94bc6659d46922be9465b99 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Sat, 25 Mar 2023 18:29:04 -0400 Subject: stop using tokio::select! --- Cargo.lock | 41 ++++++++++++++++---- Cargo.toml | 2 + src/bin/rbw-agent/agent.rs | 91 ++++++++++++++++++-------------------------- src/bin/rbw-agent/main.rs | 3 +- src/bin/rbw-agent/timeout.rs | 66 ++++++++++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 62 deletions(-) create mode 100644 src/bin/rbw-agent/timeout.rs diff --git a/Cargo.lock b/Cargo.lock index 92a042c..4a42f2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -391,15 +391,26 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" [[package]] name = "futures-io" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" +checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" + +[[package]] +name = "futures-macro" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "futures-sink" @@ -409,18 +420,19 @@ checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-task", "memchr", "pin-project-lite", @@ -1042,6 +1054,7 @@ dependencies = [ "daemonize", "directories", "env_logger", + "futures-util", "hkdf", "hmac", "humantime", @@ -1066,6 +1079,7 @@ dependencies = [ "textwrap", "thiserror", "tokio", + "tokio-stream", "totp-lite", "url", "uuid", @@ -1607,6 +1621,17 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-stream" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.7" diff --git a/Cargo.toml b/Cargo.toml index ae6d641..e4c7ea3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ clap_complete = "4.1.4" daemonize = "0.5.0" directories = "4.0.1" env_logger = "0.10.0" +futures-util = "0.3.27" hkdf = "0.12.3" hmac = { version = "0.12.1", features = ["std"] } humantime = "2.1.0" @@ -51,6 +52,7 @@ terminal_size = "0.2.5" textwrap = "0.16.0" thiserror = "1.0.39" tokio = { version = "1.26.0", features = ["full"] } +tokio-stream = { version = "0.1.12", features = ["net"] } totp-lite = "2.0.0" url = "2.3.1" uuid = { version = "1.3.0", features = ["v4"] } diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index b36fbb7..8fa6768 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,16 +1,12 @@ use anyhow::Context as _; - -#[derive(Debug)] -pub enum TimeoutEvent { - Set, - Clear, -} +use futures_util::StreamExt as _; pub struct State { pub priv_key: Option, pub org_keys: Option>, - pub timeout_chan: tokio::sync::mpsc::UnboundedSender, + pub timeout: crate::timeout::Timeout, + pub timeout_duration: std::time::Duration, } impl State { @@ -25,22 +21,18 @@ impl State { } pub fn set_timeout(&mut self) { - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Set).unwrap(); + self.timeout.set(self.timeout_duration); } pub fn clear(&mut self) { self.priv_key = None; self.org_keys = None; - // no real better option to unwrap here - self.timeout_chan.send(TimeoutEvent::Clear).unwrap(); + self.timeout.clear(); } } pub struct Agent { - timeout_duration: tokio::time::Duration, - timeout: Option>>, - timeout_chan: tokio::sync::mpsc::UnboundedReceiver, + timer_r: tokio::sync::mpsc::UnboundedReceiver<()>, state: std::sync::Arc>, } @@ -48,70 +40,63 @@ impl Agent { pub fn new() -> anyhow::Result { let config = rbw::config::Config::load()?; let timeout_duration = - tokio::time::Duration::from_secs(config.lock_timeout); - let (w, r) = tokio::sync::mpsc::unbounded_channel(); + std::time::Duration::from_secs(config.lock_timeout); + let (timeout, timer_r) = crate::timeout::Timeout::new(); Ok(Self { - timeout_duration, - timeout: None, - timeout_chan: r, + timer_r, state: std::sync::Arc::new(tokio::sync::RwLock::new(State { priv_key: None, org_keys: None, - timeout_chan: w, + timeout, + timeout_duration, })), }) } - fn set_timeout(&mut self) { - self.timeout = - Some(Box::pin(tokio::time::sleep(self.timeout_duration))); - } - - fn clear_timeout(&mut self) { - self.timeout = None; - } - pub async fn run( - &mut self, + self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - // tokio only supports timeouts up to 2^36 milliseconds - let mut forever = Box::pin(tokio::time::sleep( - tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2), - )); - loop { - let timeout = self.timeout.as_mut().unwrap_or(&mut forever); - tokio::select! { - sock = listener.accept() => { + enum Event { + Request(std::io::Result), + Timeout(()), + } + let mut stream = futures_util::stream::select_all([ + tokio_stream::wrappers::UnixListenerStream::new(listener) + .map(Event::Request) + .boxed(), + tokio_stream::wrappers::UnboundedReceiverStream::new( + self.timer_r, + ) + .map(Event::Timeout) + .boxed(), + ]); + while let Some(event) = stream.next().await { + match event { + Event::Request(res) => { let mut sock = crate::sock::Sock::new( - sock.context("failed to accept incoming connection")?.0 + res.context("failed to accept incoming connection")?, ); let state = self.state.clone(); tokio::spawn(async move { - let res - = handle_request(&mut sock, state.clone()).await; + let res = + handle_request(&mut sock, state.clone()).await; if let Err(e) = res { // unwrap is the only option here sock.send(&rbw::protocol::Response::Error { error: format!("{e:#}"), - }).await.unwrap(); + }) + .await + .unwrap(); } }); } - _ = timeout => { - let state = self.state.clone(); - tokio::spawn(async move{ - state.write().await.clear(); - }); - } - Some(ev) = self.timeout_chan.recv() => { - match ev { - TimeoutEvent::Set => self.set_timeout(), - TimeoutEvent::Clear => self.clear_timeout(), - } + Event::Timeout(()) => { + self.state.write().await.clear(); } } } + Ok(()) } } diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs index ad1fe80..81eee3a 100644 --- a/src/bin/rbw-agent/main.rs +++ b/src/bin/rbw-agent/main.rs @@ -20,6 +20,7 @@ mod agent; mod daemon; mod debugger; mod sock; +mod timeout; async fn tokio_main( startup_ack: Option, @@ -30,7 +31,7 @@ async fn tokio_main( startup_ack.ack()?; } - let mut agent = crate::agent::Agent::new()?; + let agent = crate::agent::Agent::new()?; agent.run(listener).await?; Ok(()) diff --git a/src/bin/rbw-agent/timeout.rs b/src/bin/rbw-agent/timeout.rs new file mode 100644 index 0000000..e613ff0 --- /dev/null +++ b/src/bin/rbw-agent/timeout.rs @@ -0,0 +1,66 @@ +use futures_util::StreamExt as _; + +#[derive(Debug, Hash, Eq, PartialEq, Copy, Clone)] +enum Streams { + Requests, + Timer, +} + +#[derive(Debug)] +enum Action { + Set(std::time::Duration), + Clear, +} + +pub struct Timeout { + req_w: tokio::sync::mpsc::UnboundedSender, +} + +impl Timeout { + pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<()>) { + let (req_w, req_r) = tokio::sync::mpsc::unbounded_channel(); + let (timer_w, timer_r) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + enum Event { + Request(Action), + Timer, + } + let mut stream = tokio_stream::StreamMap::new(); + stream.insert( + Streams::Requests, + tokio_stream::wrappers::UnboundedReceiverStream::new(req_r) + .map(Event::Request) + .boxed(), + ); + while let Some(event) = stream.next().await { + match event { + (_, Event::Request(Action::Set(dur))) => { + stream.insert( + Streams::Timer, + futures_util::stream::once(tokio::time::sleep( + dur, + )) + .map(|_| Event::Timer) + .boxed(), + ); + } + (_, Event::Request(Action::Clear)) => { + stream.remove(&Streams::Timer); + } + (_, Event::Timer) => { + timer_w.send(()).unwrap(); + } + } + } + }); + (Self { req_w }, timer_r) + } + + pub fn set(&self, dur: std::time::Duration) { + self.req_w.send(Action::Set(dur)).unwrap(); + } + + pub fn clear(&self) { + self.req_w.send(Action::Clear).unwrap(); + } +} -- cgit v1.2.3-54-g00ecf