aboutsummaryrefslogtreecommitdiffstats
path: root/src/bin/rbw-agent/notifications.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/bin/rbw-agent/notifications.rs')
-rw-r--r--src/bin/rbw-agent/notifications.rs174
1 files changed, 174 insertions, 0 deletions
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
new file mode 100644
index 0000000..8176603
--- /dev/null
+++ b/src/bin/rbw-agent/notifications.rs
@@ -0,0 +1,174 @@
+use futures_util::{SinkExt as _, StreamExt as _};
+
+#[derive(Clone, Copy, Debug)]
+pub enum Message {
+ Sync,
+ Logout,
+}
+
+pub struct Handler {
+ write: Option<
+ futures::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ tokio_tungstenite::tungstenite::Message,
+ >,
+ >,
+ read_handle: Option<tokio::task::JoinHandle<()>>,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>,
+ >,
+}
+
+impl Handler {
+ pub fn new() -> Self {
+ Self {
+ write: None,
+ read_handle: None,
+ sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(
+ Vec::new(),
+ )),
+ }
+ }
+
+ pub async fn connect(
+ &mut self,
+ url: String,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ if self.is_connected() {
+ self.disconnect().await?;
+ }
+
+ let (write, read_handle) =
+ subscribe_to_notifications(url, self.sending_channels.clone())
+ .await?;
+
+ self.write = Some(write);
+ self.read_handle = Some(read_handle);
+ Ok(())
+ }
+
+ pub fn is_connected(&self) -> bool {
+ self.write.is_some()
+ && self.read_handle.is_some()
+ && !self.read_handle.as_ref().unwrap().is_finished()
+ }
+
+ pub async fn disconnect(
+ &mut self,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ self.sending_channels.write().await.clear();
+ if let Some(mut write) = self.write.take() {
+ write
+ .send(tokio_tungstenite::tungstenite::Message::Close(None))
+ .await?;
+ write.close().await?;
+ self.read_handle.take().unwrap().await?;
+ }
+ self.write = None;
+ self.read_handle = None;
+ Ok(())
+ }
+
+ pub async fn get_channel(
+ &mut self,
+ ) -> tokio::sync::mpsc::UnboundedReceiver<Message> {
+ let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
+ self.sending_channels.write().await.push(tx);
+ rx
+ }
+}
+
+async fn subscribe_to_notifications(
+ url: String,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<Message>>>,
+ >,
+) -> Result<
+ (
+ futures_util::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ tokio_tungstenite::tungstenite::Message,
+ >,
+ tokio::task::JoinHandle<()>,
+ ),
+ Box<dyn std::error::Error>,
+> {
+ let url = url::Url::parse(url.as_str())?;
+ let (ws_stream, _response) =
+ tokio_tungstenite::connect_async(url).await?;
+ let (mut write, read) = ws_stream.split();
+
+ write
+ .send(tokio_tungstenite::tungstenite::Message::Text(
+ "{\"protocol\":\"messagepack\",\"version\":1}\x1e".to_string(),
+ ))
+ .await
+ .unwrap();
+
+ let read_future = async move {
+ let sending_channels = &sending_channels;
+ read.for_each(|message| async move {
+ match message {
+ Ok(message) => {
+ if let Some(message) = parse_message(message) {
+ let sending_channels = sending_channels.read().await;
+ let sending_channels = sending_channels.as_slice();
+ for channel in sending_channels {
+ channel.send(message).unwrap();
+ }
+ }
+ }
+ Err(e) => {
+ eprintln!("websocket error: {e:?}");
+ }
+ }
+ })
+ .await;
+ };
+
+ Ok((write, tokio::spawn(read_future)))
+}
+
+fn parse_message(
+ message: tokio_tungstenite::tungstenite::Message,
+) -> Option<Message> {
+ let tokio_tungstenite::tungstenite::Message::Binary(data) = message
+ else {
+ return None;
+ };
+
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
+ let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
+
+ let unpacked_messagepack =
+ rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
+
+ let unpacked_message = unpacked_messagepack.as_array()?;
+ let message_type = unpacked_message.first()?.as_u64()?;
+ // invocation
+ if message_type != 1 {
+ return None;
+ }
+ let target = unpacked_message.get(3)?.as_str()?;
+ if target != "ReceiveMessage" {
+ return None;
+ }
+
+ let args = unpacked_message.get(4)?.as_array()?;
+ let map = args.first()?.as_map()?;
+ for (k, v) in map {
+ if k.as_str()? == "Type" {
+ let ty = v.as_i64()?;
+ return match ty {
+ 11 => Some(Message::Logout),
+ _ => Some(Message::Sync),
+ };
+ }
+ }
+
+ None
+}