aboutsummaryrefslogtreecommitdiffstats
path: root/src/bin
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2023-07-19 01:05:42 -0400
committerJesse Luehrs <doy@tozt.net>2023-07-19 01:05:42 -0400
commit76ab1de92a36b151b3c817e16737fd703567a3f2 (patch)
tree688f604bce050f154ec35b7f6faf0c161a5c79b6 /src/bin
parent7a0eae68c1f3496a1d421b61f66115a7889d7e92 (diff)
downloadrbw-76ab1de92a36b151b3c817e16737fd703567a3f2.tar.gz
rbw-76ab1de92a36b151b3c817e16737fd703567a3f2.zip
more correct websocket notification handling
the servers tend to be fairly chatty with messages, mostly pings and heartbeats of various sorts, and we don't want to sync on all of those. also, the message type in the first array element of the messagepack structure is not the same thing as the UpdateType - that is stored as an argument to the ReceiveMessage invocation, so we need to parse a bit further to get the actual UpdateType. this still just does a full sync on any changes, though.
Diffstat (limited to 'src/bin')
-rw-r--r--src/bin/rbw-agent/agent.rs35
-rw-r--r--src/bin/rbw-agent/notifications.rs162
2 files changed, 91 insertions, 106 deletions
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index 29e400b..5769c0e 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -83,31 +83,28 @@ impl Agent {
self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
- enum Event {
+ pub enum Event {
Request(std::io::Result<tokio::net::UnixStream>),
Timeout(()),
Sync(()),
}
- let c: tokio::sync::mpsc::UnboundedReceiver<
- notifications::NotificationMessage,
- > = {
- self.state
- .lock()
- .await
- .notifications_handler
- .get_channel()
- .await
- };
+ let notifications = self
+ .state
+ .lock()
+ .await
+ .notifications_handler
+ .get_channel()
+ .await;
let notifications =
- tokio_stream::wrappers::UnboundedReceiverStream::new(c)
- .map(|message| match message {
- notifications::NotificationMessage::Logout => {
- Event::Timeout(())
- }
- _ => Event::Sync(()),
- })
- .boxed();
+ tokio_stream::wrappers::UnboundedReceiverStream::new(
+ notifications,
+ )
+ .map(|message| match message {
+ notifications::Message::Logout => Event::Timeout(()),
+ notifications::Message::Sync => Event::Sync(()),
+ })
+ .boxed();
let mut stream = futures_util::stream::select_all([
tokio_stream::wrappers::UnixListenerStream::new(listener)
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
index 69ebda5..297fa9c 100644
--- a/src/bin/rbw-agent/notifications.rs
+++ b/src/bin/rbw-agent/notifications.rs
@@ -1,74 +1,23 @@
-use futures::stream::SplitSink;
-use futures_util::{SinkExt, StreamExt};
-use tokio::{net::TcpStream, task::JoinHandle};
-use tokio_tungstenite::{
- connect_async, tungstenite::protocol::Message, MaybeTlsStream,
- WebSocketStream,
-};
-
-#[derive(Copy, Clone)]
-pub enum NotificationMessage {
- SyncCipherUpdate,
- SyncCipherCreate,
- SyncLoginDelete,
- SyncFolderDelete,
- SyncCiphers,
-
- SyncVault,
- SyncOrgKeys,
- SyncFolderCreate,
- SyncFolderUpdate,
- SyncCipherDelete,
- SyncSettings,
+use futures_util::{SinkExt as _, StreamExt as _};
+#[derive(Clone, Copy, Debug)]
+pub enum Message {
+ Sync,
Logout,
}
-fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
- // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
- let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
-
- let unpacked_messagepack =
- rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
- if !unpacked_messagepack.is_array() {
- return None;
- }
-
- let unpacked_message = unpacked_messagepack.as_array().unwrap();
- let message_type =
- unpacked_message.iter().next().unwrap().as_u64().unwrap();
-
- match message_type {
- 0 => Some(NotificationMessage::SyncCipherUpdate),
- 1 => Some(NotificationMessage::SyncCipherCreate),
- 2 => Some(NotificationMessage::SyncLoginDelete),
- 3 => Some(NotificationMessage::SyncFolderDelete),
- 4 => Some(NotificationMessage::SyncCiphers),
- 5 => Some(NotificationMessage::SyncVault),
- 6 => Some(NotificationMessage::SyncOrgKeys),
- 7 => Some(NotificationMessage::SyncFolderCreate),
- 8 => Some(NotificationMessage::SyncFolderUpdate),
- 9 => Some(NotificationMessage::SyncCipherDelete),
- 10 => Some(NotificationMessage::SyncSettings),
- 11 => Some(NotificationMessage::Logout),
- _ => None,
- }
-}
-
pub struct Handler {
write: Option<
futures::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
- Message,
+ tokio_tungstenite::tungstenite::Message,
>,
>,
read_handle: Option<tokio::task::JoinHandle<()>>,
sending_channels: std::sync::Arc<
- tokio::sync::RwLock<
- Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
- >,
+ tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>,
>,
}
@@ -111,7 +60,9 @@ impl Handler {
) -> Result<(), Box<dyn std::error::Error>> {
self.sending_channels.write().await.clear();
if let Some(mut write) = self.write.take() {
- write.send(Message::Close(None)).await?;
+ write
+ .send(tokio_tungstenite::tungstenite::Message::Close(None))
+ .await?;
write.close().await?;
self.read_handle.take().unwrap().await?;
}
@@ -122,9 +73,8 @@ impl Handler {
pub async fn get_channel(
&mut self,
- ) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> {
- let (tx, rx) =
- tokio::sync::mpsc::unbounded_channel::<NotificationMessage>();
+ ) -> tokio::sync::mpsc::UnboundedReceiver<Message> {
+ let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.sending_channels.write().await.push(tx);
rx
}
@@ -133,54 +83,92 @@ impl Handler {
async fn subscribe_to_notifications(
url: String,
sending_channels: std::sync::Arc<
- tokio::sync::RwLock<
- Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
- >,
+ tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>,
>,
) -> Result<
(
- SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
- JoinHandle<()>,
+ futures_util::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ tokio_tungstenite::tungstenite::Message,
+ >,
+ tokio::task::JoinHandle<()>,
),
Box<dyn std::error::Error>,
> {
let url = url::Url::parse(url.as_str())?;
- let (ws_stream, _response) = connect_async(url).await?;
+ let (ws_stream, _response) =
+ tokio_tungstenite::connect_async(url).await?;
let (mut write, read) = ws_stream.split();
write
- .send(Message::Text(
- "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(),
+ .send(tokio_tungstenite::tungstenite::Message::Text(
+ "{\"protocol\":\"messagepack\",\"version\":1}\x1e".to_string(),
))
.await
.unwrap();
let read_future = async move {
- read.map(|message| {
- (message, sending_channels.clone())
- }).for_each(|(message, a)| async move {
- let a = a.read().await;
-
+ let sending_channels = &sending_channels;
+ read.for_each(|message| async move {
match message {
- Ok(Message::Binary(binary)) => {
- let msgpack = parse_messagepack(&binary);
- if let Some(msg) = msgpack {
- let channels = a.as_slice();
- for channel in channels {
- let res = channel.send(msg);
- if res.is_err() {
- eprintln!("error sending websocket message to channel");
- }
+ Ok(message) => {
+ if let Some(message) = parse_message(message) {
+ let sending_channels = sending_channels.read().await;
+ let sending_channels = sending_channels.as_slice();
+ for channel in sending_channels {
+ channel.send(message).unwrap();
}
}
- },
+ }
Err(e) => {
eprintln!("websocket error: {e:?}");
- },
- _ => {}
+ }
}
- }).await;
+ })
+ .await;
};
Ok((write, tokio::spawn(read_future)))
}
+
+fn parse_message(
+ message: tokio_tungstenite::tungstenite::Message,
+) -> Option<Message> {
+ let tokio_tungstenite::tungstenite::Message::Binary(data) = message
+ else {
+ return None;
+ };
+
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
+ let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
+
+ let unpacked_messagepack =
+ rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
+
+ let unpacked_message = unpacked_messagepack.as_array()?;
+ let message_type = unpacked_message.get(0)?.as_u64()?;
+ // invocation
+ if message_type != 1 {
+ return None;
+ }
+ let target = unpacked_message.get(3)?.as_str()?;
+ if target != "ReceiveMessage" {
+ return None;
+ }
+
+ let args = unpacked_message.get(4)?.as_array()?;
+ let map = args.get(0)?.as_map()?;
+ for (k, v) in map {
+ if k.as_str()? == "Type" {
+ let ty = v.as_i64()?;
+ return match ty {
+ 11 => Some(Message::Logout),
+ _ => Some(Message::Sync),
+ };
+ }
+ }
+
+ None
+}