aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJesse Luehrs <doy@tozt.net>2019-10-17 02:53:41 -0400
committerJesse Luehrs <doy@tozt.net>2019-10-17 02:53:41 -0400
commitf4e6745c78c016b6103bc595ff99514b5adf973b (patch)
tree8b29b42f75d68c6e8ccf990156ec9490de2ad541
parent5336c7ccbe7bee81b5c496f55f8fb9b1835a1003 (diff)
downloadteleterm-f4e6745c78c016b6103bc595ff99514b5adf973b.tar.gz
teleterm-f4e6745c78c016b6103bc595ff99514b5adf973b.zip
allow restricting the accepted auth type list
-rw-r--r--src/cmd/server.rs55
-rw-r--r--src/error.rs6
-rw-r--r--src/protocol.rs21
-rw-r--r--src/server.rs10
-rw-r--r--src/server/tls.rs10
5 files changed, 94 insertions, 8 deletions
diff --git a/src/cmd/server.rs b/src/cmd/server.rs
index 07d9321..be70da6 100644
--- a/src/cmd/server.rs
+++ b/src/cmd/server.rs
@@ -1,4 +1,5 @@
use crate::prelude::*;
+use std::convert::TryFrom as _;
use std::io::Read as _;
pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> {
@@ -23,6 +24,12 @@ pub fn cmd<'a, 'b>(app: clap::App<'a, 'b>) -> clap::App<'a, 'b> {
.long("tls-identity-file")
.takes_value(true),
)
+ .arg(
+ clap::Arg::with_name("allowed-login-methods")
+ .long("allowed-login-methods")
+ .use_delimiter(true)
+ .takes_value(true),
+ )
}
pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> {
@@ -46,7 +53,20 @@ pub fn run<'a>(matches: &clap::ArgMatches<'a>) -> super::Result<()> {
},
)?;
let tls_identity_file = matches.value_of("tls-identity-file");
- run_impl(address, buffer_size, read_timeout, tls_identity_file)
+ let allowed_login_methods =
+ matches.values_of("allowed-login-methods").map_or_else(
+ || Ok(crate::protocol::AuthType::iter().collect()),
+ |methods| {
+ methods.map(crate::protocol::AuthType::try_from).collect()
+ },
+ )?;
+ run_impl(
+ address,
+ buffer_size,
+ read_timeout,
+ tls_identity_file,
+ allowed_login_methods,
+ )
}
fn run_impl(
@@ -54,6 +74,9 @@ fn run_impl(
buffer_size: usize,
read_timeout: std::time::Duration,
tls_identity_file: Option<&str>,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<()> {
let (acceptor, server) =
if let Some(tls_identity_file) = tls_identity_file {
@@ -62,9 +85,15 @@ fn run_impl(
buffer_size,
read_timeout,
tls_identity_file,
+ allowed_login_methods,
)?
} else {
- create_server(address, buffer_size, read_timeout)?
+ create_server(
+ address,
+ buffer_size,
+ read_timeout,
+ allowed_login_methods,
+ )?
};
tokio::run(futures::future::lazy(move || {
tokio::spawn(server.map_err(|e| {
@@ -82,6 +111,9 @@ fn create_server(
address: std::net::SocketAddr,
buffer_size: usize,
read_timeout: std::time::Duration,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<(
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
@@ -97,8 +129,12 @@ fn create_server(
.try_send(sock)
.context(crate::error::SendSocketChannel)
});
- let server =
- crate::server::Server::new(buffer_size, read_timeout, sock_r);
+ let server = crate::server::Server::new(
+ buffer_size,
+ read_timeout,
+ sock_r,
+ allowed_login_methods,
+ );
Ok((Box::new(acceptor), Box::new(server)))
}
@@ -107,6 +143,9 @@ fn create_server_tls(
buffer_size: usize,
read_timeout: std::time::Duration,
tls_identity_file: &str,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Result<(
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
Box<dyn futures::future::Future<Item = (), Error = Error> + Send>,
@@ -135,7 +174,11 @@ fn create_server_tls(
.try_send(sock)
.map_err(|_| Error::SendSocketChannelTls {})
});
- let server =
- crate::server::tls::Server::new(buffer_size, read_timeout, sock_r);
+ let server = crate::server::tls::Server::new(
+ buffer_size,
+ read_timeout,
+ sock_r,
+ allowed_login_methods,
+ );
Ok((Box::new(acceptor), Box::new(server)))
}
diff --git a/src/error.rs b/src/error.rs
index c1d2bd7..4166cc4 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -4,6 +4,9 @@ pub enum Error {
#[snafu(display("failed to accept: {}", source))]
Acceptor { source: tokio::io::Error },
+ #[snafu(display("auth type not allowed: {:?}", ty))]
+ AuthTypeNotAllowed { ty: crate::protocol::AuthType },
+
#[snafu(display("failed to bind: {}", source))]
Bind { source: tokio::io::Error },
@@ -67,6 +70,9 @@ pub enum Error {
#[snafu(display("invalid auth type: {}", ty))]
InvalidAuthType { ty: u8 },
+ #[snafu(display("invalid auth type: {}", ty))]
+ InvalidAuthTypeStr { ty: String },
+
#[snafu(display("invalid message type: {}", ty))]
InvalidMessageType { ty: u8 },
diff --git a/src/protocol.rs b/src/protocol.rs
index c76a8a7..324fffd 100644
--- a/src/protocol.rs
+++ b/src/protocol.rs
@@ -74,6 +74,13 @@ impl AuthType {
Self::RecurseCenter => true,
}
}
+
+ pub fn iter() -> impl Iterator<Item = Self> {
+ (0..=255)
+ .map(Self::try_from)
+ .take_while(std::result::Result::is_ok)
+ .map(std::result::Result::unwrap)
+ }
}
impl std::convert::TryFrom<u8> for AuthType {
@@ -88,6 +95,18 @@ impl std::convert::TryFrom<u8> for AuthType {
}
}
+impl std::convert::TryFrom<&str> for AuthType {
+ type Error = Error;
+
+ fn try_from(s: &str) -> Result<Self> {
+ Ok(match s {
+ s if Self::Plain.name() == s => Self::Plain,
+ s if Self::RecurseCenter.name() == s => Self::RecurseCenter,
+ _ => return Err(Error::InvalidAuthTypeStr { ty: s.to_string() }),
+ })
+ }
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Auth {
Plain { username: String },
@@ -115,7 +134,7 @@ impl Auth {
self.auth_type().name()
}
- fn auth_type(&self) -> AuthType {
+ pub fn auth_type(&self) -> AuthType {
match self {
Self::Plain { .. } => AuthType::Plain,
Self::RecurseCenter { .. } => AuthType::RecurseCenter,
diff --git a/src/server.rs b/src/server.rs
index 76dcfe9..0116fff 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -329,6 +329,7 @@ pub struct Server<
>,
connections: std::collections::HashMap<String, Connection<S>>,
rate_limiter: ratelimit_meter::KeyedRateLimiter<Option<String>>,
+ allowed_auth_types: std::collections::HashSet<crate::protocol::AuthType>,
}
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
@@ -338,6 +339,9 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
buffer_size: usize,
read_timeout: std::time::Duration,
sock_r: tokio::sync::mpsc::Receiver<S>,
+ allowed_auth_types: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Self {
let sock_stream = sock_r
.map(move |s| Connection::new(s, buffer_size))
@@ -351,6 +355,7 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
std::num::NonZeroU32::new(300).unwrap(),
std::time::Duration::from_secs(60),
),
+ allowed_auth_types,
}
}
@@ -374,6 +379,11 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static>
return Err(Error::TermTooBig { size });
}
+ let ty = auth.auth_type();
+ if !self.allowed_auth_types.contains(&ty) {
+ return Err(Error::AuthTypeNotAllowed { ty });
+ }
+
match &auth {
crate::protocol::Auth::Plain { username } => {
log::info!(
diff --git a/src/server/tls.rs b/src/server/tls.rs
index 095c26d..14effbc 100644
--- a/src/server/tls.rs
+++ b/src/server/tls.rs
@@ -17,10 +17,18 @@ impl Server {
sock_r: tokio::sync::mpsc::Receiver<
tokio_tls::Accept<tokio::net::TcpStream>,
>,
+ allowed_login_methods: std::collections::HashSet<
+ crate::protocol::AuthType,
+ >,
) -> Self {
let (tls_sock_w, tls_sock_r) = tokio::sync::mpsc::channel(100);
Self {
- server: super::Server::new(buffer_size, read_timeout, tls_sock_r),
+ server: super::Server::new(
+ buffer_size,
+ read_timeout,
+ tls_sock_r,
+ allowed_login_methods,
+ ),
sock_r,
sock_w: tls_sock_w,
accepting_sockets: vec![],