From 66cf6aea2d2fc355543470dab762211d9c8ad306 Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Thu, 27 Apr 2023 02:38:29 +0200 Subject: Cargo format and reconnect websocket on sync --- src/bin/rbw-agent/actions.rs | 34 +++++++---- src/bin/rbw-agent/agent.rs | 66 ++++++++++++++-------- src/bin/rbw-agent/main.rs | 2 +- src/bin/rbw-agent/notifications.rs | 113 ++++++++++++++++++++++++++----------- 4 files changed, 148 insertions(+), 67 deletions(-) diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs index f5b9dc0..2f34c6b 100644 --- a/src/bin/rbw-agent/actions.rs +++ b/src/bin/rbw-agent/actions.rs @@ -212,7 +212,6 @@ pub async fn login( eprintln!("failed to subscribe to notifications: {}", e) } - respond_ack(sock).await?; Ok(()) @@ -664,24 +663,39 @@ async fn config_pinentry() -> anyhow::Result { Ok(config.pinentry) } -pub async fn subscribe_to_notifications(state: std::sync::Arc>) -> anyhow::Result<()> { +pub async fn subscribe_to_notifications( + state: std::sync::Arc>, +) -> anyhow::Result<()> { // access token might be out of date, so we do a sync to refresh it sync(None).await?; - let config = rbw::config::Config::load_async().await.context("Config is missing")?; + let config = rbw::config::Config::load_async() + .await + .context("Config is missing")?; let email = config.email.clone().context("Config is missing email")?; - let db = rbw::db::Db::load_async(&config.server_name().as_str(), &email).await?; - let access_token = db.access_token.context("Error getting access token")?; - - let mut websocket_url = config.base_url.clone().expect("config is missing base url").replace("https://", "wss://") + "/notifications/hub?access_token="; + let db = rbw::db::Db::load_async(&config.server_name().as_str(), &email) + .await?; + let access_token = + db.access_token.context("Error getting access token")?; + + let mut websocket_url = config + .base_url + .clone() + .expect("config is missing base url") + .replace("https://", "wss://") + + "/notifications/hub?access_token="; websocket_url = websocket_url + &access_token; let mut state = state.write().await; - let err = state.notifications_handler.connect(websocket_url).await.err(); - + let err = state + .notifications_handler + .connect(websocket_url) + .await + .err(); + if let Some(err) = err { return Err(anyhow::anyhow!(err.to_string())); } else { Ok(()) } -} \ No newline at end of file +} diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index fb21728..b88121d 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -1,3 +1,4 @@ +use aes::cipher::typenum::private::IsNotEqualPrivate; use anyhow::Context as _; use futures_util::StreamExt as _; @@ -58,7 +59,8 @@ impl Agent { if sync_timeout_duration > std::time::Duration::ZERO { sync_timeout.set(sync_timeout_duration); } - let notifications_handler = crate::notifications::NotificationsHandler::new(); + let notifications_handler = + crate::notifications::NotificationsHandler::new(); Ok(Self { timer_r, sync_timer_r, @@ -78,7 +80,9 @@ impl Agent { self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - let err = crate::actions::subscribe_to_notifications(self.state.clone()).await; + let err = + crate::actions::subscribe_to_notifications(self.state.clone()) + .await; if let Err(e) = err { eprintln!("failed to subscribe to notifications: {e:#}") } @@ -88,25 +92,27 @@ impl Agent { Timeout(()), Sync(()), } - - let c: tokio::sync::mpsc::UnboundedReceiver = { - self.state.write().await.notifications_handler.get_channel().await + + let c: tokio::sync::mpsc::UnboundedReceiver< + notifications::NotificationMessage, + > = { + self.state + .write() + .await + .notifications_handler + .get_channel() + .await }; - let notifications = tokio_stream::wrappers::UnboundedReceiverStream::new( - c, - ) - .map(|message| { - match message { - notifications::NotificationMessage::Logout => { - Event::Timeout(()) - } - _ => { - Event::Sync(()) - } - } - }) - .boxed(); - + let notifications = + tokio_stream::wrappers::UnboundedReceiverStream::new(c) + .map(|message| match message { + notifications::NotificationMessage::Logout => { + Event::Timeout(()) + } + _ => Event::Sync(()), + }) + .boxed(); + let mut stream = futures_util::stream::select_all([ tokio_stream::wrappers::UnixListenerStream::new(listener) .map(Event::Request) @@ -121,7 +127,7 @@ impl Agent { ) .map(Event::Sync) .boxed(), - notifications, + notifications, ]); while let Some(event) = stream.next().await { match event { @@ -149,8 +155,24 @@ impl Agent { Event::Sync(()) => { // this could fail if we aren't logged in, but we don't // care about that + let state = self.state.clone(); tokio::spawn(async move { - let _ = crate::actions::sync(None).await; + let result = crate::actions::sync(None).await; + if let Err(e) = result { + eprintln!("failed to sync: {e:#}"); + } else { + if !state + .write() + .await + .notifications_handler + .is_connected() + { + let err = crate::actions::subscribe_to_notifications(state).await; + if let Err(e) = err { + eprintln!("failed to subscribe to notifications: {e:#}") + } + } + } }); self.state.write().await.set_sync_timeout(); } diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs index 5e0fa61..a9477df 100644 --- a/src/bin/rbw-agent/main.rs +++ b/src/bin/rbw-agent/main.rs @@ -19,9 +19,9 @@ mod actions; mod agent; mod daemon; mod debugger; +mod notifications; mod sock; mod timeout; -mod notifications; async fn tokio_main( startup_ack: Option, diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs index b575cf9..e8f84b0 100644 --- a/src/bin/rbw-agent/notifications.rs +++ b/src/bin/rbw-agent/notifications.rs @@ -1,7 +1,10 @@ -use futures::{stream::SplitSink}; -use tokio::{net::{TcpStream}, task::JoinHandle}; -use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, WebSocketStream, MaybeTlsStream}; -use futures_util::{StreamExt, SinkExt}; +use futures::stream::SplitSink; +use futures_util::{SinkExt, StreamExt}; +use tokio::{net::TcpStream, task::JoinHandle}; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message, MaybeTlsStream, + WebSocketStream, +}; #[derive(Copy, Clone)] pub enum NotificationMessage { @@ -21,43 +24,54 @@ pub enum NotificationMessage { Logout, } - - fn parse_messagepack(data: &[u8]) -> Option { // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message - let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0 )? + 1; + let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; - let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; + let unpacked_messagepack = + rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; if !unpacked_messagepack.is_array() { return None; } let unpacked_message = unpacked_messagepack.as_array().unwrap(); - let message_type = unpacked_message.iter().next().unwrap().as_u64().unwrap(); + let message_type = + unpacked_message.iter().next().unwrap().as_u64().unwrap(); let message = match message_type { - 0 => Some(NotificationMessage::SyncCipherUpdate), - 1 => Some(NotificationMessage::SyncCipherCreate), - 2 => Some(NotificationMessage::SyncLoginDelete), - 3 => Some(NotificationMessage::SyncFolderDelete), - 4 => Some(NotificationMessage::SyncCiphers), - 5 => Some(NotificationMessage::SyncVault), - 6 => Some(NotificationMessage::SyncOrgKeys), - 7 => Some(NotificationMessage::SyncFolderCreate), - 8 => Some(NotificationMessage::SyncFolderUpdate), - 9 => Some(NotificationMessage::SyncCipherDelete), + 0 => Some(NotificationMessage::SyncCipherUpdate), + 1 => Some(NotificationMessage::SyncCipherCreate), + 2 => Some(NotificationMessage::SyncLoginDelete), + 3 => Some(NotificationMessage::SyncFolderDelete), + 4 => Some(NotificationMessage::SyncCiphers), + 5 => Some(NotificationMessage::SyncVault), + 6 => Some(NotificationMessage::SyncOrgKeys), + 7 => Some(NotificationMessage::SyncFolderCreate), + 8 => Some(NotificationMessage::SyncFolderUpdate), + 9 => Some(NotificationMessage::SyncCipherDelete), 10 => Some(NotificationMessage::SyncSettings), 11 => Some(NotificationMessage::Logout), - _ => None + _ => None, }; return message; } pub struct NotificationsHandler { - write: Option>, Message>>, + write: Option< + futures::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + Message, + >, + >, read_handle: Option>, - sending_channels : std::sync::Arc>>>, + sending_channels: std::sync::Arc< + tokio::sync::RwLock< + Vec>, + >, + >, } impl NotificationsHandler { @@ -65,27 +79,38 @@ impl NotificationsHandler { Self { write: None, read_handle: None, - sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())), + sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new( + Vec::new(), + )), } } - pub async fn connect(&mut self, url: String) -> Result<(), Box> { + pub async fn connect( + &mut self, + url: String, + ) -> Result<(), Box> { if self.is_connected() { self.disconnect().await?; } - let (write, read_handle) = subscribe_to_notifications(url, self.sending_channels.clone()).await?; - + let (write, read_handle) = + subscribe_to_notifications(url, self.sending_channels.clone()) + .await?; + self.write = Some(write); self.read_handle = Some(read_handle); return Ok(()); } pub fn is_connected(&self) -> bool { - self.write.is_some() && self.read_handle.is_some() && !self.read_handle.as_ref().unwrap().is_finished() + self.write.is_some() + && self.read_handle.is_some() + && !self.read_handle.as_ref().unwrap().is_finished() } - pub async fn disconnect(&mut self) -> Result<(), Box> { + pub async fn disconnect( + &mut self, + ) -> Result<(), Box> { self.sending_channels.write().await.clear(); if let Some(mut write) = self.write.take() { write.send(Message::Close(None)).await?; @@ -97,20 +122,40 @@ impl NotificationsHandler { Ok(()) } - pub async fn get_channel(&mut self) -> tokio::sync::mpsc::UnboundedReceiver { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + pub async fn get_channel( + &mut self, + ) -> tokio::sync::mpsc::UnboundedReceiver { + let (tx, rx) = + tokio::sync::mpsc::unbounded_channel::(); self.sending_channels.write().await.push(tx); return rx; } - } -async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Arc>>>) -> Result<(SplitSink>, Message>, JoinHandle<()>), Box> { +async fn subscribe_to_notifications( + url: String, + sending_channels: std::sync::Arc< + tokio::sync::RwLock< + Vec>, + >, + >, +) -> Result< + ( + SplitSink>, Message>, + JoinHandle<()>, + ), + Box, +> { let url = url::Url::parse(url.as_str())?; let (ws_stream, _response) = connect_async(url).await?; let (mut write, read) = ws_stream.split(); - write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap(); + write + .send(Message::Text( + "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(), + )) + .await + .unwrap(); let read_future = async move { read.map(|message| { @@ -139,4 +184,4 @@ async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Ar }; return Ok((write, tokio::spawn(read_future))); -} \ No newline at end of file +} -- cgit v1.2.3-54-g00ecf