diff options
-rw-r--r-- | src/agent.rs | 10 | ||||
-rw-r--r-- | src/bin/agent.rs | 82 | ||||
-rw-r--r-- | src/bin/rbw.rs | 82 |
3 files changed, 131 insertions, 43 deletions
diff --git a/src/agent.rs b/src/agent.rs index 4b9ac20..c64acc9 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,5 +1,5 @@ #[derive(serde::Serialize, serde::Deserialize, Debug)] -pub struct Message { +pub struct Request { pub tty: Option<String>, pub action: Action, } @@ -15,3 +15,11 @@ pub enum Action { // update // remove } + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(tag = "type")] +pub enum Response { + Ack, + Error { error: String }, + Decrypt { plaintext: String }, +} diff --git a/src/bin/agent.rs b/src/bin/agent.rs index f21abcc..9081988 100644 --- a/src/bin/agent.rs +++ b/src/bin/agent.rs @@ -1,4 +1,4 @@ -use tokio::io::AsyncBufReadExt as _; +use tokio::io::{AsyncBufReadExt as _, AsyncWriteExt as _}; use tokio::stream::StreamExt as _; fn make_socket() -> anyhow::Result<tokio::net::UnixListener> { @@ -12,14 +12,28 @@ fn make_socket() -> anyhow::Result<tokio::net::UnixListener> { Ok(sock) } -async fn ensure_login(state: std::sync::Arc<tokio::sync::RwLock<State>>) { +async fn send_response( + sock: &mut tokio::net::UnixStream, + res: &rbw::agent::Response, +) { + sock.write_all(serde_json::to_string(res).unwrap().as_bytes()) + .await + .unwrap(); + sock.write_all(b"\n").await.unwrap(); +} + +async fn ensure_login( + sock: &mut tokio::net::UnixStream, + state: std::sync::Arc<tokio::sync::RwLock<State>>, +) { let rstate = state.read().await; if rstate.access_token.is_none() { - login(state.clone(), None).await; // tty + login(sock, state.clone(), None).await; // tty } } async fn login( + sock: &mut tokio::net::UnixStream, state: std::sync::Arc<tokio::sync::RwLock<State>>, tty: Option<&str>, ) { @@ -36,16 +50,22 @@ async fn login( .await .unwrap(); state.priv_key = Some((enc_key, mac_key)); + + send_response(sock, &rbw::agent::Response::Ack).await; } -async fn ensure_unlock(state: std::sync::Arc<tokio::sync::RwLock<State>>) { +async fn ensure_unlock( + sock: &mut tokio::net::UnixStream, + state: std::sync::Arc<tokio::sync::RwLock<State>>, +) { let rstate = state.read().await; if rstate.priv_key.is_none() { - unlock(state.clone(), None).await; // tty + unlock(sock, state.clone(), None).await; // tty } } async fn unlock( + sock: &mut tokio::net::UnixStream, state: std::sync::Arc<tokio::sync::RwLock<State>>, tty: Option<&str>, ) { @@ -62,10 +82,15 @@ async fn unlock( .await .unwrap(); state.priv_key = Some((enc_key, mac_key)); + + send_response(sock, &rbw::agent::Response::Ack).await; } -async fn sync(state: std::sync::Arc<tokio::sync::RwLock<State>>) { - ensure_login(state.clone()).await; +async fn sync( + sock: &mut tokio::net::UnixStream, + state: std::sync::Arc<tokio::sync::RwLock<State>>, +) { + ensure_login(sock, state.clone()).await; let mut state = state.write().await; let (protected_key, ciphers) = rbw::actions::sync(state.access_token.as_ref().unwrap()) @@ -74,41 +99,46 @@ async fn sync(state: std::sync::Arc<tokio::sync::RwLock<State>>) { state.protected_key = Some(protected_key); println!("{}", serde_json::to_string(&ciphers).unwrap()); state.ciphers = ciphers; + + send_response(sock, &rbw::agent::Response::Ack).await; } async fn decrypt( + sock: &mut tokio::net::UnixStream, state: std::sync::Arc<tokio::sync::RwLock<State>>, cipherstring: &str, ) { - ensure_unlock(state.clone()).await; + ensure_unlock(sock, state.clone()).await; let state = state.read().await; let (enc_key, mac_key) = state.priv_key.as_ref().unwrap(); let cipherstring = rbw::cipherstring::CipherString::new(cipherstring).unwrap(); - let plain = cipherstring.decrypt(&enc_key, &mac_key).unwrap(); - println!("{}", String::from_utf8(plain).unwrap()); + let plaintext = + String::from_utf8(cipherstring.decrypt(&enc_key, &mac_key).unwrap()) + .unwrap(); + + send_response(sock, &rbw::agent::Response::Decrypt { plaintext }).await; } async fn handle_sock( sock: tokio::net::UnixStream, state: std::sync::Arc<tokio::sync::RwLock<State>>, ) { - let buf = tokio::io::BufStream::new(sock); - let mut lines = buf.lines(); - while let Some(line) = lines.next().await { - let line = line.unwrap(); - let msg: rbw::agent::Message = serde_json::from_str(&line).unwrap(); - match msg.action { - rbw::agent::Action::Login => { - login(state.clone(), msg.tty.as_deref()).await - } - rbw::agent::Action::Unlock => { - unlock(state.clone(), msg.tty.as_deref()).await - } - rbw::agent::Action::Sync => sync(state.clone()).await, - rbw::agent::Action::Decrypt { cipherstring } => { - decrypt(state.clone(), &cipherstring).await - } + let mut buf = tokio::io::BufStream::new(sock); + let mut line = String::new(); + buf.read_line(&mut line).await.unwrap(); + let mut sock = buf.into_inner(); + let msg: rbw::agent::Request = serde_json::from_str(&line).unwrap(); + match msg.action { + rbw::agent::Action::Login => { + login(&mut sock, state.clone(), msg.tty.as_deref()).await + } + rbw::agent::Action::Unlock => { + unlock(&mut sock, state.clone(), msg.tty.as_deref()).await + } + rbw::agent::Action::Sync => sync(&mut sock, state.clone()).await, + rbw::agent::Action::Decrypt { cipherstring } => { + decrypt(&mut sock, state.clone(), &cipherstring).await } } } diff --git a/src/bin/rbw.rs b/src/bin/rbw.rs index a110176..b309550 100644 --- a/src/bin/rbw.rs +++ b/src/bin/rbw.rs @@ -1,4 +1,4 @@ -use std::io::Write as _; +use std::io::{BufRead as _, Write as _}; fn ensure_agent() { let agent_path = std::env::var("RBW_AGENT"); @@ -16,34 +16,84 @@ fn ensure_agent() { } } -fn send(msg: &rbw::agent::Message) { - let mut sock = std::os::unix::net::UnixStream::connect( +fn connect() -> std::os::unix::net::UnixStream { + std::os::unix::net::UnixStream::connect( rbw::dirs::runtime_dir().join("socket"), ) - .unwrap(); + .unwrap() +} + +fn send( + sock: &mut std::os::unix::net::UnixStream, + msg: &rbw::agent::Request, +) { sock.write_all(serde_json::to_string(msg).unwrap().as_bytes()) .unwrap(); + sock.write_all(b"\n").unwrap(); +} + +fn recv(sock: &mut std::os::unix::net::UnixStream) -> rbw::agent::Response { + let mut buf = std::io::BufReader::new(sock); + let mut line = String::new(); + buf.read_line(&mut line).unwrap(); + serde_json::from_str(&line).unwrap() } fn login() { - send(&rbw::agent::Message { - tty: std::env::var("TTY").ok(), - action: rbw::agent::Action::Login, - }) + let mut sock = connect(); + send( + &mut sock, + &rbw::agent::Request { + tty: std::env::var("TTY").ok(), + action: rbw::agent::Action::Login, + }, + ); + let res = recv(&mut sock); + match res { + rbw::agent::Response::Ack => (), + rbw::agent::Response::Error { error } => { + panic!("failed to login: {}", error) + } + _ => panic!("unexpected message: {:?}", res), + } } fn unlock() { - send(&rbw::agent::Message { - tty: std::env::var("TTY").ok(), - action: rbw::agent::Action::Unlock, - }) + let mut sock = connect(); + send( + &mut sock, + &rbw::agent::Request { + tty: std::env::var("TTY").ok(), + action: rbw::agent::Action::Unlock, + }, + ); + let res = recv(&mut sock); + match res { + rbw::agent::Response::Ack => (), + rbw::agent::Response::Error { error } => { + panic!("failed to login: {}", error) + } + _ => panic!("unexpected message: {:?}", res), + } } fn sync() { - send(&rbw::agent::Message { - tty: std::env::var("TTY").ok(), - action: rbw::agent::Action::Sync, - }) + let mut sock = connect(); + send( + &mut sock, + &rbw::agent::Request { + tty: std::env::var("TTY").ok(), + action: rbw::agent::Action::Sync, + }, + ); + let res = recv(&mut sock); + match res { + rbw::agent::Response::Ack => (), + rbw::agent::Response::Error { error } => { + panic!("failed to login: {}", error) + } + _ => panic!("unexpected message: {:?}", res), + } } fn list() { |