diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/bin/rbw-agent/agent.rs | 35 | ||||
-rw-r--r-- | src/bin/rbw-agent/notifications.rs | 162 |
2 files changed, 91 insertions, 106 deletions
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs index 29e400b..5769c0e 100644 --- a/src/bin/rbw-agent/agent.rs +++ b/src/bin/rbw-agent/agent.rs @@ -83,31 +83,28 @@ impl Agent { self, listener: tokio::net::UnixListener, ) -> anyhow::Result<()> { - enum Event { + pub enum Event { Request(std::io::Result<tokio::net::UnixStream>), Timeout(()), Sync(()), } - let c: tokio::sync::mpsc::UnboundedReceiver< - notifications::NotificationMessage, - > = { - self.state - .lock() - .await - .notifications_handler - .get_channel() - .await - }; + let notifications = self + .state + .lock() + .await + .notifications_handler + .get_channel() + .await; let notifications = - tokio_stream::wrappers::UnboundedReceiverStream::new(c) - .map(|message| match message { - notifications::NotificationMessage::Logout => { - Event::Timeout(()) - } - _ => Event::Sync(()), - }) - .boxed(); + tokio_stream::wrappers::UnboundedReceiverStream::new( + notifications, + ) + .map(|message| match message { + notifications::Message::Logout => Event::Timeout(()), + notifications::Message::Sync => Event::Sync(()), + }) + .boxed(); let mut stream = futures_util::stream::select_all([ tokio_stream::wrappers::UnixListenerStream::new(listener) diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs index 69ebda5..297fa9c 100644 --- a/src/bin/rbw-agent/notifications.rs +++ b/src/bin/rbw-agent/notifications.rs @@ -1,74 +1,23 @@ -use futures::stream::SplitSink; -use futures_util::{SinkExt, StreamExt}; -use tokio::{net::TcpStream, task::JoinHandle}; -use tokio_tungstenite::{ - connect_async, tungstenite::protocol::Message, MaybeTlsStream, - WebSocketStream, -}; - -#[derive(Copy, Clone)] -pub enum NotificationMessage { - SyncCipherUpdate, - SyncCipherCreate, - SyncLoginDelete, - SyncFolderDelete, - SyncCiphers, - - SyncVault, - SyncOrgKeys, - SyncFolderCreate, - SyncFolderUpdate, - SyncCipherDelete, - SyncSettings, +use futures_util::{SinkExt as _, StreamExt as _}; +#[derive(Clone, Copy, Debug)] +pub enum Message { + Sync, Logout, } -fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> { - // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message - let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; - - let unpacked_messagepack = - rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; - if !unpacked_messagepack.is_array() { - return None; - } - - let unpacked_message = unpacked_messagepack.as_array().unwrap(); - let message_type = - unpacked_message.iter().next().unwrap().as_u64().unwrap(); - - match message_type { - 0 => Some(NotificationMessage::SyncCipherUpdate), - 1 => Some(NotificationMessage::SyncCipherCreate), - 2 => Some(NotificationMessage::SyncLoginDelete), - 3 => Some(NotificationMessage::SyncFolderDelete), - 4 => Some(NotificationMessage::SyncCiphers), - 5 => Some(NotificationMessage::SyncVault), - 6 => Some(NotificationMessage::SyncOrgKeys), - 7 => Some(NotificationMessage::SyncFolderCreate), - 8 => Some(NotificationMessage::SyncFolderUpdate), - 9 => Some(NotificationMessage::SyncCipherDelete), - 10 => Some(NotificationMessage::SyncSettings), - 11 => Some(NotificationMessage::Logout), - _ => None, - } -} - pub struct Handler { write: Option< futures::stream::SplitSink< tokio_tungstenite::WebSocketStream< tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, >, - Message, + tokio_tungstenite::tungstenite::Message, >, >, read_handle: Option<tokio::task::JoinHandle<()>>, sending_channels: std::sync::Arc< - tokio::sync::RwLock< - Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>, - >, + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, >, } @@ -111,7 +60,9 @@ impl Handler { ) -> Result<(), Box<dyn std::error::Error>> { self.sending_channels.write().await.clear(); if let Some(mut write) = self.write.take() { - write.send(Message::Close(None)).await?; + write + .send(tokio_tungstenite::tungstenite::Message::Close(None)) + .await?; write.close().await?; self.read_handle.take().unwrap().await?; } @@ -122,9 +73,8 @@ impl Handler { pub async fn get_channel( &mut self, - ) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> { - let (tx, rx) = - tokio::sync::mpsc::unbounded_channel::<NotificationMessage>(); + ) -> tokio::sync::mpsc::UnboundedReceiver<Message> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); self.sending_channels.write().await.push(tx); rx } @@ -133,54 +83,92 @@ impl Handler { async fn subscribe_to_notifications( url: String, sending_channels: std::sync::Arc< - tokio::sync::RwLock< - Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>, - >, + tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>, >, ) -> Result< ( - SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, - JoinHandle<()>, + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>, + >, + tokio_tungstenite::tungstenite::Message, + >, + tokio::task::JoinHandle<()>, ), Box<dyn std::error::Error>, > { let url = url::Url::parse(url.as_str())?; - let (ws_stream, _response) = connect_async(url).await?; + let (ws_stream, _response) = + tokio_tungstenite::connect_async(url).await?; let (mut write, read) = ws_stream.split(); write - .send(Message::Text( - "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(), + .send(tokio_tungstenite::tungstenite::Message::Text( + "{\"protocol\":\"messagepack\",\"version\":1}\x1e".to_string(), )) .await .unwrap(); let read_future = async move { - read.map(|message| { - (message, sending_channels.clone()) - }).for_each(|(message, a)| async move { - let a = a.read().await; - + let sending_channels = &sending_channels; + read.for_each(|message| async move { match message { - Ok(Message::Binary(binary)) => { - let msgpack = parse_messagepack(&binary); - if let Some(msg) = msgpack { - let channels = a.as_slice(); - for channel in channels { - let res = channel.send(msg); - if res.is_err() { - eprintln!("error sending websocket message to channel"); - } + Ok(message) => { + if let Some(message) = parse_message(message) { + let sending_channels = sending_channels.read().await; + let sending_channels = sending_channels.as_slice(); + for channel in sending_channels { + channel.send(message).unwrap(); } } - }, + } Err(e) => { eprintln!("websocket error: {e:?}"); - }, - _ => {} + } } - }).await; + }) + .await; }; Ok((write, tokio::spawn(read_future))) } + +fn parse_message( + message: tokio_tungstenite::tungstenite::Message, +) -> Option<Message> { + let tokio_tungstenite::tungstenite::Message::Binary(data) = message + else { + return None; + }; + + // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message + let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1; + + let unpacked_messagepack = + rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?; + + let unpacked_message = unpacked_messagepack.as_array()?; + let message_type = unpacked_message.get(0)?.as_u64()?; + // invocation + if message_type != 1 { + return None; + } + let target = unpacked_message.get(3)?.as_str()?; + if target != "ReceiveMessage" { + return None; + } + + let args = unpacked_message.get(4)?.as_array()?; + let map = args.get(0)?.as_map()?; + for (k, v) in map { + if k.as_str()? == "Type" { + let ty = v.as_i64()?; + return match ty { + 11 => Some(Message::Logout), + _ => Some(Message::Sync), + }; + } + } + + None +} |