diff options
author | Jesse Luehrs <doy@tozt.net> | 2019-10-17 02:53:41 -0400 |
---|---|---|
committer | Jesse Luehrs <doy@tozt.net> | 2019-10-17 02:53:41 -0400 |
commit | f4e6745c78c016b6103bc595ff99514b5adf973b (patch) | |
tree | 8b29b42f75d68c6e8ccf990156ec9490de2ad541 | |
parent | 5336c7ccbe7bee81b5c496f55f8fb9b1835a1003 (diff) | |
download | teleterm-f4e6745c78c016b6103bc595ff99514b5adf973b.tar.gz teleterm-f4e6745c78c016b6103bc595ff99514b5adf973b.zip |
allow restricting the accepted auth type list
-rw-r--r-- | src/cmd/server.rs | 55 | ||||
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/protocol.rs | 21 | ||||
-rw-r--r-- | src/server.rs | 10 | ||||
-rw-r--r-- | src/server/tls.rs | 10 |
5 files changed, 94 insertions, 8 deletions
diff --git a/src/cmd/server.rs b/src/cmd/server.rs index 07d9321..be70da6 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use std::convert::TryFrom as _; use std::io::Read as _; pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { @@ -23,6 +24,12 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> { .long("tls-identity-file") .takes_value(true), ) + .arg( + clap::Arg::with_name("allowed-login-methods") + .long("allowed-login-methods") + .use_delimiter(true) + .takes_value(true), + ) } pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { @@ -46,7 +53,20 @@ pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> { }, )?; let tls_identity_file = matches.value_of("tls-identity-file"); - run_impl(address, buffer_size, read_timeout, tls_identity_file) + let allowed_login_methods = + matches.values_of("allowed-login-methods").map_or_else( + || Ok(crate::protocol::AuthType::iter().collect()), + |methods| { + methods.map(crate::protocol::AuthType::try_from).collect() + }, + )?; + run_impl( + address, + buffer_size, + read_timeout, + tls_identity_file, + allowed_login_methods, + ) } fn run_impl( @@ -54,6 +74,9 @@ fn run_impl( buffer_size: usize, read_timeout: std::time::Duration, tls_identity_file: Option<&str>, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<()> { let (acceptor, server) = if let Some(tls_identity_file) = tls_identity_file { @@ -62,9 +85,15 @@ fn run_impl( buffer_size, read_timeout, tls_identity_file, + allowed_login_methods, )? } else { - create_server(address, buffer_size, read_timeout)? + create_server( + address, + buffer_size, + read_timeout, + allowed_login_methods, + )? }; tokio::run(futures::future::lazy(move || { tokio::spawn(server.map_err(|e| { @@ -82,6 +111,9 @@ fn create_server( address: std::net::SocketAddr, buffer_size: usize, read_timeout: std::time::Duration, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<( Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, @@ -97,8 +129,12 @@ fn create_server( .try_send(sock) .context(crate::error::SendSocketChannel) }); - let server = - crate::server::Server::new(buffer_size, read_timeout, sock_r); + let server = crate::server::Server::new( + buffer_size, + read_timeout, + sock_r, + allowed_login_methods, + ); Ok((Box::new(acceptor), Box::new(server))) } @@ -107,6 +143,9 @@ fn create_server_tls( buffer_size: usize, read_timeout: std::time::Duration, tls_identity_file: &str, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Result<( Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, @@ -135,7 +174,11 @@ fn create_server_tls( .try_send(sock) .map_err(|_| Error::SendSocketChannelTls {}) }); - let server = - crate::server::tls::Server::new(buffer_size, read_timeout, sock_r); + let server = crate::server::tls::Server::new( + buffer_size, + read_timeout, + sock_r, + allowed_login_methods, + ); Ok((Box::new(acceptor), Box::new(server))) } diff --git a/src/error.rs b/src/error.rs index c1d2bd7..4166cc4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,9 @@ pub enum Error { #[snafu(display("failed to accept: {}", source))] Acceptor { source: tokio::io::Error }, + #[snafu(display("auth type not allowed: {:?}", ty))] + AuthTypeNotAllowed { ty: crate::protocol::AuthType }, + #[snafu(display("failed to bind: {}", source))] Bind { source: tokio::io::Error }, @@ -67,6 +70,9 @@ pub enum Error { #[snafu(display("invalid auth type: {}", ty))] InvalidAuthType { ty: u8 }, + #[snafu(display("invalid auth type: {}", ty))] + InvalidAuthTypeStr { ty: String }, + #[snafu(display("invalid message type: {}", ty))] InvalidMessageType { ty: u8 }, diff --git a/src/protocol.rs b/src/protocol.rs index c76a8a7..324fffd 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -74,6 +74,13 @@ impl AuthType { Self::RecurseCenter => true, } } + + pub fn iter() -> impl Iterator<Item = Self> { + (0..=255) + .map(Self::try_from) + .take_while(std::result::Result::is_ok) + .map(std::result::Result::unwrap) + } } impl std::convert::TryFrom<u8> for AuthType { @@ -88,6 +95,18 @@ impl std::convert::TryFrom<u8> for AuthType { } } +impl std::convert::TryFrom<&str> for AuthType { + type Error = Error; + + fn try_from(s: &str) -> Result<Self> { + Ok(match s { + s if Self::Plain.name() == s => Self::Plain, + s if Self::RecurseCenter.name() == s => Self::RecurseCenter, + _ => return Err(Error::InvalidAuthTypeStr { ty: s.to_string() }), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Auth { Plain { username: String }, @@ -115,7 +134,7 @@ impl Auth { self.auth_type().name() } - fn auth_type(&self) -> AuthType { + pub fn auth_type(&self) -> AuthType { match self { Self::Plain { .. } => AuthType::Plain, Self::RecurseCenter { .. } => AuthType::RecurseCenter, diff --git a/src/server.rs b/src/server.rs index 76dcfe9..0116fff 100644 --- a/src/server.rs +++ b/src/server.rs @@ -329,6 +329,7 @@ pub struct Server< >, connections: std::collections::HashMap<String, Connection<S>>, rate_limiter: ratelimit_meter::KeyedRateLimiter<Option<String>>, + allowed_auth_types: std::collections::HashSet<crate::protocol::AuthType>, } impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> @@ -338,6 +339,9 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> buffer_size: usize, read_timeout: std::time::Duration, sock_r: tokio::sync::mpsc::Receiver<S>, + allowed_auth_types: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Self { let sock_stream = sock_r .map(move |s| Connection::new(s, buffer_size)) @@ -351,6 +355,7 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> std::num::NonZeroU32::new(300).unwrap(), std::time::Duration::from_secs(60), ), + allowed_auth_types, } } @@ -374,6 +379,11 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> return Err(Error::TermTooBig { size }); } + let ty = auth.auth_type(); + if !self.allowed_auth_types.contains(&ty) { + return Err(Error::AuthTypeNotAllowed { ty }); + } + match &auth { crate::protocol::Auth::Plain { username } => { log::info!( diff --git a/src/server/tls.rs b/src/server/tls.rs index 095c26d..14effbc 100644 --- a/src/server/tls.rs +++ b/src/server/tls.rs @@ -17,10 +17,18 @@ impl Server { sock_r: tokio::sync::mpsc::Receiver< tokio_tls::Accept<tokio::net::TcpStream>, >, + allowed_login_methods: std::collections::HashSet< + crate::protocol::AuthType, + >, ) -> Self { let (tls_sock_w, tls_sock_r) = tokio::sync::mpsc::channel(100); Self { - server: super::Server::new(buffer_size, read_timeout, tls_sock_r), + server: super::Server::new( + buffer_size, + read_timeout, + tls_sock_r, + allowed_login_methods, + ), sock_r, sock_w: tls_sock_w, accepting_sockets: vec![], |