diff options
Diffstat (limited to 'src/cmd/server.rs')
-rw-r--r-- | src/cmd/server.rs | 148 |
1 files changed, 69 insertions, 79 deletions
diff --git a/src/cmd/server.rs b/src/cmd/server.rs index bd1b4f4..8acd011 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -27,43 +27,30 @@ impl crate::config::Config for Config { fn run( &self, - ) -> Box<dyn futures::future::Future<Item = (), Error = Error> + Send> { - let (acceptor, server) = - if let Some(tls_identity_file) = &self.server.tls_identity_file { - match create_server_tls( - self.server.listen_address, - self.server.buffer_size, - self.server.read_timeout, - tls_identity_file, - self.server.allowed_login_methods.clone(), - self.oauth_configs.clone(), - self.server.uid, - self.server.gid, - ) { - Ok(futs) => futs, - Err(e) => return Box::new(futures::future::err(e)), - } - } else { - match create_server( - self.server.listen_address, - self.server.buffer_size, - self.server.read_timeout, - self.server.allowed_login_methods.clone(), - self.oauth_configs.clone(), - self.server.uid, - self.server.gid, - ) { - Ok(futs) => futs, - Err(e) => return Box::new(futures::future::err(e)), - } - }; - Box::new(futures::future::lazy(move || { - tokio::spawn(server.map_err(|e| { - log::error!("{}", e); - })); - - acceptor - })) + ) -> Box<dyn futures::future::Future<Item = (), Error = Error> + Send> + { + if let Some(tls_identity_file) = &self.server.tls_identity_file { + create_server_tls( + self.server.listen_address, + self.server.buffer_size, + self.server.read_timeout, + tls_identity_file, + self.server.allowed_login_methods.clone(), + self.oauth_configs.clone(), + self.server.uid, + self.server.gid, + ) + } else { + create_server( + self.server.listen_address, + self.server.buffer_size, + self.server.read_timeout, + self.server.allowed_login_methods.clone(), + self.oauth_configs.clone(), + self.server.uid, + self.server.gid, + ) + } } } @@ -97,31 +84,22 @@ fn create_server( >, uid: Option<users::uid_t>, gid: Option<users::gid_t>, -) -> Result<( - Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, - Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, -)> { - let (mut sock_w, sock_r) = tokio::sync::mpsc::channel(100); - let listener = tokio::net::TcpListener::bind(&address) - .context(crate::error::Bind { address })?; - drop_privs(uid, gid)?; - log::info!("Listening on {}", address); - let acceptor = listener - .incoming() - .context(crate::error::Acceptor) - .for_each(move |sock| { - sock_w - .try_send(sock) - .context(crate::error::SendSocketChannel) - }); +) -> Box<dyn futures::future::Future<Item = (), Error = Error> + Send> { + let listener = match listen(address, uid, gid) { + Ok(listener) => listener, + Err(e) => return Box::new(futures::future::err(e)), + }; + + let acceptor = listener.incoming().context(crate::error::Acceptor); let server = crate::server::Server::new( + Box::new(acceptor), buffer_size, read_timeout, - sock_r, allowed_login_methods, oauth_configs, ); - Ok((Box::new(acceptor), Box::new(server))) + + Box::new(server) } fn create_server_tls( @@ -138,16 +116,45 @@ fn create_server_tls( >, uid: Option<users::uid_t>, gid: Option<users::gid_t>, -) -> Result<( - Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, - Box<dyn futures::future::Future<Item = (), Error = Error> + Send>, -)> { - let (mut sock_w, sock_r) = tokio::sync::mpsc::channel(100); +) -> Box<dyn futures::future::Future<Item = (), Error = Error> + Send> { + let listener = match listen(address, uid, gid) { + Ok(listener) => listener, + Err(e) => return Box::new(futures::future::err(e)), + }; + + let tls_acceptor = match accept_tls(tls_identity_file) { + Ok(acceptor) => acceptor, + Err(e) => return Box::new(futures::future::err(e)), + }; + + let acceptor = listener + .incoming() + .context(crate::error::Acceptor) + .map(move |sock| tls_acceptor.accept(sock)); + let server = crate::server::tls::Server::new( + Box::new(acceptor), + buffer_size, + read_timeout, + allowed_login_methods, + oauth_configs, + ); + + Box::new(server) +} + +fn listen( + address: std::net::SocketAddr, + uid: Option<users::uid_t>, + gid: Option<users::gid_t>, +) -> Result<tokio::net::TcpListener> { let listener = tokio::net::TcpListener::bind(&address) .context(crate::error::Bind { address })?; drop_privs(uid, gid)?; log::info!("Listening on {}", address); + Ok(listener) +} +fn accept_tls(tls_identity_file: &str) -> Result<tokio_tls::TlsAcceptor> { let mut file = std::fs::File::open(tls_identity_file).context( crate::error::OpenFileSync { filename: tls_identity_file, @@ -160,25 +167,8 @@ fn create_server_tls( .context(crate::error::ParseIdentity)?; let acceptor = native_tls::TlsAcceptor::new(identity) .context(crate::error::CreateAcceptor)?; - let acceptor = tokio_tls::TlsAcceptor::from(acceptor); - let acceptor = listener - .incoming() - .context(crate::error::Acceptor) - .for_each(move |sock| { - let sock = acceptor.accept(sock); - sock_w - .try_send(sock) - .map_err(|_| Error::SendSocketChannelTls {}) - }); - let server = crate::server::tls::Server::new( - buffer_size, - read_timeout, - sock_r, - allowed_login_methods, - oauth_configs, - ); - Ok((Box::new(acceptor), Box::new(server))) + Ok(tokio_tls::TlsAcceptor::from(acceptor)) } fn drop_privs( |