From 66cf6aea2d2fc355543470dab762211d9c8ad306 Mon Sep 17 00:00:00 2001 From: Bernd Schoolmann Date: Thu, 27 Apr 2023 02:38:29 +0200 Subject: Cargo format and reconnect websocket on sync --- src/bin/rbw-agent/notifications.rs | 113 ++++++++++++++++++++++++++----------- 1 file changed, 79 insertions(+), 34 deletions(-) (limited to 'src/bin/rbw-agent/notifications.rs') diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs index b575cf9..e8f84b0 100644 --- a/src/bin/rbw-agent/notifications.rs +++ b/src/bin/rbw-agent/notifications.rs @@ -1,7 +1,10 @@ -use futures::{stream::SplitSink}; -use tokio::{net::{TcpStream}, task::JoinHandle}; -use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, WebSocketStream, MaybeTlsStream}; -use futures_util::{StreamExt, SinkExt}; +use futures::stream::SplitSink; +use futures_util::{SinkExt, StreamExt}; +use tokio::{net::TcpStream, task::JoinHandle}; +use tokio_tungstenite::{ + connect_async, tungstenite::protocol::Message, MaybeTlsStream, + WebSocketStream, +}; #[derive(Copy, Clone)] pub enum NotificationMessage { @@ -21,43 +24,54 @@ pub enum NotificationMessage { Logout, } - - fn parse_messagepack(data: &[u8]) -> Option { // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message - let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0 )? + 1; + let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; - let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; + let unpacked_messagepack = + rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; if !unpacked_messagepack.is_array() { return None; } let unpacked_message = unpacked_messagepack.as_array().unwrap(); - let message_type = unpacked_message.iter().next().unwrap().as_u64().unwrap(); + let message_type = + unpacked_message.iter().next().unwrap().as_u64().unwrap(); let message = match message_type { - 0 => Some(NotificationMessage::SyncCipherUpdate), - 1 => Some(NotificationMessage::SyncCipherCreate), - 2 => Some(NotificationMessage::SyncLoginDelete), - 3 => Some(NotificationMessage::SyncFolderDelete), - 4 => Some(NotificationMessage::SyncCiphers), - 5 => Some(NotificationMessage::SyncVault), - 6 => Some(NotificationMessage::SyncOrgKeys), - 7 => Some(NotificationMessage::SyncFolderCreate), - 8 => Some(NotificationMessage::SyncFolderUpdate), - 9 => Some(NotificationMessage::SyncCipherDelete), + 0 => Some(NotificationMessage::SyncCipherUpdate), + 1 => Some(NotificationMessage::SyncCipherCreate), + 2 => Some(NotificationMessage::SyncLoginDelete), + 3 => Some(NotificationMessage::SyncFolderDelete), + 4 => Some(NotificationMessage::SyncCiphers), + 5 => Some(NotificationMessage::SyncVault), + 6 => Some(NotificationMessage::SyncOrgKeys), + 7 => Some(NotificationMessage::SyncFolderCreate), + 8 => Some(NotificationMessage::SyncFolderUpdate), + 9 => Some(NotificationMessage::SyncCipherDelete), 10 => Some(NotificationMessage::SyncSettings), 11 => Some(NotificationMessage::Logout), - _ => None + _ => None, }; return message; } pub struct NotificationsHandler { - write: Option>, Message>>, + write: Option< + futures::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + Message, + >, + >, read_handle: Option>, - sending_channels : std::sync::Arc>>>, + sending_channels: std::sync::Arc< + tokio::sync::RwLock< + Vec>, + >, + >, } impl NotificationsHandler { @@ -65,27 +79,38 @@ impl NotificationsHandler { Self { write: None, read_handle: None, - sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())), + sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new( + Vec::new(), + )), } } - pub async fn connect(&mut self, url: String) -> Result<(), Box> { + pub async fn connect( + &mut self, + url: String, + ) -> Result<(), Box> { if self.is_connected() { self.disconnect().await?; } - let (write, read_handle) = subscribe_to_notifications(url, self.sending_channels.clone()).await?; - + let (write, read_handle) = + subscribe_to_notifications(url, self.sending_channels.clone()) + .await?; + self.write = Some(write); self.read_handle = Some(read_handle); return Ok(()); } pub fn is_connected(&self) -> bool { - self.write.is_some() && self.read_handle.is_some() && !self.read_handle.as_ref().unwrap().is_finished() + self.write.is_some() + && self.read_handle.is_some() + && !self.read_handle.as_ref().unwrap().is_finished() } - pub async fn disconnect(&mut self) -> Result<(), Box> { + pub async fn disconnect( + &mut self, + ) -> Result<(), Box> { self.sending_channels.write().await.clear(); if let Some(mut write) = self.write.take() { write.send(Message::Close(None)).await?; @@ -97,20 +122,40 @@ impl NotificationsHandler { Ok(()) } - pub async fn get_channel(&mut self) -> tokio::sync::mpsc::UnboundedReceiver { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + pub async fn get_channel( + &mut self, + ) -> tokio::sync::mpsc::UnboundedReceiver { + let (tx, rx) = + tokio::sync::mpsc::unbounded_channel::(); self.sending_channels.write().await.push(tx); return rx; } - } -async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Arc>>>) -> Result<(SplitSink>, Message>, JoinHandle<()>), Box> { +async fn subscribe_to_notifications( + url: String, + sending_channels: std::sync::Arc< + tokio::sync::RwLock< + Vec>, + >, + >, +) -> Result< + ( + SplitSink>, Message>, + JoinHandle<()>, + ), + Box, +> { let url = url::Url::parse(url.as_str())?; let (ws_stream, _response) = connect_async(url).await?; let (mut write, read) = ws_stream.split(); - write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap(); + write + .send(Message::Text( + "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(), + )) + .await + .unwrap(); let read_future = async move { read.map(|message| { @@ -139,4 +184,4 @@ async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Ar }; return Ok((write, tokio::spawn(read_future))); -} \ No newline at end of file +} -- cgit v1.2.3-54-g00ecf