aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2023-07-16 16:20:01 -0400
committerGitHub <noreply@github.com>2023-07-16 16:20:01 -0400
commitb06eab0609451ec449a88ed5141a658e16197eb0 (patch)
tree278750c6c431c6fce9092f7890c0a8bad466f513
parentc8c99cf06fdc45568fc10b98257963ae63e9c486 (diff)
parent389655d8f76b49a8a391deda28cf75bd99d17a96 (diff)
downloadrbw-b06eab0609451ec449a88ed5141a658e16197eb0.tar.gz
rbw-b06eab0609451ec449a88ed5141a658e16197eb0.zip
Merge branch 'main' into result-to-clipboard
-rw-r--r--Cargo.lock145
-rw-r--r--Cargo.toml4
-rwxr-xr-xbin/rbw-pinentry-keyring72
-rw-r--r--src/api.rs2
-rw-r--r--src/bin/rbw-agent/actions.rs49
-rw-r--r--src/bin/rbw-agent/agent.rs54
-rw-r--r--src/bin/rbw-agent/main.rs1
-rw-r--r--src/bin/rbw-agent/notifications.rs187
8 files changed, 493 insertions, 21 deletions
diff --git a/Cargo.lock b/Cargo.lock
index ee40de6..0ca3587 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -81,6 +81,12 @@ checksum = "23ce669cd6c8588f79e15cf450314f9638f967fc5770ff1c7c1deb0925ea7cfa"
[[package]]
name = "base64"
+version = "0.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
+
+[[package]]
+name = "base64"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
@@ -425,58 +431,87 @@ dependencies = [
]
[[package]]
+name = "futures"
+version = "0.3.28"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-executor",
+ "futures-io",
+ "futures-sink",
+ "futures-task",
+ "futures-util",
+]
+
+[[package]]
name = "futures-channel"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
+checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
dependencies = [
"futures-core",
+ "futures-sink",
]
[[package]]
name = "futures-core"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
+checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
+
+[[package]]
+name = "futures-executor"
+version = "0.3.28"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+]
[[package]]
name = "futures-io"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
+checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]]
name = "futures-macro"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
+checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [
"proc-macro2",
"quote",
- "syn 1.0.109",
+ "syn 2.0.10",
]
[[package]]
name = "futures-sink"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2"
+checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
[[package]]
name = "futures-task"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
+checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
[[package]]
name = "futures-util"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
+checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [
+ "futures-channel",
"futures-core",
"futures-io",
"futures-macro",
+ "futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
@@ -1055,6 +1090,12 @@ dependencies = [
]
[[package]]
+name = "paste"
+version = "1.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
+
+[[package]]
name = "pbkdf2"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1183,7 +1224,7 @@ dependencies = [
"arrayvec",
"async-trait",
"base32",
- "base64",
+ "base64 0.21.0",
"block-padding",
"cbc",
"clap",
@@ -1192,6 +1233,8 @@ dependencies = [
"daemonize",
"directories",
"env_logger",
+ "futures",
+ "futures-channel",
"futures-util",
"hkdf",
"hmac",
@@ -1205,6 +1248,7 @@ dependencies = [
"rand",
"region",
"reqwest",
+ "rmpv",
"rsa",
"serde",
"serde_json",
@@ -1218,6 +1262,7 @@ dependencies = [
"thiserror",
"tokio",
"tokio-stream",
+ "tokio-tungstenite",
"totp-lite",
"url",
"uuid",
@@ -1279,7 +1324,7 @@ version = "0.11.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ba30cc2c0cd02af1222ed216ba659cdb2f879dfe3181852fe7c50b1d0005949"
dependencies = [
- "base64",
+ "base64 0.21.0",
"bytes",
"encoding_rs",
"futures-core",
@@ -1328,6 +1373,27 @@ dependencies = [
]
[[package]]
+name = "rmp"
+version = "0.8.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "44519172358fd6d58656c86ab8e7fbc9e1490c3e8f14d35ed78ca0dd07403c9f"
+dependencies = [
+ "byteorder",
+ "num-traits",
+ "paste",
+]
+
+[[package]]
+name = "rmpv"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "de8813b3a2f95c5138fe5925bfb8784175d88d6bff059ba8ce090aa891319754"
+dependencies = [
+ "num-traits",
+ "rmp",
+]
+
+[[package]]
name = "rsa"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1391,7 +1457,7 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [
- "base64",
+ "base64 0.21.0",
]
[[package]]
@@ -1816,6 +1882,22 @@ dependencies = [
]
[[package]]
+name = "tokio-tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd"
+dependencies = [
+ "futures-util",
+ "log",
+ "rustls",
+ "rustls-native-certs",
+ "tokio",
+ "tokio-rustls",
+ "tungstenite",
+ "webpki",
+]
+
+[[package]]
name = "tokio-util"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1874,6 +1956,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
+name = "tungstenite"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788"
+dependencies = [
+ "base64 0.13.1",
+ "byteorder",
+ "bytes",
+ "http",
+ "httparse",
+ "log",
+ "rand",
+ "rustls",
+ "sha1",
+ "thiserror",
+ "url",
+ "utf-8",
+ "webpki",
+]
+
+[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1934,6 +2037,12 @@ dependencies = [
]
[[package]]
+name = "utf-8"
+version = "0.7.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
+
+[[package]]
name = "uuid"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index e6cd8c3..fa19acd 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -27,6 +27,8 @@ clap_complete = "4.1.5"
daemonize = "0.5.0"
directories = "5.0.0"
env_logger = "0.10.0"
+futures = "0.3.28"
+futures-channel = "0.3.28"
futures-util = "0.3.27"
hkdf = "0.12.3"
hmac = { version = "0.12.1", features = ["std"] }
@@ -58,6 +60,8 @@ url = "2.3.1"
uuid = { version = "1.3.0", features = ["v4"] }
zeroize = "1.5.7"
copypasta = "0.8.2"
+rmpv = "1.0.0"
+tokio-tungstenite = { version = "0.18.0", features = ["rustls-tls-native-roots"] }
[package.metadata.deb]
depends = "pinentry"
diff --git a/bin/rbw-pinentry-keyring b/bin/rbw-pinentry-keyring
new file mode 100755
index 0000000..1626853
--- /dev/null
+++ b/bin/rbw-pinentry-keyring
@@ -0,0 +1,72 @@
+#!/bin/bash
+
+[[ -z "${RBW_PROFILE}" ]] && rbw_profile='rbw' || rbw_profile="rbw-${RBW_PROFILE}"
+
+set -eEuo pipefail
+
+function help() {
+ cat <<EOHELP
+Use this script as pinentry to store master password for rbw into your keyring
+
+Usage
+- run "rbw-pinentry-keyring setup" once to save master password to keyring
+- add "rbw-pinentry-keyring" as "pinentry" in rbw config (${XDG_CONFIG_HOME}/rbw/config.json)
+- use rbw as normal
+Notes
+- needs "secret-tool" to access keyring
+- setup tested with pinentry-gnome3, but you can run the "secret-tool store"-command manually as well
+- master passwords are stored into the keyring as plaintext, so secure your keyring appropriately
+- supports multiple profiles, simply set RBW_PROFILE during setup
+- can easily be rewritten to use other backends than keyring by setting the "secret_value"-variable
+EOHELP
+}
+
+function setup() {
+ cmd="SETTITLE rbw\n"
+ cmd+="SETPROMPT Master Password\n"
+ cmd+="SETDESC Please enter the master password for '$rbw_profile'\n"
+ cmd+="GETPIN\n"
+ password="$(printf "$cmd" | pinentry | grep -E "^D " | cut -d' ' -f2)"
+ if [ -n "$password" ]; then
+ echo -n "$password" | secret-tool store --label="$rbw_profile master password" application rbw profile "$rbw_profile" type master_password
+ fi
+}
+
+function getpin() {
+ echo 'OK'
+
+ while IFS=' ' read -r command args ; do
+ case "$command" in
+ SETPROMPT|SETTITLE| SETDESC)
+ echo 'OK'
+ ;;
+ GETPIN)
+ secret_value="$(secret-tool lookup application rbw profile "$rbw_profile" type master_password)"
+ if [ -z "$secret_value" ]; then
+ exit 1
+ fi
+ printf 'D %s\n' "$secret_value"
+ echo 'OK'
+ ;;
+ BYE)
+ exit
+ ;;
+ *)
+ echo 'ERR Unknown command'
+ ;;
+ esac
+ done
+}
+
+command="$1"
+case "$command" in
+ -h|--help|help)
+ help
+ ;;
+ -s|--setup|setup)
+ setup
+ ;;
+ *)
+ getpin
+ ;;
+esac
diff --git a/src/api.rs b/src/api.rs
index bf608b3..fb4fc42 100644
--- a/src/api.rs
+++ b/src/api.rs
@@ -62,6 +62,7 @@ impl TwoFactorProviderType {
pub fn message(&self) -> &str {
match *self {
Self::Authenticator => "Enter the 6 digit verification code from your authenticator app.",
+ Self::Yubikey => "Insert your Yubikey and push the button.",
Self::Email => "Enter the PIN you received via email.",
_ => "Enter the code."
}
@@ -71,6 +72,7 @@ impl TwoFactorProviderType {
pub fn header(&self) -> &str {
match *self {
Self::Authenticator => "Authenticator App",
+ Self::Yubikey => "Yubikey",
Self::Email => "Email Code",
_ => "Two Factor Authentication",
}
diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs
index 7b5dc58..4d77133 100644
--- a/src/bin/rbw-agent/actions.rs
+++ b/src/bin/rbw-agent/actions.rs
@@ -1,3 +1,5 @@
+use std::f32::consts::E;
+
use anyhow::Context as _;
pub async fn register(
@@ -130,7 +132,7 @@ pub async fn login(
protected_key,
)) => {
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -148,6 +150,7 @@ pub async fn login(
Err(rbw::error::Error::TwoFactorRequired { providers }) => {
let supported_types = vec![
rbw::api::TwoFactorProviderType::Authenticator,
+ rbw::api::TwoFactorProviderType::Yubikey,
rbw::api::TwoFactorProviderType::Email,
];
@@ -169,7 +172,7 @@ pub async fn login(
)
.await?;
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -205,6 +208,11 @@ pub async fn login(
}
}
+ let err = subscribe_to_notifications(state.clone()).await.err();
+ if let Some(e) = err {
+ eprintln!("failed to subscribe to notifications: {}", e)
+ }
+
respond_ack(sock).await?;
Ok(())
@@ -655,3 +663,40 @@ async fn config_pinentry() -> anyhow::Result<String> {
let config = rbw::config::Config::load_async().await?;
Ok(config.pinentry)
}
+
+pub async fn subscribe_to_notifications(
+ state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>,
+) -> anyhow::Result<()> {
+ // access token might be out of date, so we do a sync to refresh it
+ sync(None).await?;
+
+ let config = rbw::config::Config::load_async()
+ .await
+ .context("Config is missing")?;
+ let email = config.email.clone().context("Config is missing email")?;
+ let db = rbw::db::Db::load_async(&config.server_name().as_str(), &email)
+ .await?;
+ let access_token =
+ db.access_token.context("Error getting access token")?;
+
+ let mut websocket_url = config
+ .base_url
+ .clone()
+ .expect("config is missing base url")
+ .replace("https://", "wss://")
+ + "/notifications/hub?access_token=";
+ websocket_url = websocket_url + &access_token;
+
+ let mut state = state.write().await;
+ let err = state
+ .notifications_handler
+ .connect(websocket_url)
+ .await
+ .err();
+
+ if let Some(err) = err {
+ return Err(anyhow::anyhow!(err.to_string()));
+ } else {
+ Ok(())
+ }
+}
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index 7dcab16..b88121d 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -1,6 +1,9 @@
+use aes::cipher::typenum::private::IsNotEqualPrivate;
use anyhow::Context as _;
use futures_util::StreamExt as _;
+use crate::notifications;
+
pub struct State {
pub priv_key: Option<rbw::locked::Keys>,
pub org_keys:
@@ -9,6 +12,7 @@ pub struct State {
pub timeout_duration: std::time::Duration,
pub sync_timeout: crate::timeout::Timeout,
pub sync_timeout_duration: std::time::Duration,
+ pub notifications_handler: crate::notifications::NotificationsHandler,
}
impl State {
@@ -55,6 +59,8 @@ impl Agent {
if sync_timeout_duration > std::time::Duration::ZERO {
sync_timeout.set(sync_timeout_duration);
}
+ let notifications_handler =
+ crate::notifications::NotificationsHandler::new();
Ok(Self {
timer_r,
sync_timer_r,
@@ -65,6 +71,7 @@ impl Agent {
timeout_duration,
sync_timeout,
sync_timeout_duration,
+ notifications_handler,
})),
})
}
@@ -73,11 +80,39 @@ impl Agent {
self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
+ let err =
+ crate::actions::subscribe_to_notifications(self.state.clone())
+ .await;
+ if let Err(e) = err {
+ eprintln!("failed to subscribe to notifications: {e:#}")
+ }
+
enum Event {
Request(std::io::Result<tokio::net::UnixStream>),
Timeout(()),
Sync(()),
}
+
+ let c: tokio::sync::mpsc::UnboundedReceiver<
+ notifications::NotificationMessage,
+ > = {
+ self.state
+ .write()
+ .await
+ .notifications_handler
+ .get_channel()
+ .await
+ };
+ let notifications =
+ tokio_stream::wrappers::UnboundedReceiverStream::new(c)
+ .map(|message| match message {
+ notifications::NotificationMessage::Logout => {
+ Event::Timeout(())
+ }
+ _ => Event::Sync(()),
+ })
+ .boxed();
+
let mut stream = futures_util::stream::select_all([
tokio_stream::wrappers::UnixListenerStream::new(listener)
.map(Event::Request)
@@ -92,6 +127,7 @@ impl Agent {
)
.map(Event::Sync)
.boxed(),
+ notifications,
]);
while let Some(event) = stream.next().await {
match event {
@@ -119,8 +155,24 @@ impl Agent {
Event::Sync(()) => {
// this could fail if we aren't logged in, but we don't
// care about that
+ let state = self.state.clone();
tokio::spawn(async move {
- let _ = crate::actions::sync(None).await;
+ let result = crate::actions::sync(None).await;
+ if let Err(e) = result {
+ eprintln!("failed to sync: {e:#}");
+ } else {
+ if !state
+ .write()
+ .await
+ .notifications_handler
+ .is_connected()
+ {
+ let err = crate::actions::subscribe_to_notifications(state).await;
+ if let Err(e) = err {
+ eprintln!("failed to subscribe to notifications: {e:#}")
+ }
+ }
+ }
});
self.state.write().await.set_sync_timeout();
}
diff --git a/src/bin/rbw-agent/main.rs b/src/bin/rbw-agent/main.rs
index 81eee3a..a9477df 100644
--- a/src/bin/rbw-agent/main.rs
+++ b/src/bin/rbw-agent/main.rs
@@ -19,6 +19,7 @@ mod actions;
mod agent;
mod daemon;
mod debugger;
+mod notifications;
mod sock;
mod timeout;
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
new file mode 100644
index 0000000..e8f84b0
--- /dev/null
+++ b/src/bin/rbw-agent/notifications.rs
@@ -0,0 +1,187 @@
+use futures::stream::SplitSink;
+use futures_util::{SinkExt, StreamExt};
+use tokio::{net::TcpStream, task::JoinHandle};
+use tokio_tungstenite::{
+ connect_async, tungstenite::protocol::Message, MaybeTlsStream,
+ WebSocketStream,
+};
+
+#[derive(Copy, Clone)]
+pub enum NotificationMessage {
+ SyncCipherUpdate,
+ SyncCipherCreate,
+ SyncLoginDelete,
+ SyncFolderDelete,
+ SyncCiphers,
+
+ SyncVault,
+ SyncOrgKeys,
+ SyncFolderCreate,
+ SyncFolderUpdate,
+ SyncCipherDelete,
+ SyncSettings,
+
+ Logout,
+}
+
+fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
+ let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0)? + 1;
+
+ let unpacked_messagepack =
+ rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
+ if !unpacked_messagepack.is_array() {
+ return None;
+ }
+
+ let unpacked_message = unpacked_messagepack.as_array().unwrap();
+ let message_type =
+ unpacked_message.iter().next().unwrap().as_u64().unwrap();
+
+ let message = match message_type {
+ 0 => Some(NotificationMessage::SyncCipherUpdate),
+ 1 => Some(NotificationMessage::SyncCipherCreate),
+ 2 => Some(NotificationMessage::SyncLoginDelete),
+ 3 => Some(NotificationMessage::SyncFolderDelete),
+ 4 => Some(NotificationMessage::SyncCiphers),
+ 5 => Some(NotificationMessage::SyncVault),
+ 6 => Some(NotificationMessage::SyncOrgKeys),
+ 7 => Some(NotificationMessage::SyncFolderCreate),
+ 8 => Some(NotificationMessage::SyncFolderUpdate),
+ 9 => Some(NotificationMessage::SyncCipherDelete),
+ 10 => Some(NotificationMessage::SyncSettings),
+ 11 => Some(NotificationMessage::Logout),
+ _ => None,
+ };
+
+ return message;
+}
+
+pub struct NotificationsHandler {
+ write: Option<
+ futures::stream::SplitSink<
+ tokio_tungstenite::WebSocketStream<
+ tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
+ >,
+ Message,
+ >,
+ >,
+ read_handle: Option<tokio::task::JoinHandle<()>>,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+}
+
+impl NotificationsHandler {
+ pub fn new() -> Self {
+ Self {
+ write: None,
+ read_handle: None,
+ sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(
+ Vec::new(),
+ )),
+ }
+ }
+
+ pub async fn connect(
+ &mut self,
+ url: String,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ if self.is_connected() {
+ self.disconnect().await?;
+ }
+
+ let (write, read_handle) =
+ subscribe_to_notifications(url, self.sending_channels.clone())
+ .await?;
+
+ self.write = Some(write);
+ self.read_handle = Some(read_handle);
+ return Ok(());
+ }
+
+ pub fn is_connected(&self) -> bool {
+ self.write.is_some()
+ && self.read_handle.is_some()
+ && !self.read_handle.as_ref().unwrap().is_finished()
+ }
+
+ pub async fn disconnect(
+ &mut self,
+ ) -> Result<(), Box<dyn std::error::Error>> {
+ self.sending_channels.write().await.clear();
+ if let Some(mut write) = self.write.take() {
+ write.send(Message::Close(None)).await?;
+ write.close().await?;
+ self.read_handle.take().unwrap().await?;
+ }
+ self.write = None;
+ self.read_handle = None;
+ Ok(())
+ }
+
+ pub async fn get_channel(
+ &mut self,
+ ) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> {
+ let (tx, rx) =
+ tokio::sync::mpsc::unbounded_channel::<NotificationMessage>();
+ self.sending_channels.write().await.push(tx);
+ return rx;
+ }
+}
+
+async fn subscribe_to_notifications(
+ url: String,
+ sending_channels: std::sync::Arc<
+ tokio::sync::RwLock<
+ Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>,
+ >,
+ >,
+) -> Result<
+ (
+ SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
+ JoinHandle<()>,
+ ),
+ Box<dyn std::error::Error>,
+> {
+ let url = url::Url::parse(url.as_str())?;
+ let (ws_stream, _response) = connect_async(url).await?;
+ let (mut write, read) = ws_stream.split();
+
+ write
+ .send(Message::Text(
+ "{\"protocol\":\"messagepack\",\"version\":1}\n".to_string(),
+ ))
+ .await
+ .unwrap();
+
+ let read_future = async move {
+ read.map(|message| {
+ (message, sending_channels.clone())
+ }).for_each(|(message, a)| async move {
+ let a = a.read().await;
+
+ match message {
+ Ok(Message::Binary(binary)) => {
+ let msgpack = parse_messagepack(&binary);
+ if let Some(msg) = msgpack {
+ for channel in a.iter() {
+ let res = channel.send(msg);
+ if res.is_err() {
+ eprintln!("error sending websocket message to channel");
+ }
+ }
+ }
+ },
+ Err(e) => {
+ eprintln!("websocket error: {:?}", e);
+ },
+ _ => {}
+ }
+ }).await;
+ };
+
+ return Ok((write, tokio::spawn(read_future)));
+}