diff options
author | Jesse Luehrs <doy@tozt.net> | 2019-11-15 13:11:07 -0500 |
---|---|---|
committer | Jesse Luehrs <doy@tozt.net> | 2019-11-15 13:11:07 -0500 |
commit | bbf15cfef8134da720a27bd71a93efcb8467025b (patch) | |
tree | aa58a5d7c1862fcdd6c8629651f664aa12c70f66 /teleterm/src/server.rs | |
parent | fe4fa53dbbb6030beae2094e33d1db008532ae3c (diff) | |
download | teleterm-bbf15cfef8134da720a27bd71a93efcb8467025b.tar.gz teleterm-bbf15cfef8134da720a27bd71a93efcb8467025b.zip |
use workspaces
Diffstat (limited to 'teleterm/src/server.rs')
-rw-r--r-- | teleterm/src/server.rs | 1073 |
1 files changed, 1073 insertions, 0 deletions
diff --git a/teleterm/src/server.rs b/teleterm/src/server.rs new file mode 100644 index 0000000..d18fa95 --- /dev/null +++ b/teleterm/src/server.rs @@ -0,0 +1,1073 @@ +use crate::prelude::*; +use tokio::util::FutureExt as _; + +pub mod tls; + +enum ReadSocket< + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static, +> { + Connected(crate::protocol::FramedReadHalf<S>), + Reading( + Box< + dyn futures::future::Future< + Item = ( + crate::protocol::Message, + crate::protocol::FramedReadHalf<S>, + ), + Error = Error, + > + Send, + >, + ), + Processing( + crate::protocol::FramedReadHalf<S>, + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + ), +} + +enum WriteSocket< + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static, +> { + Connected(crate::protocol::FramedWriteHalf<S>), + Writing( + Box< + dyn futures::future::Future< + Item = crate::protocol::FramedWriteHalf<S>, + Error = Error, + > + Send, + >, + ), +} + +#[derive(Debug, Clone)] +struct TerminalInfo { + term: String, + size: crate::term::Size, +} + +#[allow(clippy::large_enum_variant)] +// XXX https://github.com/rust-lang/rust/issues/64362 +#[allow(dead_code)] +enum ConnectionState { + Accepted, + LoggingIn { + term_info: TerminalInfo, + }, + LoggedIn { + username: String, + term_info: TerminalInfo, + }, + Streaming { + username: String, + term_info: TerminalInfo, + term: vt100::Parser, + }, + Watching { + username: String, + term_info: TerminalInfo, + watch_id: String, + }, +} + +impl ConnectionState { + fn new() -> Self { + Self::Accepted + } + + fn username(&self) -> Option<&str> { + match self { + Self::Accepted => None, + Self::LoggingIn { .. } => None, + Self::LoggedIn { username, .. } => Some(username), + Self::Streaming { username, .. } => Some(username), + Self::Watching { username, .. } => Some(username), + } + } + + fn term_info(&mut self) -> Option<&TerminalInfo> { + match self { + Self::Accepted => None, + Self::LoggingIn { term_info, .. } => Some(term_info), + Self::LoggedIn { term_info, .. } => Some(term_info), + Self::Streaming { term_info, .. } => Some(term_info), + Self::Watching { term_info, .. } => Some(term_info), + } + } + + fn term_info_mut(&mut self) -> Option<&mut TerminalInfo> { + match self { + Self::Accepted => None, + Self::LoggingIn { term_info, .. } => Some(term_info), + Self::LoggedIn { term_info, .. } => Some(term_info), + Self::Streaming { term_info, .. } => Some(term_info), + Self::Watching { term_info, .. } => Some(term_info), + } + } + + fn term(&self) -> Option<&vt100::Parser> { + match self { + Self::Accepted => None, + Self::LoggingIn { .. } => None, + Self::LoggedIn { .. } => None, + Self::Streaming { term, .. } => Some(term), + Self::Watching { .. } => None, + } + } + + fn term_mut(&mut self) -> Option<&mut vt100::Parser> { + match self { + Self::Accepted => None, + Self::LoggingIn { .. } => None, + Self::LoggedIn { .. } => None, + Self::Streaming { term, .. } => Some(term), + Self::Watching { .. } => None, + } + } + + fn watch_id(&self) -> Option<&str> { + match self { + Self::Accepted => None, + Self::LoggingIn { .. } => None, + Self::LoggedIn { .. } => None, + Self::Streaming { .. } => None, + Self::Watching { watch_id, .. } => Some(watch_id), + } + } + + fn login_plain( + &mut self, + username: &str, + term_type: &str, + size: crate::term::Size, + ) { + if let Self::Accepted = self { + *self = Self::LoggedIn { + username: username.to_string(), + term_info: TerminalInfo { + term: term_type.to_string(), + size, + }, + }; + } else { + unreachable!() + } + } + + fn login_oauth_start( + &mut self, + term_type: &str, + size: crate::term::Size, + ) { + if let Self::Accepted = self { + *self = Self::LoggingIn { + term_info: TerminalInfo { + term: term_type.to_string(), + size, + }, + }; + } else { + unreachable!() + } + } + + fn stream(&mut self) { + if let Self::LoggedIn { + username, + term_info, + } = std::mem::replace(self, Self::Accepted) + { + *self = Self::Streaming { + username, + term_info, + term: vt100::Parser::default(), + }; + } else { + unreachable!() + } + } + + fn watch(&mut self, id: &str) { + if let Self::LoggedIn { + username, + term_info, + } = std::mem::replace(self, Self::Accepted) + { + *self = Self::Watching { + username, + term_info, + watch_id: id.to_string(), + }; + } else { + unreachable!() + } + } +} + +struct Connection< + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static, +> { + id: String, + rsock: Option<ReadSocket<S>>, + wsock: Option<WriteSocket<S>>, + to_send: std::collections::VecDeque<crate::protocol::Message>, + closed: bool, + state: ConnectionState, + last_activity: std::time::Instant, + oauth_client: Option<Box<dyn crate::oauth::Oauth + Send>>, +} + +impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> + Connection<S> +{ + fn new(s: S) -> Self { + let (rs, ws) = s.split(); + let id = format!("{}", uuid::Uuid::new_v4()); + log::info!("{}: new connection", id); + + Self { + id, + rsock: Some(ReadSocket::Connected( + crate::protocol::FramedReader::new(rs), + )), + wsock: Some(WriteSocket::Connected( + crate::protocol::FramedWriter::new(ws), + )), + to_send: std::collections::VecDeque::new(), + closed: false, + state: ConnectionState::new(), + last_activity: std::time::Instant::now(), + oauth_client: None, + } + } + + fn session(&self, watchers: u32) -> Option<crate::protocol::Session> { + let (username, term_info) = match &self.state { + ConnectionState::Accepted => return None, + ConnectionState::LoggingIn { .. } => return None, + ConnectionState::LoggedIn { + username, + term_info, + } => (username, term_info), + ConnectionState::Streaming { + username, + term_info, + .. + } => (username, term_info), + ConnectionState::Watching { + username, + term_info, + .. + } => (username, term_info), + }; + let title = self + .state + .term() + .map_or("", |parser| parser.screen().title()); + + // i don't really care if things break for a connection that has been + // idle for 136 years + #[allow(clippy::cast_possible_truncation)] + Some(crate::protocol::Session { + id: self.id.clone(), + username: username.clone(), + term_type: term_info.term.clone(), + size: term_info.size, + idle_time: std::time::Instant::now() + .duration_since(self.last_activity) + .as_secs() as u32, + title: title.to_string(), + watchers, + }) + } + + fn send_message(&mut self, message: crate::protocol::Message) { + self.to_send.push_back(message); + } + + fn close(&mut self, res: Result<()>) { + let msg = match res { + Ok(()) => crate::protocol::Message::disconnected(), + Err(e) => crate::protocol::Message::error(&format!("{}", e)), + }; + self.send_message(msg); + self.closed = true; + } +} + +pub struct Server< + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static, +> { + read_timeout: std::time::Duration, + acceptor: + Box<dyn futures::stream::Stream<Item = S, Error = Error> + Send>, + connections: std::collections::HashMap<String, Connection<S>>, + rate_limiter: ratelimit_meter::KeyedRateLimiter<Option<String>>, + allowed_auth_types: std::collections::HashSet<crate::protocol::AuthType>, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, +} + +impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> + Server<S> +{ + pub fn new( + acceptor: Box< + dyn futures::stream::Stream<Item = S, Error = Error> + Send, + >, + read_timeout: std::time::Duration, + allowed_auth_types: std::collections::HashSet< + crate::protocol::AuthType, + >, + oauth_configs: std::collections::HashMap< + crate::protocol::AuthType, + crate::oauth::Config, + >, + ) -> Self { + Self { + read_timeout, + acceptor, + connections: std::collections::HashMap::new(), + rate_limiter: ratelimit_meter::KeyedRateLimiter::new( + std::num::NonZeroU32::new(300).unwrap(), + std::time::Duration::from_secs(60), + ), + allowed_auth_types, + oauth_configs, + } + } + + fn handle_message_login( + &mut self, + conn: &mut Connection<S>, + auth: &crate::protocol::Auth, + term_type: &str, + size: crate::term::Size, + ) -> Result< + Option< + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + >, + > { + if size.rows >= 1000 || size.cols >= 1000 { + return Err(Error::TermTooBig { size }); + } + + let ty = auth.auth_type(); + if !self.allowed_auth_types.contains(&ty) { + return Err(Error::AuthTypeNotAllowed { ty }); + } + + match &auth { + crate::protocol::Auth::Plain { username } => { + log::info!( + "{}: login({}, {})", + auth.name(), + conn.id, + username + ); + conn.state.login_plain(username, term_type, size); + conn.send_message(crate::protocol::Message::logged_in( + username, + )); + } + oauth if oauth.is_oauth() => { + let config = self.oauth_configs.get(&ty).context( + crate::error::AuthTypeMissingOauthConfig { ty }, + )?; + let (refresh, client) = match oauth { + crate::protocol::Auth::RecurseCenter { id } => ( + id.is_some(), + Box::new(crate::oauth::RecurseCenter::new( + config.clone(), + &id.clone().unwrap_or_else(|| { + format!("{}", uuid::Uuid::new_v4()) + }), + )), + ), + _ => unreachable!(), + }; + + conn.oauth_client = Some(client); + let client = conn.oauth_client.as_ref().unwrap(); + + log::info!( + "{}: login(oauth({}), {:?})", + conn.id, + auth.name(), + client.user_id() + ); + + let token_filename = client.server_token_file(true); + if let (Some(token_filename), true) = + (token_filename, refresh) + { + let term_type = term_type.to_string(); + let client = conn.oauth_client.take().unwrap(); + let fut = tokio::fs::File::open(token_filename.clone()) + .with_context(move || crate::error::OpenFile { + filename: token_filename + .to_string_lossy() + .to_string(), + }) + .and_then(|file| { + tokio::io::lines(std::io::BufReader::new(file)) + .into_future() + .map_err(|(e, _)| e) + .context(crate::error::ReadFile) + }) + .and_then(|(refresh_token, _)| { + // XXX unwrap here isn't super safe + let refresh_token = refresh_token.unwrap(); + client + .get_access_token_from_refresh_token( + refresh_token.trim(), + ) + .and_then(|access_token| { + client.get_username_from_access_token( + &access_token, + ) + }) + }) + .map(move |username| { + ( + ConnectionState::LoggedIn { + username: username.clone(), + term_info: TerminalInfo { + term: term_type, + size, + }, + }, + crate::protocol::Message::logged_in( + &username, + ), + ) + }); + return Ok(Some(Box::new(fut))); + } else { + conn.state.login_oauth_start(term_type, size); + let authorize_url = client.generate_authorize_url(); + let user_id = client.user_id().to_string(); + conn.send_message( + crate::protocol::Message::oauth_request( + &authorize_url, + &user_id, + ), + ); + } + } + _ => unreachable!(), + } + + Ok(None) + } + + fn handle_message_start_streaming( + &mut self, + conn: &mut Connection<S>, + ) -> Result<()> { + let username = conn.state.username().unwrap(); + + log::info!("{}: stream({})", conn.id, username); + conn.state.stream(); + + Ok(()) + } + + fn handle_message_start_watching( + &mut self, + conn: &mut Connection<S>, + id: String, + ) -> Result<()> { + let username = conn.state.username().unwrap(); + + if let Some(stream_conn) = self.connections.get(&id) { + let data = stream_conn + .state + .term() + .map(|parser| parser.screen().contents_formatted()) + .ok_or_else(|| Error::InvalidWatchId { + id: id.to_string(), + })?; + + log::info!("{}: watch({}, {})", conn.id, username, id); + conn.state.watch(&id); + conn.send_message(crate::protocol::Message::terminal_output( + &data, + )); + + Ok(()) + } else { + Err(Error::InvalidWatchId { id }) + } + } + + fn handle_message_heartbeat( + &mut self, + conn: &mut Connection<S>, + ) -> Result<()> { + conn.send_message(crate::protocol::Message::heartbeat()); + + Ok(()) + } + + fn handle_message_terminal_output( + &mut self, + conn: &mut Connection<S>, + data: &[u8], + ) -> Result<()> { + let parser = conn.state.term_mut().unwrap(); + + let screen = parser.screen().clone(); + parser.process(data); + let diff = parser.screen().contents_diff(&screen); + for watch_conn in self.watchers_mut() { + let watch_id = watch_conn.state.watch_id().unwrap(); + if conn.id == watch_id { + watch_conn.send_message( + crate::protocol::Message::terminal_output(&diff), + ); + } + } + + conn.last_activity = std::time::Instant::now(); + + Ok(()) + } + + fn handle_message_list_sessions( + &mut self, + conn: &mut Connection<S>, + ) -> Result<()> { + let mut watcher_counts = std::collections::HashMap::new(); + for watcher in self.watchers() { + let watch_id = + if let ConnectionState::Watching { watch_id, .. } = + &watcher.state + { + watch_id + } else { + unreachable!() + }; + watcher_counts.insert( + watch_id, + *watcher_counts.get(&watch_id).unwrap_or(&0) + 1, + ); + } + let sessions: Vec<_> = self + .streamers() + .flat_map(|streamer| { + streamer + .session(*watcher_counts.get(&streamer.id).unwrap_or(&0)) + }) + .collect(); + conn.send_message(crate::protocol::Message::sessions(&sessions)); + + Ok(()) + } + + fn handle_message_resize( + &mut self, + conn: &mut Connection<S>, + size: crate::term::Size, + ) -> Result<()> { + let term_info = conn.state.term_info_mut().unwrap(); + term_info.size = size; + + if let Some(parser) = conn.state.term_mut() { + parser.set_size(size.rows, size.cols); + } + + Ok(()) + } + + fn handle_message_oauth_response( + &mut self, + conn: &mut Connection<S>, + code: &str, + ) -> Result< + Option< + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + >, + > { + let client = conn.oauth_client.take().ok_or_else(|| { + Error::UnexpectedMessage { + message: crate::protocol::Message::oauth_response(code), + } + })?; + + let term_info = conn.state.term_info().unwrap().clone(); + let fut = client + .get_access_token_from_auth_code(code) + .and_then(|token| client.get_username_from_access_token(&token)) + .map(|username| { + ( + ConnectionState::LoggedIn { + term_info, + username: username.clone(), + }, + crate::protocol::Message::logged_in(&username), + ) + }); + + Ok(Some(Box::new(fut))) + } + + fn handle_accepted_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result< + Option< + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + >, + > { + match message { + crate::protocol::Message::Login { + auth, + term_type, + size, + .. + } => self.handle_message_login(conn, &auth, &term_type, size), + m => Err(Error::UnauthenticatedMessage { message: m }), + } + } + + fn handle_logging_in_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result< + Option< + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + >, + > { + match message { + crate::protocol::Message::OauthResponse { code } => { + self.handle_message_oauth_response(conn, &code) + } + m => Err(Error::UnauthenticatedMessage { message: m }), + } + } + + fn handle_logged_in_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result<()> { + match message { + crate::protocol::Message::Heartbeat => { + self.handle_message_heartbeat(conn) + } + crate::protocol::Message::Resize { size } => { + self.handle_message_resize(conn, size) + } + crate::protocol::Message::ListSessions => { + self.handle_message_list_sessions(conn) + } + crate::protocol::Message::StartStreaming => { + self.handle_message_start_streaming(conn) + } + crate::protocol::Message::StartWatching { id } => { + self.handle_message_start_watching(conn, id) + } + m => Err(crate::error::Error::UnexpectedMessage { message: m }), + } + } + + fn handle_streaming_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result<()> { + match message { + crate::protocol::Message::Heartbeat => { + self.handle_message_heartbeat(conn) + } + crate::protocol::Message::Resize { size } => { + self.handle_message_resize(conn, size) + } + crate::protocol::Message::TerminalOutput { data } => { + self.handle_message_terminal_output(conn, &data) + } + m => Err(crate::error::Error::UnexpectedMessage { message: m }), + } + } + + fn handle_watching_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result<()> { + match message { + crate::protocol::Message::Heartbeat => { + self.handle_message_heartbeat(conn) + } + crate::protocol::Message::Resize { size } => { + self.handle_message_resize(conn, size) + } + m => Err(crate::error::Error::UnexpectedMessage { message: m }), + } + } + + fn handle_disconnect(&mut self, conn: &mut Connection<S>) { + if let Some(username) = conn.state.username() { + log::info!("{}: disconnect({})", conn.id, username); + } else { + log::info!("{}: disconnect", conn.id); + } + + for watch_conn in self.watchers_mut() { + let watch_id = watch_conn.state.watch_id().unwrap(); + if conn.id == watch_id { + watch_conn.close(Ok(())); + } + } + } + + fn handle_message( + &mut self, + conn: &mut Connection<S>, + message: crate::protocol::Message, + ) -> Result< + Option< + Box< + dyn futures::future::Future< + Item = (ConnectionState, crate::protocol::Message), + Error = Error, + > + Send, + >, + >, + > { + if let crate::protocol::Message::TerminalOutput { .. } = message { + // do nothing, we expect TerminalOutput spam + } else { + let username = + conn.state.username().map(std::string::ToString::to_string); + if self.rate_limiter.check(username).is_err() { + let display_name = + conn.state.username().unwrap_or("(non-logged-in users)"); + log::info!("{}: ratelimit({})", conn.id, display_name); + return Err(Error::RateLimited); + } + } + + log::debug!("{}: recv({})", conn.id, message.format_log()); + + match conn.state { + ConnectionState::Accepted { .. } => { + self.handle_accepted_message(conn, message) + } + ConnectionState::LoggingIn { .. } => { + self.handle_logging_in_message(conn, message) + } + ConnectionState::LoggedIn { .. } => { + self.handle_logged_in_message(conn, message).map(|_| None) + } + ConnectionState::Streaming { .. } => { + self.handle_streaming_message(conn, message).map(|_| None) + } + ConnectionState::Watching { .. } => { + self.handle_watching_message(conn, message).map(|_| None) + } + } + } + + fn poll_read_connection( + &mut self, + conn: &mut Connection<S>, + ) -> component_future::Poll<(), Error> { + match &mut conn.rsock { + Some(ReadSocket::Connected(..)) => { + if let Some(ReadSocket::Connected(s)) = conn.rsock.take() { + let fut = Box::new( + crate::protocol::Message::read_async(s) + .timeout(self.read_timeout) + .context(crate::error::ReadMessageWithTimeout), + ); + conn.rsock = Some(ReadSocket::Reading(fut)); + } else { + unreachable!() + } + Ok(component_future::Async::DidWork) + } + Some(ReadSocket::Reading(fut)) => match fut.poll() { + Ok(futures::Async::Ready((msg, s))) => { + let res = self.handle_message(conn, msg); + match res { + Ok(Some(fut)) => { + conn.rsock = Some(ReadSocket::Processing(s, fut)); + } + Ok(None) => { + conn.rsock = Some(ReadSocket::Connected(s)); + } + e @ Err(..) => { + conn.close(e.map(|_| ())); + conn.rsock = Some(ReadSocket::Connected(s)); + } + } + Ok(component_future::Async::DidWork) + } + Ok(futures::Async::NotReady) => { + Ok(component_future::Async::NotReady) + } + Err(e) => classify_connection_error(e), + }, + Some(ReadSocket::Processing(_, fut)) => { + let (state, msg) = component_future::try_ready!(fut.poll()); + if let Some(ReadSocket::Processing(s, _)) = conn.rsock.take() + { + conn.state = state; + conn.send_message(msg); + conn.rsock = Some(ReadSocket::Connected(s)); + } else { + unreachable!() + } + Ok(component_future::Async::DidWork) + } + _ => Ok(component_future::Async::NothingToDo), + } + } + + fn poll_write_connection( + &mut self, + conn: &mut Connection<S>, + ) -> component_future::Poll<(), Error> { + match &mut conn.wsock { + Some(WriteSocket::Connected(..)) => { + if let Some(msg) = conn.to_send.pop_front() { + if let Some(WriteSocket::Connected(s)) = conn.wsock.take() + { + log::debug!( + "{}: send({})", + conn.id, + msg.format_log() + ); + let fut = msg + .write_async(s) + .timeout(self.read_timeout) + .context(crate::error::WriteMessageWithTimeout); + conn.wsock = + Some(WriteSocket::Writing(Box::new(fut))); + } else { + unreachable!() + } + Ok(component_future::Async::DidWork) + } else if conn.closed { + Ok(component_future::Async::Ready(())) + } else { + Ok(component_future::Async::NothingToDo) + } + } + Some(WriteSocket::Writing(fut)) => match fut.poll() { + Ok(futures::Async::Ready(s)) => { + conn.wsock = Some(WriteSocket::Connected(s)); + Ok(component_future::Async::DidWork) + } + Ok(futures::Async::NotReady) => { + Ok(component_future::Async::NotReady) + } + Err(e) => classify_connection_error(e), + }, + _ => Ok(component_future::Async::NothingToDo), + } + } + + fn streamers(&self) -> impl Iterator<Item = &Connection<S>> { + self.connections.values().filter(|conn| match conn.state { + ConnectionState::Streaming { .. } => true, + _ => false, + }) + } + + fn watchers(&self) -> impl Iterator<Item = &Connection<S>> { + self.connections.values().filter(|conn| match conn.state { + ConnectionState::Watching { .. } => true, + _ => false, + }) + } + + fn watchers_mut(&mut self) -> impl Iterator<Item = &mut Connection<S>> { + self.connections + .values_mut() + .filter(|conn| match conn.state { + ConnectionState::Watching { .. } => true, + _ => false, + }) + } +} + +impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> + Server<S> +{ + const POLL_FNS: + &'static [&'static dyn for<'a> Fn( + &'a mut Self, + ) + -> component_future::Poll< + (), + Error, + >] = &[&Self::poll_accept, &Self::poll_read, &Self::poll_write]; + + fn poll_accept(&mut self) -> component_future::Poll<(), Error> { + if let Some(sock) = component_future::try_ready!(self.acceptor.poll()) + { + let conn = Connection::new(sock); + self.connections.insert(conn.id.to_string(), conn); + Ok(component_future::Async::DidWork) + } else { + unreachable!() + } + } + + fn poll_read(&mut self) -> component_future::Poll<(), Error> { + let mut did_work = false; + let mut not_ready = false; + + let keys: Vec<_> = self.connections.keys().cloned().collect(); + for key in keys { + let mut conn = self.connections.remove(&key).unwrap(); + match self.poll_read_connection(&mut conn) { + Ok(component_future::Async::Ready(())) => { + self.handle_disconnect(&mut conn); + continue; + } + Ok(component_future::Async::DidWork) => { + did_work = true; + } + Ok(component_future::Async::NotReady) => { + not_ready = true; + } + Err(e) => { + log::error!( + "error reading from active connection: {}", + e + ); + continue; + } + _ => {} + } + self.connections.insert(key.to_string(), conn); + } + + if did_work { + Ok(component_future::Async::DidWork) + } else if not_ready { + Ok(component_future::Async::NotReady) + } else { + Ok(component_future::Async::NothingToDo) + } + } + + fn poll_write(&mut self) -> component_future::Poll<(), Error> { + let mut did_work = false; + let mut not_ready = false; + + let keys: Vec<_> = self.connections.keys().cloned().collect(); + for key in keys { + let mut conn = self.connections.remove(&key).unwrap(); + match self.poll_write_connection(&mut conn) { + Ok(component_future::Async::Ready(())) => { + self.handle_disconnect(&mut conn); + continue; + } + Ok(component_future::Async::DidWork) => { + did_work = true; + } + Ok(component_future::Async::NotReady) => { + not_ready = true; + } + Err(e) => { + log::error!( + "error reading from active connection: {}", + e + ); + continue; + } + _ => {} + } + self.connections.insert(key.to_string(), conn); + } + + if did_work { + Ok(component_future::Async::DidWork) + } else if not_ready { + Ok(component_future::Async::NotReady) + } else { + Ok(component_future::Async::NothingToDo) + } + } +} + +fn classify_connection_error(e: Error) -> component_future::Poll<(), Error> { + let source = match e { + Error::ReadMessageWithTimeout { source } => source, + Error::WriteMessageWithTimeout { source } => source, + _ => return Err(e), + }; + + if source.is_inner() { + let source = source.into_inner().unwrap(); + let tokio_err = match source { + Error::ReadPacket { + source: ref tokio_err, + } => tokio_err, + Error::WritePacket { + source: ref tokio_err, + } => tokio_err, + Error::EOF => { + return Ok(component_future::Async::Ready(())); + } + _ => { + return Err(source); + } + }; + + if tokio_err.kind() == tokio::io::ErrorKind::UnexpectedEof { + Ok(component_future::Async::Ready(())) + } else { + Err(source) + } + } else if source.is_elapsed() { + Err(Error::Timeout) + } else { + let source = source.into_timer().unwrap(); + Err(Error::TimerReadTimeout { source }) + } +} + +#[must_use = "futures do nothing unless polled"] +impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> + futures::future::Future for Server<S> +{ + type Item = (); + type Error = Error; + + fn poll(&mut self) -> futures::Poll<Self::Item, Self::Error> { + component_future::poll_future(self, Self::POLL_FNS) + } +} |