From 1e2ffb20113b230e04e6ca1b60ad4ca2ac89a3e8 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Thu, 5 Sep 2019 05:59:41 -0400 Subject: use tokio codecs for some of the protocol parsing --- Cargo.lock | 1 + Cargo.toml | 1 + src/cmd/cast.rs | 24 ++++++---- src/cmd/server.rs | 74 ++++++++++++++++++------------ src/protocol.rs | 133 +++++++++++++++++++++++++++++++++++------------------- 5 files changed, 147 insertions(+), 86 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b521c95..923217c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -648,6 +648,7 @@ dependencies = [ name = "termcast" version = "0.1.0" dependencies = [ + "bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)", "clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)", "crossterm 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.1.28 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 76ac973..5a709a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ authors = ["Jesse Luehrs "] edition = "2018" [dependencies] +bytes = "0.4" clap = "2" crossterm = "0.10" futures = "0.1" diff --git a/src/cmd/cast.rs b/src/cmd/cast.rs index 6264ddc..0654e17 100644 --- a/src/cmd/cast.rs +++ b/src/cmd/cast.rs @@ -66,13 +66,13 @@ fn run_impl() -> Result<()> { enum ReadSocket { NotConnected, - Connected(tokio::io::ReadHalf), + Connected(crate::protocol::FramedReader), ReadingMessage( Box< dyn futures::future::Future< Item = ( crate::protocol::Message, - tokio::io::ReadHalf, + crate::protocol::FramedReader, ), Error = Error, > + Send, @@ -93,16 +93,16 @@ enum WriteSocket { LoggingIn( Box< dyn futures::future::Future< - Item = tokio::io::WriteHalf, + Item = crate::protocol::FramedWriter, Error = Error, > + Send, >, ), - Connected(tokio::io::WriteHalf), + Connected(crate::protocol::FramedWriter), SendingOutput( Box< dyn futures::future::Future< - Item = tokio::io::WriteHalf, + Item = crate::protocol::FramedWriter, Error = Error, > + Send, >, @@ -111,7 +111,7 @@ enum WriteSocket { SendingHeartbeat( Box< dyn futures::future::Future< - Item = tokio::io::WriteHalf, + Item = crate::protocol::FramedWriter, Error = Error, > + Send, >, @@ -202,13 +202,17 @@ impl CastSession { Ok(futures::Async::Ready(s)) => { let (rs, ws) = s.split(); self.last_server_time = std::time::Instant::now(); - let term = - std::env::var("TERM").unwrap_or("".to_string()); + let term = std::env::var("TERM") + .unwrap_or_else(|_| "".to_string()); let fut = crate::protocol::Message::start_casting("doy", &term) - .write_async(ws) + .write_async(crate::protocol::FramedWriter::new( + ws, + )) .context(Write); - self.rsock = ReadSocket::Connected(rs); + self.rsock = ReadSocket::Connected( + crate::protocol::FramedReader::new(rs), + ); self.wsock = WriteSocket::LoggingIn(Box::new(fut)); Ok(true) } diff --git a/src/cmd/server.rs b/src/cmd/server.rs index 5de1b68..cd202bb 100644 --- a/src/cmd/server.rs +++ b/src/cmd/server.rs @@ -79,13 +79,13 @@ enum SockType { } enum ReadSocket { - Connected(tokio::io::ReadHalf), + Connected(crate::protocol::FramedReader), Reading( Box< dyn futures::future::Future< Item = ( crate::protocol::Message, - tokio::io::ReadHalf, + crate::protocol::FramedReader, ), Error = Error, > + Send, @@ -94,11 +94,11 @@ enum ReadSocket { } enum WriteSocket { - Connected(tokio::io::WriteHalf), + Connected(crate::protocol::FramedWriter), Writing( Box< dyn futures::future::Future< - Item = tokio::io::WriteHalf, + Item = crate::protocol::FramedWriter, Error = Error, > + Send, >, @@ -122,8 +122,12 @@ impl Connection { fn new(s: tokio::net::tcp::TcpStream) -> Self { let (rs, ws) = s.split(); Self { - rsock: Some(ReadSocket::Connected(rs)), - wsock: Some(WriteSocket::Connected(ws)), + rsock: Some(ReadSocket::Connected( + crate::protocol::FramedReader::new(rs), + )), + wsock: Some(WriteSocket::Connected( + crate::protocol::FramedWriter::new(ws), + )), ty: SockType::Unknown, id: format!("{}", uuid::Uuid::new_v4()), @@ -201,20 +205,25 @@ impl ConnectionHandler { i += 1; } Err(e) => { - if let Error::ReadMessage { - source: + if let Error::ReadMessage { ref source } = e { + match source { crate::protocol::Error::ReadAsync { source: ref tokio_err, - }, - } = e - { - if tokio_err.kind() - == tokio::io::ErrorKind::UnexpectedEof - { - println!("disconnect"); - self.connections.swap_remove(i); - } else { - return Err(e); + } => { + if tokio_err.kind() + == tokio::io::ErrorKind::UnexpectedEof + { + println!("disconnect"); + self.connections.swap_remove(i); + } else { + return Err(e); + } + } + crate::protocol::Error::EOF => { + println!("disconnect"); + self.connections.swap_remove(i); + } + _ => return Err(e), } } else { return Err(e); @@ -262,20 +271,25 @@ impl ConnectionHandler { i += 1; } Err(e) => { - if let Error::ReadMessage { - source: + if let Error::WriteMessage { ref source } = e { + match source { crate::protocol::Error::WriteAsync { source: ref tokio_err, - }, - } = e - { - if tokio_err.kind() - == tokio::io::ErrorKind::UnexpectedEof - { - println!("disconnect"); - self.connections.swap_remove(i); - } else { - return Err(e); + } => { + if tokio_err.kind() + == tokio::io::ErrorKind::UnexpectedEof + { + println!("disconnect"); + self.connections.swap_remove(i); + } else { + return Err(e); + } + } + crate::protocol::Error::EOF => { + println!("disconnect"); + self.connections.swap_remove(i); + } + _ => return Err(e), } } else { return Err(e); diff --git a/src/protocol.rs b/src/protocol.rs index 08ede9d..82c809e 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,4 +1,6 @@ use futures::future::Future as _; +use futures::sink::Sink as _; +use futures::stream::Stream as _; use snafu::futures01::FutureExt as _; use snafu::ResultExt as _; use std::convert::{TryFrom as _, TryInto as _}; @@ -30,10 +32,47 @@ pub enum Error { #[snafu(display("invalid message type: {}", ty))] InvalidMessageType { ty: u32 }, + + #[snafu(display("eof"))] + EOF, } pub type Result = std::result::Result; +pub struct FramedReader( + tokio::codec::FramedRead< + tokio::io::ReadHalf, + tokio::codec::length_delimited::LengthDelimitedCodec, + >, +); + +impl FramedReader { + pub fn new(rs: tokio::io::ReadHalf) -> Self { + Self( + tokio::codec::length_delimited::Builder::new() + .length_field_length(4) + .new_read(rs), + ) + } +} + +pub struct FramedWriter( + tokio::codec::FramedWrite< + tokio::io::WriteHalf, + tokio::codec::length_delimited::LengthDelimitedCodec, + >, +); + +impl FramedWriter { + pub fn new(ws: tokio::io::WriteHalf) -> Self { + Self( + tokio::codec::length_delimited::Builder::new() + .length_field_length(4) + .new_write(ws), + ) + } +} + pub const PROTO_VERSION: u32 = 1; #[derive(Debug)] @@ -104,9 +143,10 @@ impl Message { Packet::read(r).and_then(Self::try_from) } - pub fn read_async( - r: R, - ) -> impl futures::future::Future { + pub fn read_async( + r: FramedReader, + ) -> impl futures::future::Future + { Packet::read_async(r).and_then(|(packet, r)| { Self::try_from(packet).map(|msg| (msg, r)) }) @@ -116,10 +156,11 @@ impl Message { Packet::from(self).write(w) } - pub fn write_async( + pub fn write_async( &self, - w: W, - ) -> impl futures::future::Future { + w: FramedWriter, + ) -> impl futures::future::Future + { Packet::from(self).write_async(w) } } @@ -131,63 +172,63 @@ struct Packet { impl Packet { fn read(mut r: R) -> Result { - let mut header_buf = [0u8; std::mem::size_of::() * 2]; - r.read_exact(&mut header_buf).context(Read)?; + let mut len_buf = [0u8; std::mem::size_of::()]; + r.read_exact(&mut len_buf).context(Read)?; + let len = u32::from_be_bytes(len_buf.try_into().unwrap()); - let (len_buf, ty_buf) = - header_buf.split_at(std::mem::size_of::()); - let len = u32::from_le_bytes(len_buf.try_into().unwrap()); - let ty = u32::from_le_bytes(ty_buf.try_into().unwrap()); let mut data = vec![0u8; len.try_into().unwrap()]; r.read_exact(&mut data).context(Read)?; + let (ty_buf, rest) = data.split_at(std::mem::size_of::()); + let ty = u32::from_be_bytes(ty_buf.try_into().unwrap()); - Ok(Packet { ty, data }) + Ok(Packet { + ty, + data: rest.to_vec(), + }) } - fn read_async( - r: R, - ) -> impl futures::future::Future { - let header_buf = [0u8; std::mem::size_of::() * 2]; - tokio::io::read_exact(r, header_buf) - .and_then(|(r, buf)| { - let (len_buf, ty_buf) = + fn read_async( + r: FramedReader, + ) -> impl futures::future::Future + { + r.0.into_future() + .map_err(|(e, _)| Error::ReadAsync { source: e }) + .and_then(|(data, r)| match data { + Some(data) => Ok((data, r)), + None => Err(Error::EOF), + }) + .map(|(buf, r)| { + let (ty_buf, data_buf) = buf.split_at(std::mem::size_of::()); - let len = u32::from_le_bytes(len_buf.try_into().unwrap()); - let ty = u32::from_le_bytes(ty_buf.try_into().unwrap()); - let body_buf = vec![0u8; len.try_into().unwrap()]; - tokio::io::read_exact(r, body_buf).map(move |(r, buf)| { - ( - Packet { - ty, - data: buf.to_vec(), - }, - r, - ) - }) + let ty = u32::from_be_bytes(ty_buf.try_into().unwrap()); + let data = data_buf.to_vec(); + (Packet { ty, data }, FramedReader(r)) }) - .context(ReadAsync) } fn write(&self, mut w: W) -> Result<()> { - Ok(w.write_all(&self.as_bytes()).context(Write)?) + let bytes = self.as_bytes(); + let len: u32 = bytes.len().try_into().unwrap(); + let len_buf = len.to_be_bytes(); + let buf: Vec = + len_buf.iter().chain(bytes.iter()).copied().collect(); + Ok(w.write_all(&buf).context(Write)?) } - fn write_async( + fn write_async( &self, - w: W, - ) -> impl futures::future::Future { - tokio::io::write_all(w, self.as_bytes()) - .map(|(w, _)| w) + w: FramedWriter, + ) -> impl futures::future::Future + { + w.0.send(bytes::Bytes::from(self.as_bytes())) + .map(FramedWriter) .context(WriteAsync) } fn as_bytes(&self) -> Vec { - let len: u32 = self.data.len().try_into().unwrap(); - let len_buf = len.to_le_bytes(); - let ty = self.ty.to_le_bytes(); - len_buf + self.ty + .to_be_bytes() .iter() - .chain(ty.iter()) .chain(self.data.iter()) .cloned() .collect() @@ -201,7 +242,7 @@ impl From<&Message> for Packet { n.try_into().unwrap() } fn write_u32(val: u32, data: &mut Vec) { - data.extend_from_slice(&val.to_le_bytes()); + data.extend_from_slice(&val.to_be_bytes()); } fn write_bytes(val: &[u8], data: &mut Vec) { write_u32(u32_from_usize(val.len()), data); @@ -286,7 +327,7 @@ impl std::convert::TryFrom for Message { fn try_from(packet: Packet) -> Result { fn read_u32(data: &[u8]) -> Result<(u32, &[u8])> { let (buf, rest) = data.split_at(std::mem::size_of::()); - let val = u32::from_le_bytes(buf.try_into().context(ParseInt)?); + let val = u32::from_be_bytes(buf.try_into().context(ParseInt)?); Ok((val, rest)) } fn read_bytes(data: &[u8]) -> Result<(Vec, &[u8])> { -- cgit v1.2.3-54-g00ecf