diff options
-rw-r--r-- | src/cmd/stream.rs | 5 | ||||
-rw-r--r-- | src/cmd/watch.rs | 5 | ||||
-rw-r--r-- | src/config.rs | 26 | ||||
-rw-r--r-- | src/config/wizard.rs | 159 | ||||
-rw-r--r-- | src/dirs.rs | 10 | ||||
-rw-r--r-- | src/error.rs | 9 |
6 files changed, 199 insertions, 15 deletions
diff --git a/src/cmd/stream.rs b/src/cmd/stream.rs index 402f6e7..ed9b980 100644 --- a/src/cmd/stream.rs +++ b/src/cmd/stream.rs @@ -95,8 +95,11 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { } pub fn config( - config: Option<config::Config>, + mut config: Option<config::Config>, ) -> Result<Box<dyn crate::config::Config>> { + if config.is_none() { + config = crate::config::wizard::run()?; + } let config: Config = if let Some(config) = config { config .try_into() diff --git a/src/cmd/watch.rs b/src/cmd/watch.rs index 202e570..f9fc7c3 100644 --- a/src/cmd/watch.rs +++ b/src/cmd/watch.rs @@ -87,8 +87,11 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { } pub fn config( - config: Option<config::Config>, + mut config: Option<config::Config>, ) -> Result<Box<dyn crate::config::Config>> { + if config.is_none() { + config = crate::config::wizard::run()?; + } let config: Config = if let Some(config) = config { config .try_into() diff --git a/src/config.rs b/src/config.rs index b954969..6c47a71 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,8 @@ use serde::de::Deserialize as _; use std::convert::TryFrom as _; use std::net::ToSocketAddrs as _; +pub mod wizard; + const CONFIG_FILENAME: &str = "config.toml"; const ALLOWED_LOGIN_METHODS_OPTION: &str = "allowed-login-methods"; @@ -47,17 +49,21 @@ pub fn config( } Some(filename.to_path_buf()) } else { - crate::dirs::Dirs::new().config_file(CONFIG_FILENAME) + crate::dirs::Dirs::new().config_file(CONFIG_FILENAME, true) }; - if let Some(config_filename) = config_filename { - let mut config = config::Config::default(); - config - .merge(config::File::from(config_filename)) - .context(crate::error::ParseConfigFile)?; - Ok(Some(config)) - } else { - Ok(None) - } + config_filename + .map(|config_filename| config_from_filename(&config_filename)) + .transpose() +} + +fn config_from_filename( + filename: &std::path::Path, +) -> Result<config::Config> { + let mut config = config::Config::default(); + config + .merge(config::File::from(filename)) + .context(crate::error::ParseConfigFile)?; + Ok(config) } #[derive(serde::Deserialize, Debug)] diff --git a/src/config/wizard.rs b/src/config/wizard.rs new file mode 100644 index 0000000..b13d0e0 --- /dev/null +++ b/src/config/wizard.rs @@ -0,0 +1,159 @@ +use crate::prelude::*; +use std::io::Write as _; + +pub fn run() -> Result<Option<config::Config>> { + println!("No configuration file found."); + let run_wizard = prompt( + "Would you like me to ask you some questions to generate one?", + )?; + if !run_wizard { + let shouldnt_touch = prompt( + "Would you like me to ask this question again in the future?", + )?; + if !shouldnt_touch { + touch_config_file()?; + } + return Ok(None); + } + + let connect_address = + prompt_addr("Which server would you like to connect to?")?; + let tls = prompt("Does this server require a TLS connection?")?; + let auth_type = prompt_auth_type( + "How would you like to authenticate to this server?", + )?; + + write_config_file(&connect_address, tls, &auth_type).and_then( + |config_filename| { + Some(super::config_from_filename(&config_filename)).transpose() + }, + ) +} + +fn touch_config_file() -> Result<()> { + let config_filename = crate::dirs::Dirs::new() + .config_file(super::CONFIG_FILENAME, false) + .unwrap(); + std::fs::File::create(config_filename.clone()).context( + crate::error::CreateFileSync { + filename: config_filename.to_string_lossy(), + }, + )?; + Ok(()) +} + +fn write_config_file( + connect_address: &str, + tls: bool, + auth_type: &str, +) -> Result<std::path::PathBuf> { + let contents = format!( + r#"[client] +connect_address = "{}" +tls = {} +auth = "{}" +"#, + connect_address, tls, auth_type + ); + let config_filename = crate::dirs::Dirs::new() + .config_file(super::CONFIG_FILENAME, false) + .unwrap(); + let mut file = std::fs::File::create(config_filename.clone()).context( + crate::error::CreateFileSync { + filename: config_filename.to_string_lossy(), + }, + )?; + file.write_all(contents.as_bytes()) + .context(crate::error::WriteFileSync)?; + Ok(config_filename) +} + +fn prompt(msg: &str) -> Result<bool> { + print!("{} [y/n]: ", msg); + std::io::stdout() + .flush() + .context(crate::error::FlushTerminal)?; + let mut response = String::new(); + std::io::stdin() + .read_line(&mut response) + .context(crate::error::ReadTerminal)?; + + loop { + match response.trim() { + "y" | "yes" => { + return Ok(true); + } + "n" | "no" => { + return Ok(false); + } + _ => { + print!("Please answer [y]es or [n]o: "); + std::io::stdout() + .flush() + .context(crate::error::FlushTerminal)?; + std::io::stdin() + .read_line(&mut response) + .context(crate::error::ReadTerminal)?; + } + } + } +} + +fn prompt_addr(msg: &str) -> Result<String> { + loop { + print!("{} [addr:port]: ", msg); + std::io::stdout() + .flush() + .context(crate::error::FlushTerminal)?; + let mut response = String::new(); + std::io::stdin() + .read_line(&mut response) + .context(crate::error::ReadTerminal)?; + + match response.trim() { + addr if addr.contains(':') => { + match super::to_connect_address(addr) { + Ok(..) => return Ok(addr.to_string()), + _ => { + println!("Couldn't parse '{}'.", addr); + } + }; + } + _ => { + println!("Please include a port number."); + } + } + } +} + +fn prompt_auth_type(msg: &str) -> Result<String> { + let auth_type_names: Vec<_> = crate::protocol::AuthType::iter() + .map(crate::protocol::AuthType::name) + .collect(); + + loop { + println!("{}", msg); + println!("Options are:"); + for (i, name) in auth_type_names.iter().enumerate() { + println!("{}: {}", i + 1, name); + } + print!("Choose [1-{}]: ", auth_type_names.len()); + std::io::stdout() + .flush() + .context(crate::error::FlushTerminal)?; + let mut response = String::new(); + std::io::stdin() + .read_line(&mut response) + .context(crate::error::ReadTerminal)?; + + let num: Option<usize> = response.trim().parse().ok(); + if let Some(num) = num { + if num > 0 && num <= auth_type_names.len() { + let name = auth_type_names[num - 1]; + return Ok(name.to_string()); + } + } + + println!("Invalid response '{}'", response.trim()); + } +} diff --git a/src/dirs.rs b/src/dirs.rs index feea8e4..937cfcb 100644 --- a/src/dirs.rs +++ b/src/dirs.rs @@ -32,16 +32,20 @@ impl Dirs { .map(directories::ProjectDirs::config_dir) } - pub fn config_file(&self, name: &str) -> Option<std::path::PathBuf> { + pub fn config_file( + &self, + name: &str, + must_exist: bool, + ) -> Option<std::path::PathBuf> { if let Some(config_dir) = self.config_dir() { let file = config_dir.join(name); - if file.exists() { + if !must_exist || file.exists() { return Some(file); } } let file = self.global_config_dir().join(name); - if file.exists() { + if !must_exist || file.exists() { return Some(file); } diff --git a/src/error.rs b/src/error.rs index 2634472..1d14250 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,6 +65,12 @@ pub enum Error { source: tokio::io::Error, }, + #[snafu(display("failed to create file {}: {}", filename, source))] + CreateFileSync { + filename: String, + source: std::io::Error, + }, + #[snafu(display("received EOF from server"))] EOF, @@ -418,6 +424,9 @@ pub enum Error { #[snafu(display("failed to write to file: {}", source))] WriteFile { source: tokio::io::Error }, + #[snafu(display("failed to write to file: {}", source))] + WriteFileSync { source: std::io::Error }, + #[snafu(display("{}", source))] WriteMessageWithTimeout { #[snafu(source(from(tokio::timer::timeout::Error<Error>, Box::new)))] |