aboutsummaryrefslogtreecommitdiffstats
path: root/src/config.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/config.rs')
-rw-r--r--src/config.rs340
1 files changed, 260 insertions, 80 deletions
diff --git a/src/config.rs b/src/config.rs
index 5fed3d5..b59bbcf 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -5,7 +5,7 @@ use std::net::ToSocketAddrs as _;
const DEFAULT_LISTEN_ADDRESS: &str = "127.0.0.1:4144";
const DEFAULT_CONNECT_ADDRESS: &str = "127.0.0.1:4144";
-const DEFAULT_CONNECTION_BUFFER_SIZE: usize = 4 * 1024 * 1024;
+const DEFAULT_BUFFER_SIZE: usize = 4 * 1024 * 1024;
const DEFAULT_READ_TIMEOUT: std::time::Duration =
std::time::Duration::from_secs(120);
const DEFAULT_AUTH_TYPE: crate::protocol::AuthType =
@@ -21,25 +21,90 @@ pub trait Config: std::fmt::Debug {
fn run(&self) -> Result<()>;
}
-pub fn listen_address<'a, D>(
+#[derive(serde::Deserialize, Debug)]
+pub struct Client {
+ #[serde(deserialize_with = "auth_type", default = "default_auth_type")]
+ pub auth: crate::protocol::AuthType,
+
+ #[serde(default = "default_username")]
+ pub username: Option<String>,
+
+ #[serde(
+ deserialize_with = "connect_address",
+ default = "default_connect_address"
+ )]
+ pub connect_address: (String, std::net::SocketAddr),
+
+ #[serde(default = "default_tls")]
+ pub tls: bool,
+}
+
+impl Client {
+ pub fn host(&self) -> &str {
+ &self.connect_address.0
+ }
+
+ pub fn addr(&self) -> &std::net::SocketAddr {
+ &self.connect_address.1
+ }
+
+ pub fn merge_args<'a>(
+ &mut self,
+ matches: &clap::ArgMatches<'a>,
+ ) -> Result<()> {
+ if matches.is_present("login-recurse-center") {
+ self.auth = crate::protocol::AuthType::RecurseCenter;
+ }
+ if matches.is_present("login-plain") {
+ let username = matches
+ .value_of("login-plain")
+ .map(std::string::ToString::to_string);
+ self.auth = crate::protocol::AuthType::Plain;
+ self.username = username;
+ }
+ if matches.is_present("address") {
+ let address = matches.value_of("address").unwrap();
+ self.connect_address = to_connect_address(address)?;
+ }
+ if matches.is_present("tls") {
+ self.tls = true;
+ }
+ Ok(())
+ }
+}
+
+impl Default for Client {
+ fn default() -> Self {
+ Self {
+ auth: default_auth_type(),
+ username: default_username(),
+ connect_address: default_connect_address(),
+ tls: default_tls(),
+ }
+ }
+}
+
+fn auth_type<'a, D>(
deserializer: D,
-) -> std::result::Result<std::net::SocketAddr, D::Error>
+) -> std::result::Result<crate::protocol::AuthType, D::Error>
where
D: serde::de::Deserializer<'a>,
{
- to_listen_address(&<String>::deserialize(deserializer)?)
- .map_err(serde::de::Error::custom)
+ crate::protocol::AuthType::try_from(
+ <String>::deserialize(deserializer)?.as_ref(),
+ )
+ .map_err(serde::de::Error::custom)
}
-pub fn default_listen_address() -> std::net::SocketAddr {
- to_listen_address(DEFAULT_LISTEN_ADDRESS).unwrap()
+fn default_auth_type() -> crate::protocol::AuthType {
+ DEFAULT_AUTH_TYPE
}
-pub fn to_listen_address(address: &str) -> Result<std::net::SocketAddr> {
- address.parse().context(crate::error::ParseAddr)
+fn default_username() -> Option<String> {
+ std::env::var("USER").ok()
}
-pub fn connect_address<'a, D>(
+fn connect_address<'a, D>(
deserializer: D,
) -> std::result::Result<(String, std::net::SocketAddr), D::Error>
where
@@ -49,12 +114,12 @@ where
.map_err(serde::de::Error::custom)
}
-pub fn default_connect_address() -> (String, std::net::SocketAddr) {
+fn default_connect_address() -> (String, std::net::SocketAddr) {
to_connect_address(DEFAULT_CONNECT_ADDRESS).unwrap()
}
// XXX this does a blocking dns lookup - should try to find an async version
-pub fn to_connect_address(
+fn to_connect_address(
address: &str,
) -> Result<(String, std::net::SocketAddr)> {
let mut address_parts = address.split(':');
@@ -72,47 +137,131 @@ pub fn to_connect_address(
Ok((host.to_string(), socket_addr))
}
-pub fn default_connection_buffer_size() -> usize {
- DEFAULT_CONNECTION_BUFFER_SIZE
+fn default_tls() -> bool {
+ DEFAULT_TLS
}
-pub fn read_timeout<'a, D>(
+#[derive(serde::Deserialize, Debug)]
+pub struct Server {
+ #[serde(
+ deserialize_with = "listen_address",
+ default = "default_listen_address"
+ )]
+ pub listen_address: std::net::SocketAddr,
+
+ #[serde(default = "default_buffer_size")]
+ pub buffer_size: usize,
+
+ #[serde(
+ rename = "read_timeout_secs",
+ deserialize_with = "read_timeout",
+ default = "default_read_timeout"
+ )]
+ pub read_timeout: std::time::Duration,
+
+ pub tls_identity_file: Option<String>,
+
+ #[serde(
+ deserialize_with = "allowed_login_methods",
+ default = "default_allowed_login_methods"
+ )]
+ pub allowed_login_methods:
+ std::collections::HashSet<crate::protocol::AuthType>,
+}
+
+impl Server {
+ pub fn merge_args<'a>(
+ &mut self,
+ matches: &clap::ArgMatches<'a>,
+ ) -> Result<()> {
+ if matches.is_present("address") {
+ self.listen_address = matches
+ .value_of("address")
+ .unwrap()
+ .parse()
+ .context(crate::error::ParseAddr)?;
+ }
+ if matches.is_present("buffer-size") {
+ let s = matches.value_of("buffer-size").unwrap();
+ self.buffer_size = s
+ .parse()
+ .context(crate::error::ParseBufferSize { input: s })?;
+ }
+ if matches.is_present("read-timeout") {
+ let s = matches.value_of("read-timeout").unwrap();
+ self.read_timeout = s
+ .parse()
+ .map(std::time::Duration::from_secs)
+ .context(crate::error::ParseReadTimeout { input: s })?;
+ }
+ if matches.is_present("tls-identity-file") {
+ self.tls_identity_file = Some(
+ matches.value_of("tls-identity-file").unwrap().to_string(),
+ );
+ }
+ if matches.is_present("allowed-login-methods") {
+ self.allowed_login_methods = matches
+ .values_of("allowed-login-methods")
+ .unwrap()
+ .map(crate::protocol::AuthType::try_from)
+ .collect::<Result<
+ std::collections::HashSet<crate::protocol::AuthType>,
+ >>()?;
+ }
+ Ok(())
+ }
+}
+
+impl Default for Server {
+ fn default() -> Self {
+ Self {
+ listen_address: default_listen_address(),
+ buffer_size: default_buffer_size(),
+ read_timeout: default_read_timeout(),
+ tls_identity_file: None,
+ allowed_login_methods: default_allowed_login_methods(),
+ }
+ }
+}
+
+fn listen_address<'a, D>(
deserializer: D,
-) -> std::result::Result<std::time::Duration, D::Error>
+) -> std::result::Result<std::net::SocketAddr, D::Error>
where
D: serde::de::Deserializer<'a>,
{
- Ok(std::time::Duration::from_secs(u64::deserialize(
- deserializer,
- )?))
-}
-
-pub fn default_read_timeout() -> std::time::Duration {
- DEFAULT_READ_TIMEOUT
+ to_listen_address(&<String>::deserialize(deserializer)?)
+ .map_err(serde::de::Error::custom)
}
-pub fn default_tls() -> bool {
- DEFAULT_TLS
+fn default_listen_address() -> std::net::SocketAddr {
+ to_listen_address(DEFAULT_LISTEN_ADDRESS).unwrap()
}
-pub fn default_command() -> String {
- std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string())
+fn to_listen_address(address: &str) -> Result<std::net::SocketAddr> {
+ address.parse().context(crate::error::ParseAddr)
}
-pub fn default_args() -> Vec<String> {
- vec![]
+fn default_buffer_size() -> usize {
+ DEFAULT_BUFFER_SIZE
}
-pub fn default_ttyrec_filename() -> String {
- DEFAULT_TTYREC_FILENAME.to_string()
+fn read_timeout<'a, D>(
+ deserializer: D,
+) -> std::result::Result<std::time::Duration, D::Error>
+where
+ D: serde::de::Deserializer<'a>,
+{
+ Ok(std::time::Duration::from_secs(u64::deserialize(
+ deserializer,
+ )?))
}
-pub fn default_allowed_login_methods(
-) -> std::collections::HashSet<crate::protocol::AuthType> {
- crate::protocol::AuthType::iter().collect()
+fn default_read_timeout() -> std::time::Duration {
+ DEFAULT_READ_TIMEOUT
}
-pub fn allowed_login_methods<'a, D>(
+fn allowed_login_methods<'a, D>(
deserializer: D,
) -> std::result::Result<
std::collections::HashSet<crate::protocol::AuthType>,
@@ -169,61 +318,92 @@ where
.collect()
}
-pub fn auth<'a, D>(
- deserializer: D,
-) -> std::result::Result<crate::protocol::Auth, D::Error>
-where
- D: serde::de::Deserializer<'a>,
-{
- LoginType::deserialize(deserializer).and_then(|login_type| {
- match login_type.login_type {
- crate::protocol::AuthType::Plain => login_type
- .username
+fn default_allowed_login_methods(
+) -> std::collections::HashSet<crate::protocol::AuthType> {
+ crate::protocol::AuthType::iter().collect()
+}
+
+#[derive(serde::Deserialize, Debug)]
+pub struct Command {
+ #[serde(default = "default_buffer_size")]
+ pub buffer_size: usize,
+
+ #[serde(skip, default = "default_command")]
+ pub command: String,
+
+ #[serde(skip, default = "default_args")]
+ pub args: Vec<String>,
+}
+
+impl Command {
+ pub fn merge_args<'a>(
+ &mut self,
+ matches: &clap::ArgMatches<'a>,
+ ) -> Result<()> {
+ if matches.is_present("buffer-size") {
+ let buffer_size = matches.value_of("buffer-size").unwrap();
+ self.buffer_size = buffer_size.parse().context(
+ crate::error::ParseBufferSize { input: buffer_size },
+ )?;
+ }
+ if matches.is_present("command") {
+ self.command = matches.value_of("command").unwrap().to_string();
+ }
+ if matches.is_present("args") {
+ self.args = matches
+ .values_of("args")
+ .unwrap()
.map(std::string::ToString::to_string)
- .or_else(default_username)
- .ok_or_else(|| Error::CouldntFindUsername)
- .map(|username| crate::protocol::Auth::Plain { username })
- .map_err(serde::de::Error::custom),
- crate::protocol::AuthType::RecurseCenter => {
- Ok(crate::protocol::Auth::RecurseCenter {
- id: login_type.id.map(std::string::ToString::to_string),
- })
- }
+ .collect();
}
- })
+ Ok(())
+ }
}
-pub fn default_auth() -> crate::protocol::Auth {
- let username = default_username()
- .ok_or_else(|| Error::CouldntFindUsername)
- .unwrap();
- crate::protocol::Auth::Plain { username }
+impl Default for Command {
+ fn default() -> Self {
+ Self {
+ buffer_size: default_buffer_size(),
+ command: default_command(),
+ args: default_args(),
+ }
+ }
}
-#[derive(serde::Deserialize)]
-struct LoginType<'a> {
- #[serde(deserialize_with = "auth_type", default = "default_auth_type")]
- login_type: crate::protocol::AuthType,
- username: Option<&'a str>,
- id: Option<&'a str>,
+fn default_command() -> String {
+ std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string())
}
-fn auth_type<'a, D>(
- deserializer: D,
-) -> std::result::Result<crate::protocol::AuthType, D::Error>
-where
- D: serde::de::Deserializer<'a>,
-{
- crate::protocol::AuthType::try_from(
- <String>::deserialize(deserializer)?.as_ref(),
- )
- .map_err(serde::de::Error::custom)
+fn default_args() -> Vec<String> {
+ vec![]
}
-fn default_auth_type() -> crate::protocol::AuthType {
- DEFAULT_AUTH_TYPE
+#[derive(serde::Deserialize, Debug)]
+pub struct Ttyrec {
+ #[serde(default = "default_ttyrec_filename")]
+ pub filename: String,
}
-fn default_username() -> Option<String> {
- std::env::var("USER").ok()
+impl Ttyrec {
+ pub fn merge_args<'a>(
+ &mut self,
+ matches: &clap::ArgMatches<'a>,
+ ) -> Result<()> {
+ if matches.is_present("filename") {
+ self.filename = matches.value_of("filename").unwrap().to_string();
+ }
+ Ok(())
+ }
+}
+
+impl Default for Ttyrec {
+ fn default() -> Self {
+ Self {
+ filename: default_ttyrec_filename(),
+ }
+ }
+}
+
+fn default_ttyrec_filename() -> String {
+ DEFAULT_TTYREC_FILENAME.to_string()
}