diff options
Diffstat (limited to 'src/config.rs')
-rw-r--r-- | src/config.rs | 340 |
1 files changed, 260 insertions, 80 deletions
diff --git a/src/config.rs b/src/config.rs index 5fed3d5..b59bbcf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,7 +5,7 @@ use std::net::ToSocketAddrs as _; const DEFAULT_LISTEN_ADDRESS: &str = "127.0.0.1:4144"; const DEFAULT_CONNECT_ADDRESS: &str = "127.0.0.1:4144"; -const DEFAULT_CONNECTION_BUFFER_SIZE: usize = 4 * 1024 * 1024; +const DEFAULT_BUFFER_SIZE: usize = 4 * 1024 * 1024; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); const DEFAULT_AUTH_TYPE: crate::protocol::AuthType = @@ -21,25 +21,90 @@ pub trait Config: std::fmt::Debug { fn run(&self) -> Result<()>; } -pub fn listen_address<'a, D>( +#[derive(serde::Deserialize, Debug)] +pub struct Client { + #[serde(deserialize_with = "auth_type", default = "default_auth_type")] + pub auth: crate::protocol::AuthType, + + #[serde(default = "default_username")] + pub username: Option<String>, + + #[serde( + deserialize_with = "connect_address", + default = "default_connect_address" + )] + pub connect_address: (String, std::net::SocketAddr), + + #[serde(default = "default_tls")] + pub tls: bool, +} + +impl Client { + pub fn host(&self) -> &str { + &self.connect_address.0 + } + + pub fn addr(&self) -> &std::net::SocketAddr { + &self.connect_address.1 + } + + pub fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("login-recurse-center") { + self.auth = crate::protocol::AuthType::RecurseCenter; + } + if matches.is_present("login-plain") { + let username = matches + .value_of("login-plain") + .map(std::string::ToString::to_string); + self.auth = crate::protocol::AuthType::Plain; + self.username = username; + } + if matches.is_present("address") { + let address = matches.value_of("address").unwrap(); + self.connect_address = to_connect_address(address)?; + } + if matches.is_present("tls") { + self.tls = true; + } + Ok(()) + } +} + +impl Default for Client { + fn default() -> Self { + Self { + auth: default_auth_type(), + username: default_username(), + connect_address: default_connect_address(), + tls: default_tls(), + } + } +} + +fn auth_type<'a, D>( deserializer: D, -) -> std::result::Result<std::net::SocketAddr, D::Error> +) -> std::result::Result<crate::protocol::AuthType, D::Error> where D: serde::de::Deserializer<'a>, { - to_listen_address(&<String>::deserialize(deserializer)?) - .map_err(serde::de::Error::custom) + crate::protocol::AuthType::try_from( + <String>::deserialize(deserializer)?.as_ref(), + ) + .map_err(serde::de::Error::custom) } -pub fn default_listen_address() -> std::net::SocketAddr { - to_listen_address(DEFAULT_LISTEN_ADDRESS).unwrap() +fn default_auth_type() -> crate::protocol::AuthType { + DEFAULT_AUTH_TYPE } -pub fn to_listen_address(address: &str) -> Result<std::net::SocketAddr> { - address.parse().context(crate::error::ParseAddr) +fn default_username() -> Option<String> { + std::env::var("USER").ok() } -pub fn connect_address<'a, D>( +fn connect_address<'a, D>( deserializer: D, ) -> std::result::Result<(String, std::net::SocketAddr), D::Error> where @@ -49,12 +114,12 @@ where .map_err(serde::de::Error::custom) } -pub fn default_connect_address() -> (String, std::net::SocketAddr) { +fn default_connect_address() -> (String, std::net::SocketAddr) { to_connect_address(DEFAULT_CONNECT_ADDRESS).unwrap() } // XXX this does a blocking dns lookup - should try to find an async version -pub fn to_connect_address( +fn to_connect_address( address: &str, ) -> Result<(String, std::net::SocketAddr)> { let mut address_parts = address.split(':'); @@ -72,47 +137,131 @@ pub fn to_connect_address( Ok((host.to_string(), socket_addr)) } -pub fn default_connection_buffer_size() -> usize { - DEFAULT_CONNECTION_BUFFER_SIZE +fn default_tls() -> bool { + DEFAULT_TLS } -pub fn read_timeout<'a, D>( +#[derive(serde::Deserialize, Debug)] +pub struct Server { + #[serde( + deserialize_with = "listen_address", + default = "default_listen_address" + )] + pub listen_address: std::net::SocketAddr, + + #[serde(default = "default_buffer_size")] + pub buffer_size: usize, + + #[serde( + rename = "read_timeout_secs", + deserialize_with = "read_timeout", + default = "default_read_timeout" + )] + pub read_timeout: std::time::Duration, + + pub tls_identity_file: Option<String>, + + #[serde( + deserialize_with = "allowed_login_methods", + default = "default_allowed_login_methods" + )] + pub allowed_login_methods: + std::collections::HashSet<crate::protocol::AuthType>, +} + +impl Server { + pub fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("address") { + self.listen_address = matches + .value_of("address") + .unwrap() + .parse() + .context(crate::error::ParseAddr)?; + } + if matches.is_present("buffer-size") { + let s = matches.value_of("buffer-size").unwrap(); + self.buffer_size = s + .parse() + .context(crate::error::ParseBufferSize { input: s })?; + } + if matches.is_present("read-timeout") { + let s = matches.value_of("read-timeout").unwrap(); + self.read_timeout = s + .parse() + .map(std::time::Duration::from_secs) + .context(crate::error::ParseReadTimeout { input: s })?; + } + if matches.is_present("tls-identity-file") { + self.tls_identity_file = Some( + matches.value_of("tls-identity-file").unwrap().to_string(), + ); + } + if matches.is_present("allowed-login-methods") { + self.allowed_login_methods = matches + .values_of("allowed-login-methods") + .unwrap() + .map(crate::protocol::AuthType::try_from) + .collect::<Result< + std::collections::HashSet<crate::protocol::AuthType>, + >>()?; + } + Ok(()) + } +} + +impl Default for Server { + fn default() -> Self { + Self { + listen_address: default_listen_address(), + buffer_size: default_buffer_size(), + read_timeout: default_read_timeout(), + tls_identity_file: None, + allowed_login_methods: default_allowed_login_methods(), + } + } +} + +fn listen_address<'a, D>( deserializer: D, -) -> std::result::Result<std::time::Duration, D::Error> +) -> std::result::Result<std::net::SocketAddr, D::Error> where D: serde::de::Deserializer<'a>, { - Ok(std::time::Duration::from_secs(u64::deserialize( - deserializer, - )?)) -} - -pub fn default_read_timeout() -> std::time::Duration { - DEFAULT_READ_TIMEOUT + to_listen_address(&<String>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) } -pub fn default_tls() -> bool { - DEFAULT_TLS +fn default_listen_address() -> std::net::SocketAddr { + to_listen_address(DEFAULT_LISTEN_ADDRESS).unwrap() } -pub fn default_command() -> String { - std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string()) +fn to_listen_address(address: &str) -> Result<std::net::SocketAddr> { + address.parse().context(crate::error::ParseAddr) } -pub fn default_args() -> Vec<String> { - vec![] +fn default_buffer_size() -> usize { + DEFAULT_BUFFER_SIZE } -pub fn default_ttyrec_filename() -> String { - DEFAULT_TTYREC_FILENAME.to_string() +fn read_timeout<'a, D>( + deserializer: D, +) -> std::result::Result<std::time::Duration, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + Ok(std::time::Duration::from_secs(u64::deserialize( + deserializer, + )?)) } -pub fn default_allowed_login_methods( -) -> std::collections::HashSet<crate::protocol::AuthType> { - crate::protocol::AuthType::iter().collect() +fn default_read_timeout() -> std::time::Duration { + DEFAULT_READ_TIMEOUT } -pub fn allowed_login_methods<'a, D>( +fn allowed_login_methods<'a, D>( deserializer: D, ) -> std::result::Result< std::collections::HashSet<crate::protocol::AuthType>, @@ -169,61 +318,92 @@ where .collect() } -pub fn auth<'a, D>( - deserializer: D, -) -> std::result::Result<crate::protocol::Auth, D::Error> -where - D: serde::de::Deserializer<'a>, -{ - LoginType::deserialize(deserializer).and_then(|login_type| { - match login_type.login_type { - crate::protocol::AuthType::Plain => login_type - .username +fn default_allowed_login_methods( +) -> std::collections::HashSet<crate::protocol::AuthType> { + crate::protocol::AuthType::iter().collect() +} + +#[derive(serde::Deserialize, Debug)] +pub struct Command { + #[serde(default = "default_buffer_size")] + pub buffer_size: usize, + + #[serde(skip, default = "default_command")] + pub command: String, + + #[serde(skip, default = "default_args")] + pub args: Vec<String>, +} + +impl Command { + pub fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("buffer-size") { + let buffer_size = matches.value_of("buffer-size").unwrap(); + self.buffer_size = buffer_size.parse().context( + crate::error::ParseBufferSize { input: buffer_size }, + )?; + } + if matches.is_present("command") { + self.command = matches.value_of("command").unwrap().to_string(); + } + if matches.is_present("args") { + self.args = matches + .values_of("args") + .unwrap() .map(std::string::ToString::to_string) - .or_else(default_username) - .ok_or_else(|| Error::CouldntFindUsername) - .map(|username| crate::protocol::Auth::Plain { username }) - .map_err(serde::de::Error::custom), - crate::protocol::AuthType::RecurseCenter => { - Ok(crate::protocol::Auth::RecurseCenter { - id: login_type.id.map(std::string::ToString::to_string), - }) - } + .collect(); } - }) + Ok(()) + } } -pub fn default_auth() -> crate::protocol::Auth { - let username = default_username() - .ok_or_else(|| Error::CouldntFindUsername) - .unwrap(); - crate::protocol::Auth::Plain { username } +impl Default for Command { + fn default() -> Self { + Self { + buffer_size: default_buffer_size(), + command: default_command(), + args: default_args(), + } + } } -#[derive(serde::Deserialize)] -struct LoginType<'a> { - #[serde(deserialize_with = "auth_type", default = "default_auth_type")] - login_type: crate::protocol::AuthType, - username: Option<&'a str>, - id: Option<&'a str>, +fn default_command() -> String { + std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string()) } -fn auth_type<'a, D>( - deserializer: D, -) -> std::result::Result<crate::protocol::AuthType, D::Error> -where - D: serde::de::Deserializer<'a>, -{ - crate::protocol::AuthType::try_from( - <String>::deserialize(deserializer)?.as_ref(), - ) - .map_err(serde::de::Error::custom) +fn default_args() -> Vec<String> { + vec![] } -fn default_auth_type() -> crate::protocol::AuthType { - DEFAULT_AUTH_TYPE +#[derive(serde::Deserialize, Debug)] +pub struct Ttyrec { + #[serde(default = "default_ttyrec_filename")] + pub filename: String, } -fn default_username() -> Option<String> { - std::env::var("USER").ok() +impl Ttyrec { + pub fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("filename") { + self.filename = matches.value_of("filename").unwrap().to_string(); + } + Ok(()) + } +} + +impl Default for Ttyrec { + fn default() -> Self { + Self { + filename: default_ttyrec_filename(), + } + } +} + +fn default_ttyrec_filename() -> String { + DEFAULT_TTYREC_FILENAME.to_string() } |