diff options
author | Jesse Luehrs <doy@tozt.net> | 2019-10-17 10:05:28 -0400 |
---|---|---|
committer | Jesse Luehrs <doy@tozt.net> | 2019-10-17 10:06:23 -0400 |
commit | 0ff2ba5d413834d04a24ce3f29a0172c82a31c35 (patch) | |
tree | 9ea695ea24c5543703ecefc1cf4efd741541b92a | |
parent | 006802454b01f6f20168e55cb2427db7b89039a1 (diff) | |
download | teleterm-0ff2ba5d413834d04a24ce3f29a0172c82a31c35.tar.gz teleterm-0ff2ba5d413834d04a24ce3f29a0172c82a31c35.zip |
refactor config handling
-rw-r--r-- | src/cmd.rs | 33 | ||||
-rw-r--r-- | src/cmd/play.rs | 50 | ||||
-rw-r--r-- | src/cmd/record.rs | 104 | ||||
-rw-r--r-- | src/cmd/server.rs | 185 | ||||
-rw-r--r-- | src/cmd/stream.rs | 239 | ||||
-rw-r--r-- | src/cmd/watch.rs | 191 | ||||
-rw-r--r-- | src/config.rs | 179 | ||||
-rw-r--r-- | src/main.rs | 2 | ||||
-rw-r--r-- | src/util.rs | 31 |
9 files changed, 688 insertions, 326 deletions
@@ -9,7 +9,7 @@ mod watch; struct Command { name: &'static str, cmd: &'static dyn for<'a, 'b> Fn(clap::App<'a, 'b>) -> clap::App<'a, 'b>, - run: &'static dyn for<'a> Fn(&clap::ArgMatches<'a>) -> Result<()>, + config: &'static dyn Fn() -> Box<dyn crate::config::Config>, log_level: &'static str, } @@ -17,37 +17,37 @@ const COMMANDS: &[Command] = &[ Command { name: "stream", cmd: &stream::cmd, - run: &stream::run, + config: &stream::config, log_level: "error", }, Command { name: "server", cmd: &server::cmd, - run: &server::run, + config: &server::config, log_level: "info", }, Command { name: "watch", cmd: &watch::cmd, - run: &watch::run, + config: &watch::config, log_level: "error", }, Command { name: "record", cmd: &record::cmd, - run: &record::run, + config: &record::config, log_level: "error", }, Command { name: "play", cmd: &play::cmd, - run: &play::run, + config: &play::config, log_level: "error", }, ]; pub fn parse<'a>() -> Result<clap::ArgMatches<'a>> { - let mut app = clap::App::new(crate::util::program_name()?) + let mut app = clap::App::new(program_name()?) .about("Stream your terminal for other people to watch") .author(clap::crate_authors!()) .version(clap::crate_version!()); @@ -69,9 +69,26 @@ pub fn run(matches: &clap::ArgMatches<'_>) -> Result<()> { chosen_submatches = submatches; } } + env_logger::from_env( env_logger::Env::default().default_filter_or(chosen_cmd.log_level), ) .init(); - (chosen_cmd.run)(chosen_submatches) + + let mut config = (chosen_cmd.config)(); + config.merge_args(chosen_submatches)?; + config.run() +} + +fn program_name() -> Result<String> { + let program = + std::env::args().next().context(crate::error::MissingArgv)?; + let path = std::path::Path::new(&program); + let filename = path.file_name(); + Ok(filename + .ok_or_else(|| Error::NotAFileName { + path: path.to_string_lossy().to_string(), + })? + .to_string_lossy() + .to_string()) } diff --git a/src/cmd/play.rs b/src/cmd/play.rs index 996fd90..08d911f 100644 --- a/src/cmd/play.rs +++ b/src/cmd/play.rs @@ -1,26 +1,50 @@ use crate::prelude::*; use std::io::Write as _; +#[derive(serde::Deserialize)] +pub struct Config { + #[serde(default = "crate::config::default_ttyrec_filename")] + filename: String, +} + +impl crate::config::Config for Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("filename") { + self.filename = matches.value_of("filename").unwrap().to_string(); + } + Ok(()) + } + + fn run(&self) -> Result<()> { + let fut = PlaySession::new(&self.filename); + tokio::run(fut.map_err(|e| { + eprintln!("{}", e); + })); + Ok(()) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + filename: crate::config::default_ttyrec_filename(), + } + } +} + pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { app.about("Play recorded terminal sessions").arg( clap::Arg::with_name("filename") .long("filename") - .takes_value(true) - .required(true), + .takes_value(true), ) } -pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> Result<()> { - let filename = matches.value_of("filename").unwrap(); - run_impl(filename) -} - -fn run_impl(filename: &str) -> Result<()> { - let fut = PlaySession::new(filename); - tokio::run(fut.map_err(|e| { - eprintln!("{}", e); - })); - Ok(()) +pub fn config() -> Box<dyn crate::config::Config> { + Box::new(Config::default()) } #[allow(clippy::large_enum_variant)] diff --git a/src/cmd/record.rs b/src/cmd/record.rs index c233dc4..83651f5 100644 --- a/src/cmd/record.rs +++ b/src/cmd/record.rs @@ -1,13 +1,79 @@ use crate::prelude::*; use tokio::io::AsyncWrite as _; +#[derive(serde::Deserialize)] +pub struct Config { + #[serde(default = "crate::config::default_ttyrec_filename")] + filename: String, + + #[serde(default = "crate::config::default_connection_buffer_size")] + buffer_size: usize, + + #[serde(default = "crate::config::default_command")] + command: String, + + #[serde(default = "crate::config::default_args")] + args: Vec<String>, +} + +impl crate::config::Config for Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("filename") { + self.filename = matches.value_of("filename").unwrap().to_string(); + } + if matches.is_present("buffer-size") { + let buffer_size = matches.value_of("buffer-size").unwrap(); + self.buffer_size = buffer_size.parse().context( + crate::error::ParseBufferSize { input: buffer_size }, + )?; + } + if matches.is_present("command") { + self.command = matches.value_of("command").unwrap().to_string(); + } + if matches.is_present("args") { + self.args = matches + .values_of("args") + .unwrap() + .map(std::string::ToString::to_string) + .collect(); + } + Ok(()) + } + + fn run(&self) -> Result<()> { + let fut = RecordSession::new( + &self.filename, + self.buffer_size, + &self.command, + &self.args, + ); + tokio::run(fut.map_err(|e| { + eprintln!("{}", e); + })); + Ok(()) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + filename: crate::config::default_ttyrec_filename(), + buffer_size: crate::config::default_connection_buffer_size(), + command: crate::config::default_command(), + args: crate::config::default_args(), + } + } +} + pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { app.about("Record a terminal session to a file") .arg( clap::Arg::with_name("filename") .long("filename") - .takes_value(true) - .required(true), + .takes_value(true), ) .arg( clap::Arg::with_name("buffer-size") @@ -18,38 +84,8 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { .arg(clap::Arg::with_name("args").index(2).multiple(true)) } -pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> Result<()> { - let filename = matches.value_of("filename").unwrap(); - let buffer_size = - matches - .value_of("buffer-size") - .map_or(Ok(4 * 1024 * 1024), |s| { - s.parse() - .context(crate::error::ParseBufferSize { input: s }) - })?; - let command = matches.value_of("command").map_or_else( - || std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string()), - std::string::ToString::to_string, - ); - let args = if let Some(args) = matches.values_of("args") { - args.map(std::string::ToString::to_string).collect() - } else { - vec![] - }; - run_impl(filename, buffer_size, &command, &args) -} - -fn run_impl( - filename: &str, - buffer_size: usize, - command: &str, - args: &[String], -) -> Result<()> { - let fut = RecordSession::new(filename, buffer_size, command, args); - tokio::run(fut.map_err(|e| { - eprintln!("{}", e); - })); - Ok(()) +pub fn config() -> Box<dyn crate::config::Config> { + Box::new(Config::default()) } #[allow(clippy::large_enum_variant)] diff --git a/src/cmd/server.rs b/src/cmd/server.rs index be70da6..ae96a86 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -2,6 +2,116 @@ use crate::prelude::*; use std::convert::TryFrom as _; use std::io::Read as _; +#[derive(serde::Deserialize)] +pub struct Config { + #[serde( + deserialize_with = "crate::config::listen_address", + default = "crate::config::default_listen_address" + )] + address: std::net::SocketAddr, + + #[serde(default = "crate::config::default_connection_buffer_size")] + buffer_size: usize, + + #[serde(default = "crate::config::default_read_timeout")] + read_timeout: std::time::Duration, + + tls_identity_file: Option<String>, + + #[serde( + deserialize_with = "crate::config::allowed_login_methods", + default = "crate::config::default_allowed_login_methods" + )] + allowed_login_methods: + std::collections::HashSet<crate::protocol::AuthType>, +} + +impl crate::config::Config for Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("address") { + self.address = matches + .value_of("address") + .unwrap() + .parse() + .context(crate::error::ParseAddr)?; + } + if matches.is_present("buffer-size") { + let s = matches.value_of("buffer-size").unwrap(); + self.buffer_size = s + .parse() + .context(crate::error::ParseBufferSize { input: s })?; + } + if matches.is_present("read-timeout") { + let s = matches.value_of("read-timeout").unwrap(); + self.read_timeout = s + .parse() + .map(std::time::Duration::from_secs) + .context(crate::error::ParseReadTimeout { input: s })?; + } + if matches.is_present("tls-identity-file") { + self.tls_identity_file = Some( + matches.value_of("tls-identity-file").unwrap().to_string(), + ); + } + if matches.is_present("allowed-login-methods") { + self.allowed_login_methods = matches + .values_of("allowed-login-methods") + .unwrap() + .map(crate::protocol::AuthType::try_from) + .collect::<Result< + std::collections::HashSet<crate::protocol::AuthType>, + >>()?; + } + Ok(()) + } + + fn run(&self) -> Result<()> { + let (acceptor, server) = + if let Some(tls_identity_file) = &self.tls_identity_file { + create_server_tls( + self.address, + self.buffer_size, + self.read_timeout, + tls_identity_file, + self.allowed_login_methods.clone(), + )? + } else { + create_server( + self.address, + self.buffer_size, + self.read_timeout, + self.allowed_login_methods.clone(), + )? + }; + tokio::run(futures::future::lazy(move || { + tokio::spawn(server.map_err(|e| { + eprintln!("{}", e); + })); + + acceptor.map_err(|e| { + eprintln!("{}", e); + }) + })); + Ok(()) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + address: crate::config::default_listen_address(), + buffer_size: crate::config::default_connection_buffer_size(), + read_timeout: crate::config::default_read_timeout(), + tls_identity_file: None, + allowed_login_methods: + crate::config::default_allowed_login_methods(), + } + } +} + pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { app.about("Run a teleterm server") .arg( @@ -32,79 +142,8 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { ) } -pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { - let address = matches.value_of("address").map_or_else( - || Ok("0.0.0.0:4144".parse().unwrap()), - |s| s.parse().context(crate::error::ParseAddr), - )?; - let buffer_size = - matches - .value_of("buffer-size") - .map_or(Ok(4 * 1024 * 1024), |s| { - s.parse() - .context(crate::error::ParseBufferSize { input: s }) - })?; - let read_timeout = matches.value_of("read-timeout").map_or( - Ok(std::time::Duration::from_secs(120)), - |s| { - s.parse() - .map(std::time::Duration::from_secs) - .context(crate::error::ParseReadTimeout { input: s }) - }, - )?; - let tls_identity_file = matches.value_of("tls-identity-file"); - let allowed_login_methods = - matches.values_of("allowed-login-methods").map_or_else( - || Ok(crate::protocol::AuthType::iter().collect()), - |methods| { - methods.map(crate::protocol::AuthType::try_from).collect() - }, - )?; - run_impl( - address, - buffer_size, - read_timeout, - tls_identity_file, - allowed_login_methods, - ) -} - -fn run_impl( - address: std::net::SocketAddr, - buffer_size: usize, - read_timeout: std::time::Duration, - tls_identity_file: Option<&str>, - allowed_login_methods: std::collections::HashSet< - crate::protocol::AuthType, - >, -) -> Result<()> { - let (acceptor, server) = - if let Some(tls_identity_file) = tls_identity_file { - create_server_tls( - address, - buffer_size, - read_timeout, - tls_identity_file, - allowed_login_methods, - )? - } else { - create_server( - address, - buffer_size, - read_timeout, - allowed_login_methods, - )? - }; - tokio::run(futures::future::lazy(move || { - tokio::spawn(server.map_err(|e| { - eprintln!("{}", e); - })); - - acceptor.map_err(|e| { - eprintln!("{}", e); - }) - })); - Ok(()) +pub fn config() -> Box<dyn crate::config::Config> { + Box::new(Config::default()) } fn create_server( diff --git a/src/cmd/stream.rs b/src/cmd/stream.rs index 24e2a41..805ed8e 100644 --- a/src/cmd/stream.rs +++ b/src/cmd/stream.rs @@ -1,6 +1,150 @@ use crate::prelude::*; use tokio::io::AsyncWrite as _; +#[derive(serde::Deserialize)] +pub struct Config { + #[serde( + deserialize_with = "crate::config::auth", + default = "crate::config::default_auth" + )] + auth: crate::protocol::Auth, + + #[serde( + deserialize_with = "crate::config::connect_address", + default = "crate::config::default_connect_address" + )] + address: (String, std::net::SocketAddr), + + #[serde(default = "crate::config::default_tls")] + tls: bool, + + #[serde(default = "crate::config::default_connection_buffer_size")] + buffer_size: usize, + + #[serde(default = "crate::config::default_command")] + command: String, + + #[serde(default = "crate::config::default_args")] + args: Vec<String>, +} + +impl Config { + fn host(&self) -> &str { + &self.address.0 + } + + fn addr(&self) -> &std::net::SocketAddr { + &self.address.1 + } +} + +impl crate::config::Config for Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("login-recurse-center") { + let id = crate::oauth::load_client_auth_id( + crate::protocol::AuthType::RecurseCenter, + ); + self.auth = crate::protocol::Auth::recurse_center( + id.as_ref().map(std::string::String::as_str), + ); + } + if matches.is_present("login-plain") { + let username = + matches.value_of("login-plain").unwrap().to_string(); + self.auth = crate::protocol::Auth::plain(&username); + } + if matches.is_present("address") { + let address = matches.value_of("address").unwrap(); + self.address = crate::config::to_connect_address(address)?; + } + if matches.is_present("tls") { + self.tls = true; + } + if matches.is_present("buffer-size") { + let buffer_size = matches.value_of("buffer-size").unwrap(); + self.buffer_size = buffer_size.parse().context( + crate::error::ParseBufferSize { input: buffer_size }, + )?; + } + if matches.is_present("command") { + self.command = matches.value_of("command").unwrap().to_string(); + } + if matches.is_present("args") { + self.args = matches + .values_of("args") + .unwrap() + .map(std::string::ToString::to_string) + .collect(); + } + Ok(()) + } + + fn run(&self) -> Result<()> { + let host = self.host().to_string(); + let address = *self.addr(); + let fut: Box< + dyn futures::future::Future<Item = (), Error = Error> + Send, + > = if self.tls { + let connector = native_tls::TlsConnector::new() + .context(crate::error::CreateConnector)?; + let connect: crate::client::Connector<_> = Box::new(move || { + let host = host.clone(); + let connector = connector.clone(); + let connector = tokio_tls::TlsConnector::from(connector); + let stream = tokio::net::tcp::TcpStream::connect(&address); + Box::new(stream.context(crate::error::Connect).and_then( + move |stream| { + connector + .connect(&host, stream) + .context(crate::error::ConnectTls) + }, + )) + }); + Box::new(StreamSession::new( + &self.command, + &self.args, + connect, + self.buffer_size, + &self.auth, + )) + } else { + let connect: crate::client::Connector<_> = Box::new(move || { + Box::new( + tokio::net::tcp::TcpStream::connect(&address) + .context(crate::error::Connect), + ) + }); + Box::new(StreamSession::new( + &self.command, + &self.args, + connect, + self.buffer_size, + &self.auth, + )) + }; + tokio::run(fut.map_err(|e| { + eprintln!("{}", e); + })); + Ok(()) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + auth: crate::config::default_auth(), + address: crate::config::default_connect_address(), + tls: crate::config::default_tls(), + buffer_size: crate::config::default_connection_buffer_size(), + command: crate::config::default_command(), + args: crate::config::default_args(), + } + } +} + pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { app.about("Stream your terminal") .arg( @@ -28,99 +172,8 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { .arg(clap::Arg::with_name("args").index(2).multiple(true)) } -pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { - let auth = if matches.is_present("login-recurse-center") { - let id = crate::oauth::load_client_auth_id( - crate::protocol::AuthType::RecurseCenter, - ); - crate::protocol::Auth::recurse_center( - id.as_ref().map(std::string::String::as_str), - ) - } else { - let username = matches - .value_of("login-plain") - .map(std::string::ToString::to_string) - .or_else(|| std::env::var("USER").ok()) - .context(crate::error::CouldntFindUsername)?; - crate::protocol::Auth::plain(&username) - }; - let address = matches.value_of("address").unwrap_or("127.0.0.1:4144"); - let (host, address) = crate::util::resolve_address(address)?; - let tls = matches.is_present("tls"); - let buffer_size = - matches - .value_of("buffer-size") - .map_or(Ok(4 * 1024 * 1024), |s| { - s.parse() - .context(crate::error::ParseBufferSize { input: s }) - })?; - let command = matches.value_of("command").map_or_else( - || std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string()), - std::string::ToString::to_string, - ); - let args = if let Some(args) = matches.values_of("args") { - args.map(std::string::ToString::to_string).collect() - } else { - vec![] - }; - run_impl(&auth, &host, address, tls, buffer_size, &command, &args) -} - -fn run_impl( - auth: &crate::protocol::Auth, - host: &str, - address: std::net::SocketAddr, - tls: bool, - buffer_size: usize, - command: &str, - args: &[String], -) -> Result<()> { - let host = host.to_string(); - let fut: Box< - dyn futures::future::Future<Item = (), Error = Error> + Send, - > = if tls { - let connector = native_tls::TlsConnector::new() - .context(crate::error::CreateConnector)?; - let connect: crate::client::Connector<_> = Box::new(move || { - let host = host.clone(); - let connector = connector.clone(); - let connector = tokio_tls::TlsConnector::from(connector); - let stream = tokio::net::tcp::TcpStream::connect(&address); - Box::new(stream.context(crate::error::Connect).and_then( - move |stream| { - connector - .connect(&host, stream) - .context(crate::error::ConnectTls) - }, - )) - }); - Box::new(StreamSession::new( - command, - args, - connect, - buffer_size, - auth, - )) - } else { - let connect: crate::client::Connector<_> = Box::new(move || { - Box::new( - tokio::net::tcp::TcpStream::connect(&address) - .context(crate::error::Connect), - ) - }); - Box::new(StreamSession::new( - command, - args, - connect, - buffer_size, - auth, - )) - }; - tokio::run(fut.map_err(|e| { - eprintln!("{}", e); - })); - - Ok(()) +pub fn config() -> Box<dyn crate::config::Config> { + Box::new(Config::default()) } struct StreamSession< diff --git a/src/cmd/watch.rs b/src/cmd/watch.rs index 200b9fe..75eba03 100644 --- a/src/cmd/watch.rs +++ b/src/cmd/watch.rs @@ -1,6 +1,122 @@ use crate::prelude::*; use std::io::Write as _; +#[derive(serde::Deserialize)] +pub struct Config { + #[serde( + deserialize_with = "crate::config::auth", + default = "crate::config::default_auth" + )] + auth: crate::protocol::Auth, + + #[serde( + deserialize_with = "crate::config::connect_address", + default = "crate::config::default_connect_address" + )] + address: (String, std::net::SocketAddr), + + #[serde(default = "crate::config::default_tls")] + tls: bool, +} + +impl Config { + fn host(&self) -> &str { + &self.address.0 + } + + fn addr(&self) -> &std::net::SocketAddr { + &self.address.1 + } +} + +impl crate::config::Config for Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()> { + if matches.is_present("login-recurse-center") { + let id = crate::oauth::load_client_auth_id( + crate::protocol::AuthType::RecurseCenter, + ); + self.auth = crate::protocol::Auth::recurse_center( + id.as_ref().map(std::string::String::as_str), + ); + } + if matches.is_present("login-plain") { + let username = + matches.value_of("login-plain").unwrap().to_string(); + self.auth = crate::protocol::Auth::plain(&username); + } + if matches.is_present("address") { + let address = matches.value_of("address").unwrap(); + self.address = crate::config::to_connect_address(address)?; + } + if matches.is_present("tls") { + self.tls = true; + } + Ok(()) + } + + fn run(&self) -> Result<()> { + let host = self.host().to_string(); + let address = *self.addr(); + let auth = self.auth.clone(); + let fut: Box< + dyn futures::future::Future<Item = (), Error = Error> + Send, + > = if self.tls { + let connector = native_tls::TlsConnector::new() + .context(crate::error::CreateConnector)?; + let make_connector: Box< + dyn Fn() -> crate::client::Connector<_> + Send, + > = Box::new(move || { + let host = host.clone(); + let connector = connector.clone(); + Box::new(move || { + let host = host.clone(); + let connector = connector.clone(); + let connector = tokio_tls::TlsConnector::from(connector); + let stream = + tokio::net::tcp::TcpStream::connect(&address); + Box::new(stream.context(crate::error::Connect).and_then( + move |stream| { + connector + .connect(&host, stream) + .context(crate::error::ConnectTls) + }, + )) + }) + }); + Box::new(WatchSession::new(make_connector, &auth)) + } else { + let make_connector: Box< + dyn Fn() -> crate::client::Connector<_> + Send, + > = Box::new(move || { + Box::new(move || { + Box::new( + tokio::net::tcp::TcpStream::connect(&address) + .context(crate::error::Connect), + ) + }) + }); + Box::new(WatchSession::new(make_connector, &auth)) + }; + tokio::run(fut.map_err(|e| { + eprintln!("{}", e); + })); + Ok(()) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + auth: crate::config::default_auth(), + address: crate::config::default_connect_address(), + tls: crate::config::default_tls(), + } + } +} + pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { app.about("Watch teleterm streams") .arg( @@ -21,79 +137,8 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { .arg(clap::Arg::with_name("tls").long("tls")) } -pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { - let auth = if matches.is_present("login-recurse-center") { - let id = crate::oauth::load_client_auth_id( - crate::protocol::AuthType::RecurseCenter, - ); - crate::protocol::Auth::recurse_center( - id.as_ref().map(std::string::String::as_str), - ) - } else { - let username = matches - .value_of("login-plain") - .map(std::string::ToString::to_string) - .or_else(|| std::env::var("USER").ok()) - .context(crate::error::CouldntFindUsername)?; - crate::protocol::Auth::plain(&username) - }; - let address = matches.value_of("address").unwrap_or("127.0.0.1:4144"); - let (host, address) = crate::util::resolve_address(address)?; - let tls = matches.is_present("tls"); - run_impl(&auth, &host, address, tls) -} - -fn run_impl( - auth: &crate::protocol::Auth, - host: &str, - address: std::net::SocketAddr, - tls: bool, -) -> Result<()> { - let host = host.to_string(); - let auth = auth.clone(); - let fut: Box< - dyn futures::future::Future<Item = (), Error = Error> + Send, - > = if tls { - let connector = native_tls::TlsConnector::new() - .context(crate::error::CreateConnector)?; - let make_connector: Box< - dyn Fn() -> crate::client::Connector<_> + Send, - > = Box::new(move || { - let host = host.clone(); - let connector = connector.clone(); - Box::new(move || { - let host = host.clone(); - let connector = connector.clone(); - let connector = tokio_tls::TlsConnector::from(connector); - let stream = tokio::net::tcp::TcpStream::connect(&address); - Box::new(stream.context(crate::error::Connect).and_then( - move |stream| { - connector - .connect(&host, stream) - .context(crate::error::ConnectTls) - }, - )) - }) - }); - Box::new(WatchSession::new(make_connector, &auth)) - } else { - let make_connector: Box< - dyn Fn() -> crate::client::Connector<_> + Send, - > = Box::new(move || { - Box::new(move || { - Box::new( - tokio::net::tcp::TcpStream::connect(&address) - .context(crate::error::Connect), - ) - }) - }); - Box::new(WatchSession::new(make_connector, &auth)) - }; - tokio::run(fut.map_err(|e| { - eprintln!("{}", e); - })); - - Ok(()) +pub fn config() -> Box<dyn crate::config::Config> { + Box::new(Config::default()) } // XXX https://github.com/rust-lang/rust/issues/64362 diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..03a5061 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,179 @@ +use crate::prelude::*; +use serde::de::Deserialize as _; +use std::convert::TryFrom as _; +use std::net::ToSocketAddrs as _; + +const DEFAULT_LISTEN_ADDRESS: &str = "127.0.0.1:4144"; +const DEFAULT_CONNECT_ADDRESS: &str = "127.0.0.1:4144"; +const DEFAULT_CONNECTION_BUFFER_SIZE: usize = 4 * 1024 * 1024; +const DEFAULT_READ_TIMEOUT: std::time::Duration = + std::time::Duration::from_secs(120); +const DEFAULT_AUTH_TYPE: crate::protocol::AuthType = + crate::protocol::AuthType::Plain; +const DEFAULT_TLS: bool = false; +const DEFAULT_TTYREC_FILENAME: &str = "teleterm.ttyrec"; + +pub trait Config { + fn merge_args<'a>( + &mut self, + matches: &clap::ArgMatches<'a>, + ) -> Result<()>; + fn run(&self) -> Result<()>; +} + +pub fn listen_address<'a, D>( + deserializer: D, +) -> std::result::Result<std::net::SocketAddr, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + to_listen_address(<&str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) +} + +pub fn default_listen_address() -> std::net::SocketAddr { + to_listen_address(DEFAULT_LISTEN_ADDRESS).unwrap() +} + +pub fn to_listen_address(address: &str) -> Result<std::net::SocketAddr> { + address.parse().context(crate::error::ParseAddr) +} + +pub fn connect_address<'a, D>( + deserializer: D, +) -> std::result::Result<(String, std::net::SocketAddr), D::Error> +where + D: serde::de::Deserializer<'a>, +{ + to_connect_address(<&str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) +} + +pub fn default_connect_address() -> (String, std::net::SocketAddr) { + to_connect_address(DEFAULT_CONNECT_ADDRESS).unwrap() +} + +// XXX this does a blocking dns lookup - should try to find an async version +pub fn to_connect_address( + address: &str, +) -> Result<(String, std::net::SocketAddr)> { + let mut address_parts = address.split(':'); + let host = address_parts.next().context(crate::error::ParseAddress)?; + let port = address_parts.next().context(crate::error::ParseAddress)?; + let port: u16 = port.parse().context(crate::error::ParsePort)?; + let socket_addr = (host, port) + .to_socket_addrs() + .context(crate::error::ResolveAddress)? + .next() + .context(crate::error::HasResolvedAddr)?; + Ok((host.to_string(), socket_addr)) +} + +pub fn default_connection_buffer_size() -> usize { + DEFAULT_CONNECTION_BUFFER_SIZE +} + +pub fn default_read_timeout() -> std::time::Duration { + DEFAULT_READ_TIMEOUT +} + +pub fn default_tls() -> bool { + DEFAULT_TLS +} + +pub fn default_command() -> String { + std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string()) +} + +pub fn default_args() -> Vec<String> { + vec![] +} + +pub fn default_ttyrec_filename() -> String { + DEFAULT_TTYREC_FILENAME.to_string() +} + +pub fn default_allowed_login_methods( +) -> std::collections::HashSet<crate::protocol::AuthType> { + crate::protocol::AuthType::iter().collect() +} + +pub fn allowed_login_methods<'a, D>( + deserializer: D, +) -> std::result::Result< + std::collections::HashSet<crate::protocol::AuthType>, + D::Error, +> +where + D: serde::de::Deserializer<'a>, +{ + Option::<Vec<&str>>::deserialize(deserializer)? + .map_or_else( + || Ok(default_allowed_login_methods()), + |methods| { + methods + .iter() + .copied() + .map(crate::protocol::AuthType::try_from) + .collect() + }, + ) + .map_err(serde::de::Error::custom) +} + +pub fn auth<'a, D>( + deserializer: D, +) -> std::result::Result<crate::protocol::Auth, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + LoginType::deserialize(deserializer).and_then(|login_type| { + match login_type.login_type { + crate::protocol::AuthType::Plain => login_type + .username + .map(std::string::ToString::to_string) + .or_else(default_username) + .ok_or_else(|| Error::CouldntFindUsername) + .map(|username| crate::protocol::Auth::Plain { username }) + .map_err(serde::de::Error::custom), + crate::protocol::AuthType::RecurseCenter => { + Ok(crate::protocol::Auth::RecurseCenter { + id: login_type.id.map(std::string::ToString::to_string), + }) + } + } + }) +} + +pub fn default_auth() -> crate::protocol::Auth { + let username = default_username() + .ok_or_else(|| Error::CouldntFindUsername) + .unwrap(); + crate::protocol::Auth::Plain { username } +} + +#[derive(serde::Deserialize)] +struct LoginType<'a> { + #[serde(deserialize_with = "auth_type", default = "default_auth_type")] + login_type: crate::protocol::AuthType, + username: Option<&'a str>, + id: Option<&'a str>, +} + +fn auth_type<'a, D>( + deserializer: D, +) -> std::result::Result<crate::protocol::AuthType, D::Error> +where + D: serde::de::Deserializer<'a>, +{ + crate::protocol::AuthType::try_from(<&str>::deserialize(deserializer)?) + .map_err(serde::de::Error::custom) +} + +fn default_auth_type() -> crate::protocol::AuthType { + DEFAULT_AUTH_TYPE +} + +fn default_username() -> Option<String> { + std::env::var("USER").ok() +} diff --git a/src/main.rs b/src/main.rs index 20172af..f44dba0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ mod async_stdin; mod client; mod cmd; mod component_future; +mod config; mod dirs; mod error; mod key_reader; @@ -23,7 +24,6 @@ mod server; mod session_list; mod term; mod ttyrec; -mod util; fn main() { dirs::Dirs::new().create_all().unwrap(); diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 6e814ab..0000000 --- a/src/util.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::prelude::*; -use std::net::ToSocketAddrs as _; - -pub fn program_name() -> Result<String> { - let program = - std::env::args().next().context(crate::error::MissingArgv)?; - let path = std::path::Path::new(&program); - let filename = path.file_name(); - Ok(filename - .ok_or_else(|| Error::NotAFileName { - path: path.to_string_lossy().to_string(), - })? - .to_string_lossy() - .to_string()) -} - -// XXX this does a blocking dns lookup - should try to find an async version -pub fn resolve_address( - address: &str, -) -> Result<(String, std::net::SocketAddr)> { - let mut address_parts = address.split(':'); - let host = address_parts.next().context(crate::error::ParseAddress)?; - let port = address_parts.next().context(crate::error::ParseAddress)?; - let port: u16 = port.parse().context(crate::error::ParsePort)?; - let socket_addr = (host, port) - .to_socket_addrs() - .context(crate::error::ResolveAddress)? - .next() - .context(crate::error::HasResolvedAddr)?; - Ok((host.to_string(), socket_addr)) -} |