aboutsummaryrefslogtreecommitdiffstats
path: root/src/bin/rbw-agent/notifications.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/rbw-agent/notifications.rs')
-rw-r--r--src/bin/rbw-agent/notifications.rs187
1 files changed, 187 insertions, 0 deletions
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
new file mode 100644
index 0000000..e8f84b0
--- /dev/null
+++ b/src/bin/rbw-agent/notifications.rs
@@ -0,0 +1,187 @@
+use futures::stream::SplitSink;
+use futures_util::{SinkExt, StreamExt};
+use tokio::{net::TcpStream, task::JoinHandle};
+use tokio_tungstenite::{
+ connect_async, tungstenite::protocol::Message, MaybeTlsStream,
+ WebSocketStream,
+};
+
+#[derive(Copy, Clone)]
+pub enum NotificationMessage {
+ SyncCipherUpdate,
+ SyncCipherCreate,
+ SyncLoginDelete,
+ SyncFolderDelete,
+ SyncCiphers,
+
+ SyncVault,
+ SyncOrgKeys,
+ SyncFolderCreate,
+ SyncFolderUpdate,
+ SyncCipherDelete,
+ SyncSettings,
+
+ Logout,
+}
+
+fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
+ let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
+
+ let unpacked_messagepack =
+ rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
+ if !unpacked_messagepack.is_array() {
+ return None;
+ }
+
+ let unpacked_message = unpacked_messagepack.as_array().unwrap();
+ let message_type =
+ unpacked_message.iter().next().unwrap().as_u64().unwrap();
+
+ let message = match message_type {
+ 0 => Some(NotificationMessage::SyncCipherUpdate),
+ 1 => Some(NotificationMessage::SyncCipherCreate),
+ 2 => Some(NotificationMessage::SyncLoginDelete),
+ 3 => Some(NotificationMessage::SyncFolderDelete),
+ 4 => Some(NotificationMessage::SyncCiphers),
+ 5 => Some(NotificationMessage::SyncVault),
+ 6 => Some(NotificationMessage::SyncOrgKeys),
+ 7 => Some(NotificationMessage::SyncFolderCreate),
+ 8 => Some(NotificationMessage::SyncFolderUpdate),
+ 9 => Some(NotificationMessage::SyncCipherDelete),
+ 10 => Some(NotificationMessage::SyncSettings),
+ 11 => Some(NotificationMessage::Logout),
+ _ => None,
+ };
+
+ return message;
+}
+
+pub struct NotificationsHandler {
+ write: Option<
+ futures::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ Message,
+ >,
+ >,
+ read_handle: Option<tokio::task::JoinHandle<()>>,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+}
+
+impl NotificationsHandler {
+ pub fn new() -> Self {
+ Self {
+ write: None,
+ read_handle: None,
+ sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(
+ Vec::new(),
+ )),
+ }
+ }
+
+ pub async fn connect(
+ &mut self,
+ url: String,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ if self.is_connected() {
+ self.disconnect().await?;
+ }
+
+ let (write, read_handle) =
+ subscribe_to_notifications(url, self.sending_channels.clone())
+ .await?;
+
+ self.write = Some(write);
+ self.read_handle = Some(read_handle);
+ return Ok(());
+ }
+
+ pub fn is_connected(&self) -> bool {
+ self.write.is_some()
+ && self.read_handle.is_some()
+ && !self.read_handle.as_ref().unwrap().is_finished()
+ }
+
+ pub async fn disconnect(
+ &mut self,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ self.sending_channels.write().await.clear();
+ if let Some(mut write) = self.write.take() {
+ write.send(Message::Close(None)).await?;
+ write.close().await?;
+ self.read_handle.take().unwrap().await?;
+ }
+ self.write = None;
+ self.read_handle = None;
+ Ok(())
+ }
+
+ pub async fn get_channel(
+ &mut self,
+ ) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> {
+ let (tx, rx) =
+ tokio::sync::mpsc::unbounded_channel::<NotificationMessage>();
+ self.sending_channels.write().await.push(tx);
+ return rx;
+ }
+}
+
+async fn subscribe_to_notifications(
+ url: String,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+) -> Result<
+ (
+ SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
+ JoinHandle<()>,
+ ),
+ Box<dyn std::error::Error>,
+> {
+ let url = url::Url::parse(url.as_str())?;
+ let (ws_stream, _response) = connect_async(url).await?;
+ let (mut write, read) = ws_stream.split();
+
+ write
+ .send(Message::Text(
+ "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(),
+ ))
+ .await
+ .unwrap();
+
+ let read_future = async move {
+ read.map(|message| {
+ (message, sending_channels.clone())
+ }).for_each(|(message, a)| async move {
+ let a = a.read().await;
+
+ match message {
+ Ok(Message::Binary(binary)) => {
+ let msgpack = parse_messagepack(&binary);
+ if let Some(msg) = msgpack {
+ for channel in a.iter() {
+ let res = channel.send(msg);
+ if res.is_err() {
+ eprintln!("error sending websocket message to channel");
+ }
+ }
+ }
+ },
+ Err(e) => {
+ eprintln!("websocket error: {:?}", e);
+ },
+ _ => {}
+ }
+ }).await;
+ };
+
+ return Ok((write, tokio::spawn(read_future)));
+}