aboutsummaryrefslogtreecommitdiffstats
path: root/src/cmd/server.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/server.rs')
-rw-r--r--src/cmd/server.rs55
1 files changed, 49 insertions, 6 deletions
diff --git a/src/cmd/server.rs b/src/cmd/server.rs
index 07d9321..be70da6 100644
--- a/src/cmd/server.rs
+++ b/src/cmd/server.rs
@@ -1,4 +1,5 @@
use crate::prelude::*;
+use std::convert::TryFrom as _;
use std::io::Read as _;
pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> {
@@ -23,6 +24,12 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> {
.long("tls-identity-file")
.takes_value(true),
)
+ .arg(
+ clap::Arg::with_name("allowed-login-methods")
+ .long("allowed-login-methods")
+ .use_delimiter(true)
+ .takes_value(true),
+ )
}
pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> {
@@ -46,7 +53,20 @@ pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> {
},
)?;
let tls_identity_file = matches.value_of("tls-identity-file");
- run_impl(address, buffer_size, read_timeout, tls_identity_file)
+ let allowed_login_methods =
+ matches.values_of("allowed-login-methods").map_or_else(
+ || Ok(crate::protocol::AuthType::iter().collect()),
+ |methods| {
+ methods.map(crate::protocol::AuthType::try_from).collect()
+ },
+ )?;
+ run_impl(
+ address,
+ buffer_size,
+ read_timeout,
+ tls_identity_file,
+ allowed_login_methods,
+ )
}
fn run_impl(
@@ -54,6 +74,9 @@ fn run_impl(
buffer_size: usize,
read_timeout: std::time::Duration,
tls_identity_file: Option<&str>,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<()> {
let (acceptor, server) =
if let Some(tls_identity_file) = tls_identity_file {
@@ -62,9 +85,15 @@ fn run_impl(
buffer_size,
read_timeout,
tls_identity_file,
+ allowed_login_methods,
)?
} else {
- create_server(address, buffer_size, read_timeout)?
+ create_server(
+ address,
+ buffer_size,
+ read_timeout,
+ allowed_login_methods,
+ )?
};
tokio::run(futures::future::lazy(move || {
tokio::spawn(server.map_err(|e| {
@@ -82,6 +111,9 @@ fn create_server(
address: std::net::SocketAddr,
buffer_size: usize,
read_timeout: std::time::Duration,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<(
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
@@ -97,8 +129,12 @@ fn create_server(
.try_send(sock)
.context(crate::error::SendSocketChannel)
});
- let server =
- crate::server::Server::new(buffer_size, read_timeout, sock_r);
+ let server = crate::server::Server::new(
+ buffer_size,
+ read_timeout,
+ sock_r,
+ allowed_login_methods,
+ );
Ok((Box::new(acceptor), Box::new(server)))
}
@@ -107,6 +143,9 @@ fn create_server_tls(
buffer_size: usize,
read_timeout: std::time::Duration,
tls_identity_file: &str,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<(
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
@@ -135,7 +174,11 @@ fn create_server_tls(
.try_send(sock)
.map_err(|_| Error::SendSocketChannelTls {})
});
- let server =
- crate::server::tls::Server::new(buffer_size, read_timeout, sock_r);
+ let server = crate::server::tls::Server::new(
+ buffer_size,
+ read_timeout,
+ sock_r,
+ allowed_login_methods,
+ );
Ok((Box::new(acceptor), Box::new(server)))
}