use futures::stream::SplitSink; use futures_util::{SinkExt, StreamExt}; use tokio::{net::TcpStream, task::JoinHandle}; use tokio_tungstenite::{ connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream, }; #[derive(Copy, Clone)] pub enum NotificationMessage { SyncCipherUpdate, SyncCipherCreate, SyncLoginDelete, SyncFolderDelete, SyncCiphers, SyncVault, SyncOrgKeys, SyncFolderCreate, SyncFolderUpdate, SyncCipherDelete, SyncSettings, Logout, } fn parse_messagepack(data: &[u8]) -> Option { // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; if !unpacked_messagepack.is_array() { return None; } let unpacked_message = unpacked_messagepack.as_array().unwrap(); let message_type = unpacked_message.iter().next().unwrap().as_u64().unwrap(); let message = match message_type { 0 => Some(NotificationMessage::SyncCipherUpdate), 1 => Some(NotificationMessage::SyncCipherCreate), 2 => Some(NotificationMessage::SyncLoginDelete), 3 => Some(NotificationMessage::SyncFolderDelete), 4 => Some(NotificationMessage::SyncCiphers), 5 => Some(NotificationMessage::SyncVault), 6 => Some(NotificationMessage::SyncOrgKeys), 7 => Some(NotificationMessage::SyncFolderCreate), 8 => Some(NotificationMessage::SyncFolderUpdate), 9 => Some(NotificationMessage::SyncCipherDelete), 10 => Some(NotificationMessage::SyncSettings), 11 => Some(NotificationMessage::Logout), _ => None, }; return message; } pub struct NotificationsHandler { write: Option< futures::stream::SplitSink< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream, >, Message, >, >, read_handle: Option>, sending_channels: std::sync::Arc< tokio::sync::RwLock< Vec>, >, >, } impl NotificationsHandler { pub fn new() -> Self { Self { write: None, read_handle: None, sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new( Vec::new(), )), } } pub async fn connect( &mut self, url: String, ) -> Result<(), Box> { if self.is_connected() { self.disconnect().await?; } let (write, read_handle) = subscribe_to_notifications(url, self.sending_channels.clone()) .await?; self.write = Some(write); self.read_handle = Some(read_handle); return Ok(()); } pub fn is_connected(&self) -> bool { self.write.is_some() && self.read_handle.is_some() && !self.read_handle.as_ref().unwrap().is_finished() } pub async fn disconnect( &mut self, ) -> Result<(), Box> { self.sending_channels.write().await.clear(); if let Some(mut write) = self.write.take() { write.send(Message::Close(None)).await?; write.close().await?; self.read_handle.take().unwrap().await?; } self.write = None; self.read_handle = None; Ok(()) } pub async fn get_channel( &mut self, ) -> tokio::sync::mpsc::UnboundedReceiver { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); self.sending_channels.write().await.push(tx); return rx; } } async fn subscribe_to_notifications( url: String, sending_channels: std::sync::Arc< tokio::sync::RwLock< Vec>, >, >, ) -> Result< ( SplitSink>, Message>, JoinHandle<()>, ), Box, > { let url = url::Url::parse(url.as_str())?; let (ws_stream, _response) = connect_async(url).await?; let (mut write, read) = ws_stream.split(); write .send(Message::Text( "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(), )) .await .unwrap(); let read_future = async move { read.map(|message| { (message, sending_channels.clone()) }).for_each(|(message, a)| async move { let a = a.read().await; match message { Ok(Message::Binary(binary)) => { let msgpack = parse_messagepack(&binary); if let Some(msg) = msgpack { for channel in a.iter() { let res = channel.send(msg); if res.is_err() { eprintln!("error sending websocket message to channel"); } } } }, Err(e) => { eprintln!("websocket error: {:?}", e); }, _ => {} } }).await; }; return Ok((write, tokio::spawn(read_future))); }