diff options
Diffstat (limited to 'src/bin/rbw-agent/notifications.rs')
-rw-r--r-- | src/bin/rbw-agent/notifications.rs | 210 |
1 files changed, 93 insertions, 117 deletions
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs index ffdefe9..c72fe38 100644 --- a/src/bin/rbw-agent/notifications.rs +++ b/src/bin/rbw-agent/notifications.rs @@ -1,17 +1,12 @@ -use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; +use futures::{stream::SplitSink}; +use tokio::{net::{TcpStream}, task::JoinHandle}; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, WebSocketStream, MaybeTlsStream}; use futures_util::{StreamExt, SinkExt}; -struct SyncCipherUpdate { - id: String -} - -struct SyncCipherCreate { - id: String -} - -enum NotificationMessage { - SyncCipherUpdate(SyncCipherUpdate), - SyncCipherCreate(SyncCipherCreate), +#[derive(Copy, Clone)] +pub enum NotificationMessage { + SyncCipherUpdate, + SyncCipherCreate, SyncLoginDelete, SyncFolderDelete, SyncCiphers, @@ -24,59 +19,25 @@ enum NotificationMessage { SyncSettings, Logout, +} - SyncSendCreate, - SyncSendUpdate, - SyncSendDelete, - - AuthRequest, - AuthRequestResponse, - None, -} fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> { - if data.len() < 2 { - return None; - } - - // the first few bytes with th 0x80 bit set, plus one byte terminating the length contain the length of the message + // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0 )? + 1; - println!("len_buffer_length: {:?}", len_buffer_length); - println!("data: {:?}", data); - let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok().unwrap(); - println!("unpacked_messagepack: {:?}", unpacked_messagepack); + let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; if !unpacked_messagepack.is_array() { return None; } + let unpacked_message = unpacked_messagepack.as_array().unwrap(); - println!("unpacked_message: {:?}", unpacked_message); - let message_type = unpacked_message.iter().next()?.as_u64()?; - let message = unpacked_message.iter().skip(4).next()?.as_array()?.first()?.as_map()?; - let payload = message.iter().filter(|x| x.0.as_str().unwrap() == "Payload").next()?.1.as_map()?; - println!("message_type: {:?}", message_type); - println!("payload: {:?}", payload); + let message_type = unpacked_message.iter().next().unwrap().as_u64().unwrap(); let message = match message_type { - 0 => { - let id = payload.iter().filter(|x| x.0.as_str().unwrap() == "Id").next()?.1.as_str()?; - - Some(NotificationMessage::SyncCipherUpdate( - SyncCipherUpdate { - id: id.to_string() - } - )) - }, - 1 => { - let id = payload.iter().filter(|x| x.0.as_str().unwrap() == "Id").next()?.1.as_str()?; - - Some(NotificationMessage::SyncCipherCreate( - SyncCipherCreate { - id: id.to_string() - } - )) - }, + 0 => Some(NotificationMessage::SyncCipherUpdate), + 1 => Some(NotificationMessage::SyncCipherCreate), 2 => Some(NotificationMessage::SyncLoginDelete), 3 => Some(NotificationMessage::SyncFolderDelete), 4 => Some(NotificationMessage::SyncCiphers), @@ -87,80 +48,95 @@ fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> { 9 => Some(NotificationMessage::SyncCipherDelete), 10 => Some(NotificationMessage::SyncSettings), 11 => Some(NotificationMessage::Logout), - 12 => Some(NotificationMessage::SyncSendCreate), - 13 => Some(NotificationMessage::SyncSendUpdate), - 14 => Some(NotificationMessage::SyncSendDelete), - 15 => Some(NotificationMessage::AuthRequest), - 16 => Some(NotificationMessage::AuthRequestResponse), - 100 => Some(NotificationMessage::None), _ => None }; return message; } -pub async fn subscribe_to_notifications(url: String) { - let url = url::Url::parse(url.as_str()).unwrap(); +pub struct NotificationsHandler { + write: Option<futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>, + read_handle: Option<tokio::task::JoinHandle<()>>, + sending_channels : std::sync::Arc<tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>>>, +} - let (ws_stream, _response) = connect_async(url).await.expect("Failed to connect"); +impl NotificationsHandler { + pub fn new() -> Self { + Self { + write: None, + read_handle: None, + sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())), + } + } - let (mut write, read) = ws_stream.split(); + pub async fn connect(&mut self, url: String) -> Result<(), Box<dyn std::error::Error>> { + if self.is_connected() { + self.disconnect().await?; + } - write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap(); + //subscribe_to_notifications(url, self.sending_channels.clone()).await?; + let (write, read_handle) = subscribe_to_notifications(url, self.sending_channels.clone()).await?; + + self.write = Some(write); + self.read_handle = Some(read_handle); + return Ok(()); + } + + pub fn is_connected(&self) -> bool { + self.write.is_some() + } - let read_future = read.for_each(|message| async { - match message { - Ok(Message::Binary(binary)) => { - let msg = parse_messagepack(&binary); - match msg { - Some(NotificationMessage::SyncCipherUpdate(update)) => { - println!("Websocket sent SyncCipherUpdate for id: {:?}", update.id); - crate::actions::sync(None).await.unwrap(); - println!("Synced") - }, - Some(NotificationMessage::SyncCipherCreate(update)) => { - println!("Websocket sent SyncCipherUpdate for id: {:?}", update.id); - crate::actions::sync(None).await.unwrap(); - println!("Synced") - }, - Some(NotificationMessage::SyncLoginDelete) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncFolderDelete) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncCiphers) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncVault) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncOrgKeys) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncFolderCreate) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncFolderUpdate) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::SyncCipherDelete) => { - crate::actions::sync(None).await.unwrap(); - }, - Some(NotificationMessage::Logout) => { - println!("Websocket sent Logout"); - // todo: proper logout? - std::process::exit(0); - }, - _ => {} - } - }, - Err(e) => { - println!("websocket error: {:?}", e); - }, - _ => {} + pub async fn disconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> { + self.sending_channels.write().await.clear(); + if let Some(mut write) = self.write.take() { + write.send(Message::Close(None)).await?; + write.close().await?; } - }); + Ok(()) + } + + pub async fn get_channel(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NotificationMessage>(); + self.sending_channels.write().await.push(tx); + return rx; + } - read_future.await; } + +async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Arc<tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>>>) -> Result<(SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, JoinHandle<()>), Box<dyn std::error::Error>> { + let url = url::Url::parse(url.as_str())?; + println!("Connecting to {}", url); + let (ws_stream, _response) = connect_async(url).await.expect("Failed to connect"); + let (mut write, read) = ws_stream.split(); + + write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap(); + + let read_future = async move { + read.map(|message| { + (message, sending_channels.clone()) + }).for_each(|(message, a)| async move { + let a = a.read().await; + + match message { + Ok(Message::Binary(binary)) => { + if binary.len() < 4 { + return; + } + + let msg1 = parse_messagepack(&binary); + if let Some(msg) = msg1 { + for channel in a.iter() { + let res = channel.send(msg); + } + } + }, + Err(e) => { + println!("websocket error: {:?}", e); + }, + _ => {} + } + }).await; + }; + + return Ok((write, tokio::spawn(read_future))); +}
\ No newline at end of file |