aboutsummaryrefslogtreecommitdiffstats
path: root/src/bin/rbw-agent/agent.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/rbw-agent/agent.rs')
-rw-r--r--src/bin/rbw-agent/agent.rs205
1 files changed, 133 insertions, 72 deletions
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index fae8c7b..a3fecb4 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -1,24 +1,25 @@
use anyhow::Context as _;
+use futures_util::StreamExt as _;
-#[derive(Debug)]
-pub enum TimeoutEvent {
- Set,
- Clear,
-}
+use crate::notifications;
pub struct State {
pub priv_key: Option<rbw::locked::Keys>,
pub org_keys:
Option<std::collections::HashMap<String, rbw::locked::Keys>>,
- pub timeout_chan: tokio::sync::mpsc::UnboundedSender<TimeoutEvent>,
+ pub timeout: crate::timeout::Timeout,
+ pub timeout_duration: std::time::Duration,
+ pub sync_timeout: crate::timeout::Timeout,
+ pub sync_timeout_duration: std::time::Duration,
+ pub notifications_handler: crate::notifications::Handler,
+ pub clipboard: Box<dyn copypasta::ClipboardProvider>,
}
impl State {
pub fn key(&self, org_id: Option<&str>) -> Option<&rbw::locked::Keys> {
- match org_id {
- Some(id) => self.org_keys.as_ref().and_then(|h| h.get(id)),
- None => self.priv_key.as_ref(),
- }
+ org_id.map_or(self.priv_key.as_ref(), |id| {
+ self.org_keys.as_ref().and_then(|h| h.get(id))
+ })
}
pub fn needs_unlock(&self) -> bool {
@@ -26,103 +27,163 @@ impl State {
}
pub fn set_timeout(&mut self) {
- // no real better option to unwrap here
- self.timeout_chan.send(TimeoutEvent::Set).unwrap();
+ self.timeout.set(self.timeout_duration);
}
pub fn clear(&mut self) {
self.priv_key = None;
- self.org_keys = Default::default();
- // no real better option to unwrap here
- self.timeout_chan.send(TimeoutEvent::Clear).unwrap();
+ self.org_keys = None;
+ self.timeout.clear();
+ }
+
+ pub fn set_sync_timeout(&mut self) {
+ self.sync_timeout.set(self.sync_timeout_duration);
}
}
pub struct Agent {
- timeout_duration: tokio::time::Duration,
- timeout: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
- timeout_chan: tokio::sync::mpsc::UnboundedReceiver<TimeoutEvent>,
- state: std::sync::Arc<tokio::sync::RwLock<State>>,
+ timer_r: tokio::sync::mpsc::UnboundedReceiver<()>,
+ sync_timer_r: tokio::sync::mpsc::UnboundedReceiver<()>,
+ state: std::sync::Arc<tokio::sync::Mutex<State>>,
}
impl Agent {
pub fn new() -> anyhow::Result<Self> {
let config = rbw::config::Config::load()?;
let timeout_duration =
- tokio::time::Duration::from_secs(config.lock_timeout);
- let (w, r) = tokio::sync::mpsc::unbounded_channel();
+ std::time::Duration::from_secs(config.lock_timeout);
+ let sync_timeout_duration =
+ std::time::Duration::from_secs(config.sync_interval);
+ let (timeout, timer_r) = crate::timeout::Timeout::new();
+ let (sync_timeout, sync_timer_r) = crate::timeout::Timeout::new();
+ if sync_timeout_duration > std::time::Duration::ZERO {
+ sync_timeout.set(sync_timeout_duration);
+ }
+ let notifications_handler = crate::notifications::Handler::new();
+ let clipboard: Box<dyn copypasta::ClipboardProvider> =
+ copypasta::ClipboardContext::new().map_or_else(
+ |e| {
+ log::warn!("couldn't create clipboard context: {e}");
+ let clipboard = Box::new(
+ // infallible
+ copypasta::nop_clipboard::NopClipboardContext::new()
+ .unwrap(),
+ );
+ let clipboard: Box<dyn copypasta::ClipboardProvider> =
+ clipboard;
+ clipboard
+ },
+ |clipboard| {
+ let clipboard = Box::new(clipboard);
+ let clipboard: Box<dyn copypasta::ClipboardProvider> =
+ clipboard;
+ clipboard
+ },
+ );
Ok(Self {
- timeout_duration,
- timeout: None,
- timeout_chan: r,
- state: std::sync::Arc::new(tokio::sync::RwLock::new(State {
+ timer_r,
+ sync_timer_r,
+ state: std::sync::Arc::new(tokio::sync::Mutex::new(State {
priv_key: None,
- org_keys: Default::default(),
- timeout_chan: w,
+ org_keys: None,
+ timeout,
+ timeout_duration,
+ sync_timeout,
+ sync_timeout_duration,
+ notifications_handler,
+ clipboard,
})),
})
}
- fn set_timeout(&mut self) {
- self.timeout =
- Some(Box::pin(tokio::time::sleep(self.timeout_duration)));
- }
-
- fn clear_timeout(&mut self) {
- self.timeout = None;
- }
-
pub async fn run(
- &mut self,
+ self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
- // tokio only supports timeouts up to 2^36 milliseconds
- let mut forever = Box::pin(tokio::time::sleep(
- tokio::time::Duration::from_secs(60 * 60 * 24 * 365 * 2),
- ));
- loop {
- let timeout = if let Some(timeout) = &mut self.timeout {
- timeout
- } else {
- &mut forever
- };
- tokio::select! {
- sock = listener.accept() => {
+ pub enum Event {
+ Request(std::io::Result<tokio::net::UnixStream>),
+ Timeout(()),
+ Sync(()),
+ }
+
+ let notifications = self
+ .state
+ .lock()
+ .await
+ .notifications_handler
+ .get_channel()
+ .await;
+ let notifications =
+ tokio_stream::wrappers::UnboundedReceiverStream::new(
+ notifications,
+ )
+ .map(|message| match message {
+ notifications::Message::Logout => Event::Timeout(()),
+ notifications::Message::Sync => Event::Sync(()),
+ })
+ .boxed();
+
+ let mut stream = futures_util::stream::select_all([
+ tokio_stream::wrappers::UnixListenerStream::new(listener)
+ .map(Event::Request)
+ .boxed(),
+ tokio_stream::wrappers::UnboundedReceiverStream::new(
+ self.timer_r,
+ )
+ .map(Event::Timeout)
+ .boxed(),
+ tokio_stream::wrappers::UnboundedReceiverStream::new(
+ self.sync_timer_r,
+ )
+ .map(Event::Sync)
+ .boxed(),
+ notifications,
+ ]);
+ while let Some(event) = stream.next().await {
+ match event {
+ Event::Request(res) => {
let mut sock = crate::sock::Sock::new(
- sock.context("failed to accept incoming connection")?.0
+ res.context("failed to accept incoming connection")?,
);
let state = self.state.clone();
tokio::spawn(async move {
- let res
- = handle_request(&mut sock, state.clone()).await;
+ let res =
+ handle_request(&mut sock, state.clone()).await;
if let Err(e) = res {
// unwrap is the only option here
sock.send(&rbw::protocol::Response::Error {
- error: format!("{:#}", e),
- }).await.unwrap();
+ error: format!("{e:#}"),
+ })
+ .await
+ .unwrap();
}
});
}
- _ = timeout => {
+ Event::Timeout(()) => {
+ self.state.lock().await.clear();
+ }
+ Event::Sync(()) => {
let state = self.state.clone();
- tokio::spawn(async move{
- state.write().await.clear();
+ tokio::spawn(async move {
+ // this could fail if we aren't logged in, but we
+ // don't care about that
+ if let Err(e) =
+ crate::actions::sync(None, state.clone()).await
+ {
+ eprintln!("failed to sync: {e:#}");
+ }
});
- }
- Some(ev) = self.timeout_chan.recv() => {
- match ev {
- TimeoutEvent::Set => self.set_timeout(),
- TimeoutEvent::Clear => self.clear_timeout(),
- }
+ self.state.lock().await.set_sync_timeout();
}
}
}
+ Ok(())
}
}
async fn handle_request(
sock: &mut crate::sock::Sock,
- state: std::sync::Arc<tokio::sync::RwLock<State>>,
+ state: std::sync::Arc<tokio::sync::Mutex<State>>,
) -> anyhow::Result<()> {
let req = sock.recv().await?;
let req = match req {
@@ -148,12 +209,7 @@ async fn handle_request(
true
}
rbw::protocol::Action::CheckLock => {
- crate::actions::check_lock(
- sock,
- state.clone(),
- req.tty.as_deref(),
- )
- .await?;
+ crate::actions::check_lock(sock, state.clone()).await?;
false
}
rbw::protocol::Action::Lock => {
@@ -161,7 +217,7 @@ async fn handle_request(
false
}
rbw::protocol::Action::Sync => {
- crate::actions::sync(sock, true).await?;
+ crate::actions::sync(Some(sock), state.clone()).await?;
false
}
rbw::protocol::Action::Decrypt {
@@ -187,6 +243,11 @@ async fn handle_request(
.await?;
true
}
+ rbw::protocol::Action::ClipboardStore { text } => {
+ crate::actions::clipboard_store(sock, state.clone(), text)
+ .await?;
+ true
+ }
rbw::protocol::Action::Quit => std::process::exit(0),
rbw::protocol::Action::Version => {
crate::actions::version(sock).await?;
@@ -195,7 +256,7 @@ async fn handle_request(
};
if set_timeout {
- state.write().await.set_timeout();
+ state.lock().await.set_timeout();
}
Ok(())