aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2023-03-25 18:29:04 -0400
committerJesse Luehrs <doy@tozt.net>2023-03-25 23:14:16 -0400
commitb659cc500476a7b4b94bc6659d46922be9465b99 (patch)
treecdf83f293d4951bf2565cc8f92e73413d7e81819
parentad0a078a5a2c4c9efd16b20ff47cc4e1ef922dab (diff)
downloadrbw-b659cc500476a7b4b94bc6659d46922be9465b99.tar.gz
rbw-b659cc500476a7b4b94bc6659d46922be9465b99.zip
stop using tokio::select!
-rw-r--r--Cargo.lock41
-rw-r--r--Cargo.toml2
-rw-r--r--src/bin/rbw-agent/agent.rs91
-rw-r--r--src/bin/rbw-agent/main.rs3
-rw-r--r--src/bin/rbw-agent/timeout.rs66
5 files changed, 141 insertions, 62 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 92a042c..4a42f2b 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -391,15 +391,26 @@ dependencies = [
[[package]]
name = "futures-core"
-version = "0.3.26"
+version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608"
+checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
[[package]]
name = "futures-io"
-version = "0.3.26"
+version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531"
+checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
+
+[[package]]
+name = "futures-macro"
+version = "0.3.27"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
[[package]]
name = "futures-sink"
@@ -409,18 +420,19 @@ checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364"
[[package]]
name = "futures-task"
-version = "0.3.26"
+version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366"
+checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
[[package]]
name = "futures-util"
-version = "0.3.26"
+version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1"
+checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
dependencies = [
"futures-core",
"futures-io",
+ "futures-macro",
"futures-task",
"memchr",
"pin-project-lite",
@@ -1042,6 +1054,7 @@ dependencies = [
"daemonize",
"directories",
"env_logger",
+ "futures-util",
"hkdf",
"hmac",
"humantime",
@@ -1066,6 +1079,7 @@ dependencies = [
"textwrap",
"thiserror",
"tokio",
+ "tokio-stream",
"totp-lite",
"url",
"uuid",
@@ -1608,6 +1622,17 @@ dependencies = [
]
[[package]]
+name = "tokio-stream"
+version = "0.1.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8fb52b74f05dbf495a8fba459fdc331812b96aa086d9eb78101fa0d4569c3313"
+dependencies = [
+ "futures-core",
+ "pin-project-lite",
+ "tokio",
+]
+
+[[package]]
name = "tokio-util"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index ae6d641..e4c7ea3 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -27,6 +27,7 @@ clap_complete = "4.1.4"
daemonize = "0.5.0"
directories = "4.0.1"
env_logger = "0.10.0"
+futures-util = "0.3.27"
hkdf = "0.12.3"
hmac = { version = "0.12.1", features = ["std"] }
humantime = "2.1.0"
@@ -51,6 +52,7 @@ terminal_size = "0.2.5"
textwrap = "0.16.0"
thiserror = "1.0.39"
tokio = { version = "1.26.0", features = ["full"] }
+tokio-stream = { version = "0.1.12", features = ["net"] }
totp-lite = "2.0.0"
url = "2.3.1"
uuid = { version = "1.3.0", features = ["v4"] }
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index b36fbb7..8fa6768 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -1,16 +1,12 @@
use anyhow::Context as _;
-
-#[derive(Debug)]
-pub enum TimeoutEvent {
- Set,
- Clear,
-}
+use futures_util::StreamExt as _;
pub struct State {
pub priv_key: Option<rbw::locked::Keys>,
pub org_keys:
Option<std::collections::HashMap<String, rbw::locked::Keys>>,
- pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>,
+ pub timeout: crate::timeout::Timeout,
+ pub timeout_duration: std::time::Duration,
}
impl State {
@@ -25,22 +21,18 @@ impl State {
}
pub fn set_timeout(&mut self) {
- // no real better option to unwrap here
- self.timeout_chan.send(TimeoutEvent::Set).unwrap();
+ self.timeout.set(self.timeout_duration);
}
pub fn clear(&mut self) {
self.priv_key = None;
self.org_keys = None;
- // no real better option to unwrap here
- self.timeout_chan.send(TimeoutEvent::Clear).unwrap();
+ self.timeout.clear();
}
}
pub struct Agent {
- timeout_duration: tokio::time::Duration,
- timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
- timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>,
+ timer_r: tokio::sync::mpsc::UnboundedReceiver<()>,
state: std::sync::Arc<tokio::sync::RwLock<State>>,
}
@@ -48,70 +40,63 @@ impl Agent {
pub fn new() -> anyhow::Result<Self> {
let config = rbw::config::Config::load()?;
let timeout_duration =
- tokio::time::Duration::from_secs(config.lock_timeout);
- let (w, r) = tokio::sync::mpsc::unbounded_channel();
+ std::time::Duration::from_secs(config.lock_timeout);
+ let (timeout, timer_r) = crate::timeout::Timeout::new();
Ok(Self {
- timeout_duration,
- timeout: None,
- timeout_chan: r,
+ timer_r,
state: std::sync::Arc::new(tokio::sync::RwLock::new(State {
priv_key: None,
org_keys: None,
- timeout_chan: w,
+ timeout,
+ timeout_duration,
})),
})
}
- fn set_timeout(&mut self) {
- self.timeout =
- Some(Box::pin(tokio::time::sleep(self.timeout_duration)));
- }
-
- fn clear_timeout(&mut self) {
- self.timeout = None;
- }
-
pub async fn run(
- &mut self,
+ self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
- // tokio only supports timeouts up to 2^36 milliseconds
- let mut forever = Box::pin(tokio::time::sleep(
- tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2),
- ));
- loop {
- let timeout = self.timeout.as_mut().unwrap_or(&mut forever);
- tokio::select! {
- sock = listener.accept() => {
+ enum Event {
+ Request(std::io::Result<tokio::net::UnixStream>),
+ Timeout(()),
+ }
+ let mut stream = futures_util::stream::select_all([
+ tokio_stream::wrappers::UnixListenerStream::new(listener)
+ .map(Event::Request)
+ .boxed(),
+ tokio_stream::wrappers::UnboundedReceiverStream::new(
+ self.timer_r,
+ )
+ .map(Event::Timeout)
+ .boxed(),
+ ]);
+ while let Some(event) = stream.next().await {
+ match event {
+ Event::Request(res) => {
let mut sock = crate::sock::Sock::new(
- sock.context("failed to accept incoming connection")?.0
+ res.context("failed to accept incoming connection")?,
);
let state = self.state.clone();
tokio::spawn(async move {
- let res
- = handle_request(&mut sock, state.clone()).await;
+ let res =
+ handle_request(&mut sock, state.clone()).await;
if let Err(e) = res {
// unwrap is the only option here
sock.send(&rbw::protocol::Response::Error {
error: format!("{e:#}"),
- }).await.unwrap();
+ })
+ .await
+ .unwrap();
}
});
}
- _ = timeout => {
- let state = self.state.clone();
- tokio::spawn(async move{
- state.write().await.clear();
- });
- }
- Some(ev) = self.timeout_chan.recv() => {
- match ev {
- TimeoutEvent::Set => self.set_timeout(),
- TimeoutEvent::Clear => self.clear_timeout(),
- }
+ Event::Timeout(()) => {
+ self.state.write().await.clear();
}
}
}
+ Ok(())
}
}
diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs
index ad1fe80..81eee3a 100644
--- a/src/bin/rbw-agent/main.rs
+++ b/src/bin/rbw-agent/main.rs
@@ -20,6 +20,7 @@ mod agent;
mod daemon;
mod debugger;
mod sock;
+mod timeout;
async fn tokio_main(
startup_ack: Option<crate::daemon::StartupAck>,
@@ -30,7 +31,7 @@ async fn tokio_main(
startup_ack.ack()?;
}
- let mut agent = crate::agent::Agent::new()?;
+ let agent = crate::agent::Agent::new()?;
agent.run(listener).await?;
Ok(())
diff --git a/src/bin/rbw-agent/timeout.rs b/src/bin/rbw-agent/timeout.rs
new file mode 100644
index 0000000..e613ff0
--- /dev/null
+++ b/src/bin/rbw-agent/timeout.rs
@@ -0,0 +1,66 @@
+use futures_util::StreamExt as _;
+
+#[derive(Debug, Hash, Eq, PartialEq, Copy, Clone)]
+enum Streams {
+ Requests,
+ Timer,
+}
+
+#[derive(Debug)]
+enum Action {
+ Set(std::time::Duration),
+ Clear,
+}
+
+pub struct Timeout {
+ req_w: tokio::sync::mpsc::UnboundedSender<Action>,
+}
+
+impl Timeout {
+ pub fn new() -> (Self, tokio::sync::mpsc::UnboundedReceiver<()>) {
+ let (req_w, req_r) = tokio::sync::mpsc::unbounded_channel();
+ let (timer_w, timer_r) = tokio::sync::mpsc::unbounded_channel();
+ tokio::spawn(async move {
+ enum Event {
+ Request(Action),
+ Timer,
+ }
+ let mut stream = tokio_stream::StreamMap::new();
+ stream.insert(
+ Streams::Requests,
+ tokio_stream::wrappers::UnboundedReceiverStream::new(req_r)
+ .map(Event::Request)
+ .boxed(),
+ );
+ while let Some(event) = stream.next().await {
+ match event {
+ (_, Event::Request(Action::Set(dur))) => {
+ stream.insert(
+ Streams::Timer,
+ futures_util::stream::once(tokio::time::sleep(
+ dur,
+ ))
+ .map(|_| Event::Timer)
+ .boxed(),
+ );
+ }
+ (_, Event::Request(Action::Clear)) => {
+ stream.remove(&Streams::Timer);
+ }
+ (_, Event::Timer) => {
+ timer_w.send(()).unwrap();
+ }
+ }
+ }
+ });
+ (Self { req_w }, timer_r)
+ }
+
+ pub fn set(&self, dur: std::time::Duration) {
+ self.req_w.send(Action::Set(dur)).unwrap();
+ }
+
+ pub fn clear(&self) {
+ self.req_w.send(Action::Clear).unwrap();
+ }
+}