aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBernd Schoolmann <mail@quexten.com>2023-04-16 13:41:52 +0200
committerBernd Schoolmann <mail@quexten.com>2023-04-16 13:41:52 +0200
commit355e17dc29244856454db3bdaeed082cf33231e6 (patch)
treeeca0a48f8816dd1c560ffd0b117f1332908b9ff8
parentd6339933d54974952721659c3de2b2871a086c1a (diff)
downloadrbw-355e17dc29244856454db3bdaeed082cf33231e6.tar.gz
rbw-355e17dc29244856454db3bdaeed082cf33231e6.zip
Restructure code
-rw-r--r--src/bin/rbw-agent/actions.rs19
-rw-r--r--src/bin/rbw-agent/agent.rs36
-rw-r--r--src/bin/rbw-agent/notifications.rs210
3 files changed, 136 insertions, 129 deletions
diff --git a/src/bin/rbw-agent/actions.rs b/src/bin/rbw-agent/actions.rs
index 7b5dc58..006f7ca 100644
--- a/src/bin/rbw-agent/actions.rs
+++ b/src/bin/rbw-agent/actions.rs
@@ -130,7 +130,7 @@ pub async fn login(
protected_key,
)) => {
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -169,7 +169,7 @@ pub async fn login(
)
.await?;
login_success(
- state,
+ state.clone(),
access_token,
refresh_token,
kdf,
@@ -205,6 +205,8 @@ pub async fn login(
}
}
+ subscribe_to_notifications(state.clone()).await.expect("could not subscribe");
+
respond_ack(sock).await?;
Ok(())
@@ -655,3 +657,16 @@ async fn config_pinentry() -> anyhow::Result<String> {
let config = rbw::config::Config::load_async().await?;
Ok(config.pinentry)
}
+
+pub async fn subscribe_to_notifications(state: std::sync::Arc<tokio::sync::RwLock<crate::agent::State>>) -> anyhow::Result<()> {
+ let config = rbw::config::Config::load_async().await.expect("Config is missing");
+ let mut websocket_url = config.base_url.clone().expect("Config is missing base url").replace("https://", "wss://") + "/notifications/hub?access_token=";
+ let email = config.email.clone().expect("Config is missing email");
+ let db = rbw::db::Db::load_async(&config.server_name().as_str(), &email).await.expect("Error loading db");
+ let access_token = db.access_token.expect("Error getting access token");
+ websocket_url = websocket_url + &access_token;
+ let mut state = state.write().await;
+ state.notifications_handler.connect(websocket_url).await.expect("Error connecting to websocket");
+
+ Ok(())
+} \ No newline at end of file
diff --git a/src/bin/rbw-agent/agent.rs b/src/bin/rbw-agent/agent.rs
index 9523c78..c025880 100644
--- a/src/bin/rbw-agent/agent.rs
+++ b/src/bin/rbw-agent/agent.rs
@@ -1,6 +1,8 @@
use anyhow::Context as _;
use futures_util::StreamExt as _;
+use crate::notifications;
+
pub struct State {
pub priv_key: Option<rbw::locked::Keys>,
pub org_keys:
@@ -9,6 +11,7 @@ pub struct State {
pub timeout_duration: std::time::Duration,
pub sync_timeout: crate::timeout::Timeout,
pub sync_timeout_duration: std::time::Duration,
+ pub notifications_handler: crate::notifications::NotificationsHandler,
}
impl State {
@@ -55,6 +58,7 @@ impl Agent {
if sync_timeout_duration > std::time::Duration::ZERO {
sync_timeout.set(sync_timeout_duration);
}
+ let notifications_handler = crate::notifications::NotificationsHandler::new();
Ok(Self {
timer_r,
sync_timer_r,
@@ -65,6 +69,7 @@ impl Agent {
timeout_duration,
sync_timeout,
sync_timeout_duration,
+ notifications_handler,
})),
})
}
@@ -73,22 +78,32 @@ impl Agent {
self,
listener: tokio::net::UnixListener,
) -> anyhow::Result<()> {
- tokio::spawn(async move {
- let config = rbw::config::Config::load_async().await.expect("Error loading config");
- let mut websocket_url = config.base_url.clone().expect("Config is missing base url").replace("https://", "wss://") + "/notifications/hub?access_token=";
- if let Some(email) = &config.email {
- let db = rbw::db::Db::load_async(&config.server_name().as_str(), email).await.expect("Error loading db");
- let access_token = db.access_token.expect("Error getting access token");
- websocket_url = websocket_url + &access_token;
- crate::notifications::subscribe_to_notifications(websocket_url).await;
- }
- });
+ crate::actions::subscribe_to_notifications(self.state.clone()).await.expect("could not subscribe");
enum Event {
Request(std::io::Result<tokio::net::UnixStream>),
Timeout(()),
Sync(()),
}
+
+ let c: tokio::sync::mpsc::UnboundedReceiver<notifications::NotificationMessage> = {
+ self.state.write().await.notifications_handler.get_channel().await
+ };
+ let notifications = tokio_stream::wrappers::UnboundedReceiverStream::new(
+ c,
+ )
+ .map(|message| {
+ match message {
+ notifications::NotificationMessage::Logout => {
+ Event::Timeout(())
+ }
+ _ => {
+ Event::Sync(())
+ }
+ }
+ })
+ .boxed();
+
let mut stream = futures_util::stream::select_all([
tokio_stream::wrappers::UnixListenerStream::new(listener)
.map(Event::Request)
@@ -103,6 +118,7 @@ impl Agent {
)
.map(Event::Sync)
.boxed(),
+ notifications,
]);
while let Some(event) = stream.next().await {
match event {
diff --git a/src/bin/rbw-agent/notifications.rs b/src/bin/rbw-agent/notifications.rs
index ffdefe9..c72fe38 100644
--- a/src/bin/rbw-agent/notifications.rs
+++ b/src/bin/rbw-agent/notifications.rs
@@ -1,17 +1,12 @@
-use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
+use futures::{stream::SplitSink};
+use tokio::{net::{TcpStream}, task::JoinHandle};
+use tokio_tungstenite::{connect_async, tungstenite::protocol::Message, WebSocketStream, MaybeTlsStream};
use futures_util::{StreamExt, SinkExt};
-struct SyncCipherUpdate {
- id: String
-}
-
-struct SyncCipherCreate {
- id: String
-}
-
-enum NotificationMessage {
- SyncCipherUpdate(SyncCipherUpdate),
- SyncCipherCreate(SyncCipherCreate),
+#[derive(Copy, Clone)]
+pub enum NotificationMessage {
+ SyncCipherUpdate,
+ SyncCipherCreate,
SyncLoginDelete,
SyncFolderDelete,
SyncCiphers,
@@ -24,59 +19,25 @@ enum NotificationMessage {
SyncSettings,
Logout,
+}
- SyncSendCreate,
- SyncSendUpdate,
- SyncSendDelete,
-
- AuthRequest,
- AuthRequestResponse,
- None,
-}
fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
- if data.len() < 2 {
- return None;
- }
-
- // the first few bytes with th 0x80 bit set, plus one byte terminating the length contain the length of the message
+ // the first few bytes with the 0x80 bit set, plus one byte terminating the length contain the length of the message
let len_buffer_length = data.iter().position(|&x| (x & 0x80) == 0 )? + 1;
- println!("len_buffer_length: {:?}", len_buffer_length);
- println!("data: {:?}", data);
- let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok().unwrap();
- println!("unpacked_messagepack: {:?}", unpacked_messagepack);
+ let unpacked_messagepack = rmpv::decode::read_value(&mut &data[len_buffer_length..]).ok()?;
if !unpacked_messagepack.is_array() {
return None;
}
+
let unpacked_message = unpacked_messagepack.as_array().unwrap();
- println!("unpacked_message: {:?}", unpacked_message);
- let message_type = unpacked_message.iter().next()?.as_u64()?;
- let message = unpacked_message.iter().skip(4).next()?.as_array()?.first()?.as_map()?;
- let payload = message.iter().filter(|x| x.0.as_str().unwrap() == "Payload").next()?.1.as_map()?;
- println!("message_type: {:?}", message_type);
- println!("payload: {:?}", payload);
+ let message_type = unpacked_message.iter().next().unwrap().as_u64().unwrap();
let message = match message_type {
- 0 => {
- let id = payload.iter().filter(|x| x.0.as_str().unwrap() == "Id").next()?.1.as_str()?;
-
- Some(NotificationMessage::SyncCipherUpdate(
- SyncCipherUpdate {
- id: id.to_string()
- }
- ))
- },
- 1 => {
- let id = payload.iter().filter(|x| x.0.as_str().unwrap() == "Id").next()?.1.as_str()?;
-
- Some(NotificationMessage::SyncCipherCreate(
- SyncCipherCreate {
- id: id.to_string()
- }
- ))
- },
+ 0 => Some(NotificationMessage::SyncCipherUpdate),
+ 1 => Some(NotificationMessage::SyncCipherCreate),
2 => Some(NotificationMessage::SyncLoginDelete),
3 => Some(NotificationMessage::SyncFolderDelete),
4 => Some(NotificationMessage::SyncCiphers),
@@ -87,80 +48,95 @@ fn parse_messagepack(data: &[u8]) -> Option<NotificationMessage> {
9 => Some(NotificationMessage::SyncCipherDelete),
10 => Some(NotificationMessage::SyncSettings),
11 => Some(NotificationMessage::Logout),
- 12 => Some(NotificationMessage::SyncSendCreate),
- 13 => Some(NotificationMessage::SyncSendUpdate),
- 14 => Some(NotificationMessage::SyncSendDelete),
- 15 => Some(NotificationMessage::AuthRequest),
- 16 => Some(NotificationMessage::AuthRequestResponse),
- 100 => Some(NotificationMessage::None),
_ => None
};
return message;
}
-pub async fn subscribe_to_notifications(url: String) {
- let url = url::Url::parse(url.as_str()).unwrap();
+pub struct NotificationsHandler {
+ write: Option<futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>,
+ read_handle: Option<tokio::task::JoinHandle<()>>,
+ sending_channels : std::sync::Arc<tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>>>,
+}
- let (ws_stream, _response) = connect_async(url).await.expect("Failed to connect");
+impl NotificationsHandler {
+ pub fn new() -> Self {
+ Self {
+ write: None,
+ read_handle: None,
+ sending_channels: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
+ }
+ }
- let (mut write, read) = ws_stream.split();
+ pub async fn connect(&mut self, url: String) -> Result<(), Box<dyn std::error::Error>> {
+ if self.is_connected() {
+ self.disconnect().await?;
+ }
- write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap();
+ //subscribe_to_notifications(url, self.sending_channels.clone()).await?;
+ let (write, read_handle) = subscribe_to_notifications(url, self.sending_channels.clone()).await?;
+
+ self.write = Some(write);
+ self.read_handle = Some(read_handle);
+ return Ok(());
+ }
+
+ pub fn is_connected(&self) -> bool {
+ self.write.is_some()
+ }
- let read_future = read.for_each(|message| async {
- match message {
- Ok(Message::Binary(binary)) => {
- let msg = parse_messagepack(&binary);
- match msg {
- Some(NotificationMessage::SyncCipherUpdate(update)) => {
- println!("Websocket sent SyncCipherUpdate for id: {:?}", update.id);
- crate::actions::sync(None).await.unwrap();
- println!("Synced")
- },
- Some(NotificationMessage::SyncCipherCreate(update)) => {
- println!("Websocket sent SyncCipherUpdate for id: {:?}", update.id);
- crate::actions::sync(None).await.unwrap();
- println!("Synced")
- },
- Some(NotificationMessage::SyncLoginDelete) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncFolderDelete) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncCiphers) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncVault) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncOrgKeys) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncFolderCreate) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncFolderUpdate) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::SyncCipherDelete) => {
- crate::actions::sync(None).await.unwrap();
- },
- Some(NotificationMessage::Logout) => {
- println!("Websocket sent Logout");
- // todo: proper logout?
- std::process::exit(0);
- },
- _ => {}
- }
- },
- Err(e) => {
- println!("websocket error: {:?}", e);
- },
- _ => {}
+ pub async fn disconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
+ self.sending_channels.write().await.clear();
+ if let Some(mut write) = self.write.take() {
+ write.send(Message::Close(None)).await?;
+ write.close().await?;
}
- });
+ Ok(())
+ }
+
+ pub async fn get_channel(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<NotificationMessage> {
+ let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NotificationMessage>();
+ self.sending_channels.write().await.push(tx);
+ return rx;
+ }
- read_future.await;
}
+
+async fn subscribe_to_notifications(url: String, sending_channels: std::sync::Arc<tokio::sync::RwLock<Vec<tokio::sync::mpsc::UnboundedSender<NotificationMessage>>>>) -> Result<(SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>, JoinHandle<()>), Box<dyn std::error::Error>> {
+ let url = url::Url::parse(url.as_str())?;
+ println!("Connecting to {}", url);
+ let (ws_stream, _response) = connect_async(url).await.expect("Failed to connect");
+ let (mut write, read) = ws_stream.split();
+
+ write.send(Message::Text("{\"protocol\":\"messagepack\",\"version\":1}\n".to_string())).await.unwrap();
+
+ let read_future = async move {
+ read.map(|message| {
+ (message, sending_channels.clone())
+ }).for_each(|(message, a)| async move {
+ let a = a.read().await;
+
+ match message {
+ Ok(Message::Binary(binary)) => {
+ if binary.len() < 4 {
+ return;
+ }
+
+ let msg1 = parse_messagepack(&binary);
+ if let Some(msg) = msg1 {
+ for channel in a.iter() {
+ let res = channel.send(msg);
+ }
+ }
+ },
+ Err(e) => {
+ println!("websocket error: {:?}", e);
+ },
+ _ => {}
+ }
+ }).await;
+ };
+
+ return Ok((write, tokio::spawn(read_future)));
+} \ No newline at end of file