aboutsummaryrefslogtreecommitdiffstats
path: root/src/bin
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2023-07-16 16:20:01 -0400
committerGitHub <noreply@github.com>2023-07-16 16:20:01 -0400
commitb06eab0609451ec449a88ed5141a658e16197eb0 (patch)
tree278750c6c431c6fce9092f7890c0a8bad466f513 /src/bin
parentc8c99cf06fdc45568fc10b98257963ae63e9c486 (diff)
parent389655d8f76b49a8a391deda28cf75bd99d17a96 (diff)
downloadrbw-b06eab0609451ec449a88ed5141a658e16197eb0.tar.gz
rbw-b06eab0609451ec449a88ed5141a658e16197eb0.zip
Merge branch 'main' into result-to-clipboard
Diffstat (limited to 'src/bin')
-rw-r--r--src/bin/rbw-agent/actions.rs49
-rw-r--r--src/bin/rbw-agent/agent.rs54
-rw-r--r--src/bin/rbw-agent/main.rs1
-rw-r--r--src/bin/rbw-agent/notifications.rs187
4 files changed, 288 insertions, 3 deletions
diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs
index 7b5dc58..4d77133 100644
--- a/src/bin/rbw-agent/actions.rs
+++ b/src/bin/rbw-agent/actions.rs
@@ -1,3 +1,5 @@
+use std::f32::consts::E;
+
use anyhow::Context as _;
pub async fn register(
@@ -130,7 +132,7 @@ pub async fn login(
protected_key,
)) => {
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -148,6 +150,7 @@ pub async fn login(
Err(rbw::error::Error::TwoFactorRequired { providers }) => {
let supported_types = vec![
rbw::api::TwoFactorProviderType::Authenticator,
+ rbw::api::TwoFactorProviderType::Yubikey,
rbw::api::TwoFactorProviderType::Email,
];
@@ -169,7 +172,7 @@ pub async fn login(
)
.await?;
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -205,6 +208,11 @@ pub async fn login(
}
}
+ let err = subscribe_to_notifications(state.clone()).await.err();
+ if let Some(e) = err {
+ eprintln!("failed to subscribe to notifications: {}", e)
+ }
+
respond_ack(sock).await?;
Ok(())
@@ -655,3 +663,40 @@ async fn config_pinentry() -> anyhow::Result<String> {
let config = rbw::config::Config::load_async().await?;
Ok(config.pinentry)
}
+
+pub async fn subscribe_to_notifications(
+ state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>,
+) -> anyhow::Result<()> {
+ // access token might be out of date, so we do a sync to refresh it
+ sync(None).await?;
+
+ let config = rbw::config::Config::load_async()
+ .await
+ .context("Config is missing")?;
+ let email = config.email.clone().context("Config is missing email")?;
+ let db = rbw::db::Db::load_async(&config.server_name().as_str(), &email)
+ .await?;
+ let access_token =
+ db.access_token.context("Error getting access token")?;
+
+ let mut websocket_url = config
+ .base_url
+ .clone()
+ .expect("config is missing base url")
+ .replace("https://", "wss://")
+ + "/notifications/hub?access_token=";
+ websocket_url = websocket_url + &access_token;
+
+ let mut state = state.write().await;
+ let err = state
+ .notifications_handler
+ .connect(websocket_url)
+ .await
+ .err();
+
+ if let Some(err) = err {
+ return Err(anyhow::anyhow!(err.to_string()));
+ } else {
+ Ok(())
+ }
+}
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index 7dcab16..b88121d 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -1,6 +1,9 @@
+use aes::cipher::typenum::private::IsNotEqualPrivate;
use anyhow::Context as _;
use futures_util::StreamExt as _;
+use crate::notifications;
+
pub struct State {
pub priv_key: Option<rbw::locked::Keys>,
pub org_keys:
@@ -9,6 +12,7 @@ pub struct State {
pub timeout_duration: std::time::Duration,
pub sync_timeout: crate::timeout::Timeout,
pub sync_timeout_duration: std::time::Duration,
+ pub notifications_handler: crate::notifications::NotificationsHandler,
}
impl State {
@@ -55,6 +59,8 @@ impl Agent {
if sync_timeout_duration > std::time::Duration::ZERO {
sync_timeout.set(sync_timeout_duration);
}
+ let notifications_handler =
+ crate::notifications::NotificationsHandler::new();
Ok(Self {
timer_r,
sync_timer_r,
@@ -65,6 +71,7 @@ impl Agent {
timeout_duration,
sync_timeout,
sync_timeout_duration,
+ notifications_handler,
})),
})
}
@@ -73,11 +80,39 @@ impl Agent {
self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
+ let err =
+ crate::actions::subscribe_to_notifications(self.state.clone())
+ .await;
+ if let Err(e) = err {
+ eprintln!("failed to subscribe to notifications: {e:#}")
+ }
+
enum Event {
Request(std::io::Result<tokio::net::UnixStream>),
Timeout(()),
Sync(()),
}
+
+ let c: tokio::sync::mpsc::UnboundedReceiver<
+ notifications::NotificationMessage,
+ > = {
+ self.state
+ .write()
+ .await
+ .notifications_handler
+ .get_channel()
+ .await
+ };
+ let notifications =
+ tokio_stream::wrappers::UnboundedReceiverStream::new(c)
+ .map(|message| match message {
+ notifications::NotificationMessage::Logout => {
+ Event::Timeout(())
+ }
+ _ => Event::Sync(()),
+ })
+ .boxed();
+
let mut stream = futures_util::stream::select_all([
tokio_stream::wrappers::UnixListenerStream::new(listener)
.map(Event::Request)
@@ -92,6 +127,7 @@ impl Agent {
)
.map(Event::Sync)
.boxed(),
+ notifications,
]);
while let Some(event) = stream.next().await {
match event {
@@ -119,8 +155,24 @@ impl Agent {
Event::Sync(()) => {
// this could fail if we aren't logged in, but we don't
// care about that
+ let state = self.state.clone();
tokio::spawn(async move {
- let _ = crate::actions::sync(None).await;
+ let result = crate::actions::sync(None).await;
+ if let Err(e) = result {
+ eprintln!("failed to sync: {e:#}");
+ } else {
+ if !state
+ .write()
+ .await
+ .notifications_handler
+ .is_connected()
+ {
+ let err = crate::actions::subscribe_to_notifications(state).await;
+ if let Err(e) = err {
+ eprintln!("failed to subscribe to notifications: {e:#}")
+ }
+ }
+ }
});
self.state.write().await.set_sync_timeout();
}
diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs
index 81eee3a..a9477df 100644
--- a/src/bin/rbw-agent/main.rs
+++ b/src/bin/rbw-agent/main.rs
@@ -19,6 +19,7 @@ mod actions;
mod agent;
mod daemon;
mod debugger;
+mod notifications;
mod sock;
mod timeout;
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
new file mode 100644
index 0000000..e8f84b0
--- /dev/null
+++ b/src/bin/rbw-agent/notifications.rs
@@ -0,0 +1,187 @@
+use futures::stream::SplitSink;
+use futures_util::{SinkExt, StreamExt};
+use tokio::{net::TcpStream, task::JoinHandle};
+use tokio_tungstenite::{
+ connect_async, tungstenite::protocol::Message, MaybeTlsStream,
+ WebSocketStream,
+};
+
+#[derive(Copy, Clone)]
+pub enum NotificationMessage {
+ SyncCipherUpdate,
+ SyncCipherCreate,
+ SyncLoginDelete,
+ SyncFolderDelete,
+ SyncCiphers,
+
+ SyncVault,
+ SyncOrgKeys,
+ SyncFolderCreate,
+ SyncFolderUpdate,
+ SyncCipherDelete,
+ SyncSettings,
+
+ Logout,
+}
+
+fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
+ let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
+
+ let unpacked_messagepack =
+ rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
+ if !unpacked_messagepack.is_array() {
+ return None;
+ }
+
+ let unpacked_message = unpacked_messagepack.as_array().unwrap();
+ let message_type =
+ unpacked_message.iter().next().unwrap().as_u64().unwrap();
+
+ let message = match message_type {
+ 0 => Some(NotificationMessage::SyncCipherUpdate),
+ 1 => Some(NotificationMessage::SyncCipherCreate),
+ 2 => Some(NotificationMessage::SyncLoginDelete),
+ 3 => Some(NotificationMessage::SyncFolderDelete),
+ 4 => Some(NotificationMessage::SyncCiphers),
+ 5 => Some(NotificationMessage::SyncVault),
+ 6 => Some(NotificationMessage::SyncOrgKeys),
+ 7 => Some(NotificationMessage::SyncFolderCreate),
+ 8 => Some(NotificationMessage::SyncFolderUpdate),
+ 9 => Some(NotificationMessage::SyncCipherDelete),
+ 10 => Some(NotificationMessage::SyncSettings),
+ 11 => Some(NotificationMessage::Logout),
+ _ => None,
+ };
+
+ return message;
+}
+
+pub struct NotificationsHandler {
+ write: Option<
+ futures::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ Message,
+ >,
+ >,
+ read_handle: Option<tokio::task::JoinHandle<()>>,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+}
+
+impl NotificationsHandler {
+ pub fn new() -> Self {
+ Self {
+ write: None,
+ read_handle: None,
+ sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(
+ Vec::new(),
+ )),
+ }
+ }
+
+ pub async fn connect(
+ &mut self,
+ url: String,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ if self.is_connected() {
+ self.disconnect().await?;
+ }
+
+ let (write, read_handle) =
+ subscribe_to_notifications(url, self.sending_channels.clone())
+ .await?;
+
+ self.write = Some(write);
+ self.read_handle = Some(read_handle);
+ return Ok(());
+ }
+
+ pub fn is_connected(&self) -> bool {
+ self.write.is_some()
+ && self.read_handle.is_some()
+ && !self.read_handle.as_ref().unwrap().is_finished()
+ }
+
+ pub async fn disconnect(
+ &mut self,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ self.sending_channels.write().await.clear();
+ if let Some(mut write) = self.write.take() {
+ write.send(Message::Close(None)).await?;
+ write.close().await?;
+ self.read_handle.take().unwrap().await?;
+ }
+ self.write = None;
+ self.read_handle = None;
+ Ok(())
+ }
+
+ pub async fn get_channel(
+ &mut self,
+ ) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> {
+ let (tx, rx) =
+ tokio::sync::mpsc::unbounded_channel::<NotificationMessage>();
+ self.sending_channels.write().await.push(tx);
+ return rx;
+ }
+}
+
+async fn subscribe_to_notifications(
+ url: String,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+) -> Result<
+ (
+ SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
+ JoinHandle<()>,
+ ),
+ Box<dyn std::error::Error>,
+> {
+ let url = url::Url::parse(url.as_str())?;
+ let (ws_stream, _response) = connect_async(url).await?;
+ let (mut write, read) = ws_stream.split();
+
+ write
+ .send(Message::Text(
+ "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(),
+ ))
+ .await
+ .unwrap();
+
+ let read_future = async move {
+ read.map(|message| {
+ (message, sending_channels.clone())
+ }).for_each(|(message, a)| async move {
+ let a = a.read().await;
+
+ match message {
+ Ok(Message::Binary(binary)) => {
+ let msgpack = parse_messagepack(&binary);
+ if let Some(msg) = msgpack {
+ for channel in a.iter() {
+ let res = channel.send(msg);
+ if res.is_err() {
+ eprintln!("error sending websocket message to channel");
+ }
+ }
+ }
+ },
+ Err(e) => {
+ eprintln!("websocket error: {:?}", e);
+ },
+ _ => {}
+ }
+ }).await;
+ };
+
+ return Ok((write, tokio::spawn(read_future)));
+}