diff options
author | Jesse Luehrs <doy@tozt.net> | 2019-10-16 00:51:36 -0400 |
---|---|---|
committer | Jesse Luehrs <doy@tozt.net> | 2019-10-16 00:51:36 -0400 |
commit | 436c270260b3ba2ca7c7633c88435282e4c9b2b8 (patch) | |
tree | 1b09014b7d5ed80b9a0ce502980836538701f789 | |
parent | b0a488a351ac1c7dbf21e2231a8b553e124aee60 (diff) | |
download | teleterm-436c270260b3ba2ca7c7633c88435282e4c9b2b8.tar.gz teleterm-436c270260b3ba2ca7c7633c88435282e4c9b2b8.zip |
don't block the main thread when waiting for oauth response
-rw-r--r-- | src/client.rs | 209 | ||||
-rw-r--r-- | src/error.rs | 5 |
2 files changed, 142 insertions, 72 deletions
diff --git a/src/client.rs b/src/client.rs index ca94623..a32c92a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -27,6 +27,15 @@ enum ReadSocket< > + Send, >, ), + Processing( + crate::protocol::FramedReadHalf<S>, + Box< + dyn futures::future::Future< + Item = crate::protocol::Message, + Error = Error, + > + Send, + >, + ), } enum WriteSocket< @@ -257,7 +266,17 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> fn handle_message( &mut self, msg: crate::protocol::Message, - ) -> Result<crate::component_future::Poll<Event>> { + ) -> Result<( + crate::component_future::Poll<Event>, + Option< + Box< + dyn futures::future::Future< + Item = crate::protocol::Message, + Error = Error, + > + Send, + >, + >, + )> { msg.log("recv"); match msg { @@ -271,25 +290,31 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> } } open::that(url).context(crate::error::OpenLink)?; - let code = self - .wait_for_oauth_response(state.map(|s| s.to_string()))?; - self.send_message(crate::protocol::Message::OauthResponse { - code, - }); - Ok(crate::component_future::Poll::DidWork) + Ok(( + crate::component_future::Poll::DidWork, + Some(self.wait_for_oauth_response( + state.map(|s| s.to_string()), + )?), + )) } crate::protocol::Message::LoggedIn { .. } => { self.reset_reconnect_timer(); for msg in &self.on_login { self.to_send.push_back(msg.clone()); } - Ok(crate::component_future::Poll::Event(Event::Connect())) + Ok(( + crate::component_future::Poll::Event(Event::Connect()), + None, + )) } crate::protocol::Message::Heartbeat => { - Ok(crate::component_future::Poll::DidWork) + Ok((crate::component_future::Poll::DidWork, None)) } - _ => Ok(crate::component_future::Poll::Event( - Event::ServerMessage(msg), + _ => Ok(( + crate::component_future::Poll::Event(Event::ServerMessage( + msg, + )), + None, )), } } @@ -297,7 +322,14 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> fn wait_for_oauth_response( &self, state: Option<String>, - ) -> Result<String> { + ) -> Result< + Box< + dyn futures::future::Future< + Item = crate::protocol::Message, + Error = Error, + > + Send, + >, + > { lazy_static::lazy_static! { static ref RE: regex::Regex = regex::Regex::new( r"^GET (/[^ ]*) HTTP/[0-9.]+$" @@ -309,70 +341,66 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> .context(crate::error::ParseAddr)?; let listener = tokio::net::TcpListener::bind(&addr) .context(crate::error::Bind)?; - let (wcode, rcode) = tokio::sync::mpsc::channel(1); - let wcode2 = wcode.clone(); - let fut = listener - .incoming() - .into_future() - .map_err(|(e, _)| e) - .context(crate::error::Acceptor) - .and_then(|(sock, _)| { - let sock = sock.unwrap(); - tokio::io::lines(std::io::BufReader::new(sock)) - .into_future() - .map_err(|(e, _)| e) - .context(crate::error::ReadSocket) - }) - .and_then(move |(buf, lines)| { - let buf = buf.unwrap(); - let path = &RE.captures(&buf).unwrap()[1]; - let base = url::Url::parse(&format!( - "http://{}", - OAUTH_LISTEN_ADDRESS - )) - .unwrap(); - let url = base.join(path).unwrap(); - let mut req_code = None; - let mut req_state = None; - for (k, v) in url.query_pairs() { - if k == "code" { - req_code = Some(v.to_string()); - } - if k == "state" { - req_state = Some(v.to_string()); + Ok(Box::new( + listener + .incoming() + .into_future() + .map_err(|(e, _)| e) + .context(crate::error::Acceptor) + .and_then(|(sock, _)| { + let sock = sock.unwrap(); + tokio::io::lines(std::io::BufReader::new(sock)) + .into_future() + .map_err(|(e, _)| e) + .context(crate::error::ReadSocket) + }) + .and_then(move |(buf, lines)| { + let buf = buf.unwrap(); + let path = &RE.captures(&buf).unwrap()[1]; + let base = url::Url::parse(&format!( + "http://{}", + OAUTH_LISTEN_ADDRESS + )) + .unwrap(); + let url = base.join(path).unwrap(); + let mut req_code = None; + let mut req_state = None; + for (k, v) in url.query_pairs() { + if k == "code" { + req_code = Some(v.to_string()); + } + if k == "state" { + req_state = Some(v.to_string()); + } } - } - let res = if let Some(auth_state) = state { - if req_state.is_none() || req_state.unwrap() != auth_state - { - unimplemented!() + let res = if let Some(auth_state) = state { + if req_state.is_none() + || req_state.unwrap() != auth_state + { + unimplemented!() + } else { + Ok(req_code.unwrap()) + } } else { Ok(req_code.unwrap()) - } - } else { - Ok(req_code.unwrap()) - }; - wcode - .send(res) - .context(crate::error::SendResultChannel) - .map(|_| lines.into_inner().into_inner()) - }) - .and_then(|sock| { - let response = r"HTTP/1.1 200 OK + }; + res.map(|code| { + ( + crate::protocol::Message::oauth_response(&code), + lines.into_inner().into_inner(), + ) + }) + }) + .and_then(|(msg, sock)| { + let response = r"HTTP/1.1 200 OK authenticated successfully! now close this page and return to your terminal. "; - tokio::io::write_all(sock, response) - .context(crate::error::WriteSocket) - }) - .map(|_| ()) - .map_err(|e| { - wcode2.wait().send(Err(e)).unwrap(); - }); - tokio::spawn(fut); - // XXX we don't actually want to block the main thread here - move - // this to a background thing that we poll instead - rcode.wait().next().unwrap().unwrap() + tokio::io::write_all(sock, response) + .context(crate::error::WriteSocket) + .map(|_| msg) + }), + )) } } @@ -454,8 +482,45 @@ impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static> ReadSocket::Reading(ref mut fut) => match fut.poll() { Ok(futures::Async::Ready((msg, s))) => { self.last_server_time = std::time::Instant::now(); - self.rsock = ReadSocket::Connected(s); - self.handle_message(msg) + match self.handle_message(msg) { + Ok((poll, fut)) => { + if let Some(fut) = fut { + self.rsock = ReadSocket::Processing(s, fut); + } else { + self.rsock = ReadSocket::Connected(s); + } + Ok(poll) + } + Err(..) => { + self.reconnect(); + Ok(crate::component_future::Poll::Event( + Event::Disconnect, + )) + } + } + } + Ok(futures::Async::NotReady) => { + Ok(crate::component_future::Poll::NotReady) + } + Err(..) => { + self.reconnect(); + Ok(crate::component_future::Poll::Event( + Event::Disconnect, + )) + } + }, + ReadSocket::Processing(_, fut) => match fut.poll() { + Ok(futures::Async::Ready(msg)) => { + if let ReadSocket::Processing(s, _) = std::mem::replace( + &mut self.rsock, + ReadSocket::NotConnected, + ) { + self.rsock = ReadSocket::Connected(s); + self.send_message(msg); + } else { + unreachable!() + } + Ok(crate::component_future::Poll::DidWork) } Ok(futures::Async::NotReady) => { Ok(crate::component_future::Poll::NotReady) diff --git a/src/error.rs b/src/error.rs index 2836f32..24cd93d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -158,6 +158,11 @@ pub enum Error { source: tokio::sync::mpsc::error::UnboundedRecvError, }, + #[snafu(display("failed to read from channel: {}", source))] + ReadChannelBounded { + source: tokio::sync::mpsc::error::RecvError, + }, + #[snafu(display("failed to read from file: {}", source))] ReadFile { source: tokio::io::Error }, |