diff options
Diffstat (limited to 'src/cmd/server.rs')
-rw-r--r-- | src/cmd/server.rs | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/src/cmd/server.rs b/src/cmd/server.rs index 07d9321..be70da6 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use std::convert::TryFrom as _; use std::io::Read as _; pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { @@ -23,6 +24,12 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { .long("tls-identity-file") .takes_value(true), ) + .arg( + clap::Arg::with_name("allowed-login-methods") + .long("allowed-login-methods") + .use_delimiter(true) + .takes_value(true), + ) } pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { @@ -46,7 +53,20 @@ pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { }, )?; let tls_identity_file = matches.value_of("tls-identity-file"); - run_impl(address, buffer_size, read_timeout, tls_identity_file) + let allowed_login_methods = + matches.values_of("allowed-login-methods").map_or_else( + || Ok(crate::protocol::AuthType::iter().collect()), + |methods| { + methods.map(crate::protocol::AuthType::try_from).collect() + }, + )?; + run_impl( + address, + buffer_size, + read_timeout, + tls_identity_file, + allowed_login_methods, + ) } fn run_impl( @@ -54,6 +74,9 @@ fn run_impl( buffer_size: usize, read_timeout: std::time::Duration, tls_identity_file: Option<&str>, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<()> { let (acceptor, server) = if let Some(tls_identity_file) = tls_identity_file { @@ -62,9 +85,15 @@ fn run_impl( buffer_size, read_timeout, tls_identity_file, + allowed_login_methods, )? } else { - create_server(address, buffer_size, read_timeout)? + create_server( + address, + buffer_size, + read_timeout, + allowed_login_methods, + )? }; tokio::run(futures::future::lazy(move || { tokio::spawn(server.map_err(|e| { @@ -82,6 +111,9 @@ fn create_server( address: std::net::SocketAddr, buffer_size: usize, read_timeout: std::time::Duration, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<( Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, @@ -97,8 +129,12 @@ fn create_server( .try_send(sock) .context(crate::error::SendSocketChannel) }); - let server = - crate::server::Server::new(buffer_size, read_timeout, sock_r); + let server = crate::server::Server::new( + buffer_size, + read_timeout, + sock_r, + allowed_login_methods, + ); Ok((Box::new(acceptor), Box::new(server))) } @@ -107,6 +143,9 @@ fn create_server_tls( buffer_size: usize, read_timeout: std::time::Duration, tls_identity_file: &str, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<( Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, @@ -135,7 +174,11 @@ fn create_server_tls( .try_send(sock) .map_err(|_| Error::SendSocketChannelTls {}) }); - let server = - crate::server::tls::Server::new(buffer_size, read_timeout, sock_r); + let server = crate::server::tls::Server::new( + buffer_size, + read_timeout, + sock_r, + allowed_login_methods, + ); Ok((Box::new(acceptor), Box::new(server))) } |