aboutsummaryrefslogtreecommitdiffstats
path: root/src/protocol.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol.rs')
-rw-r--r--src/protocol.rs133
1 files changed, 87 insertions, 46 deletions
diff --git a/src/protocol.rs b/src/protocol.rs
index 08ede9d..82c809e 100644
--- a/src/protocol.rs
+++ b/src/protocol.rs
@@ -1,4 +1,6 @@
use futures::future::Future as _;
+use futures::sink::Sink as _;
+use futures::stream::Stream as _;
use snafu::futures01::FutureExt as _;
use snafu::ResultExt as _;
use std::convert::{TryFrom as _, TryInto as _};
@@ -30,10 +32,47 @@ pub enum Error {
#[snafu(display("invalid message type: {}", ty))]
InvalidMessageType { ty: u32 },
+
+ #[snafu(display("eof"))]
+ EOF,
}
pub type Result<T> = std::result::Result<T, Error>;
+pub struct FramedReader(
+ tokio::codec::FramedRead<
+ tokio::io::ReadHalf<tokio::net::tcp::TcpStream>,
+ tokio::codec::length_delimited::LengthDelimitedCodec,
+ >,
+);
+
+impl FramedReader {
+ pub fn new(rs: tokio::io::ReadHalf<tokio::net::tcp::TcpStream>) -> Self {
+ Self(
+ tokio::codec::length_delimited::Builder::new()
+ .length_field_length(4)
+ .new_read(rs),
+ )
+ }
+}
+
+pub struct FramedWriter(
+ tokio::codec::FramedWrite<
+ tokio::io::WriteHalf<tokio::net::tcp::TcpStream>,
+ tokio::codec::length_delimited::LengthDelimitedCodec,
+ >,
+);
+
+impl FramedWriter {
+ pub fn new(ws: tokio::io::WriteHalf<tokio::net::tcp::TcpStream>) -> Self {
+ Self(
+ tokio::codec::length_delimited::Builder::new()
+ .length_field_length(4)
+ .new_write(ws),
+ )
+ }
+}
+
pub const PROTO_VERSION: u32 = 1;
#[derive(Debug)]
@@ -104,9 +143,10 @@ impl Message {
Packet::read(r).and_then(Self::try_from)
}
- pub fn read_async<R: tokio::io::AsyncRead>(
- r: R,
- ) -> impl futures::future::Future<Item = (Self, R), Error = Error> {
+ pub fn read_async(
+ r: FramedReader,
+ ) -> impl futures::future::Future<Item = (Self, FramedReader), Error = Error>
+ {
Packet::read_async(r).and_then(|(packet, r)| {
Self::try_from(packet).map(|msg| (msg, r))
})
@@ -116,10 +156,11 @@ impl Message {
Packet::from(self).write(w)
}
- pub fn write_async<W: tokio::io::AsyncWrite>(
+ pub fn write_async(
&self,
- w: W,
- ) -> impl futures::future::Future<Item = W, Error = Error> {
+ w: FramedWriter,
+ ) -> impl futures::future::Future<Item = FramedWriter, Error = Error>
+ {
Packet::from(self).write_async(w)
}
}
@@ -131,63 +172,63 @@ struct Packet {
impl Packet {
fn read<R: std::io::Read>(mut r: R) -> Result<Self> {
- let mut header_buf = [0u8; std::mem::size_of::<u32>() * 2];
- r.read_exact(&mut header_buf).context(Read)?;
+ let mut len_buf = [0u8; std::mem::size_of::<u32>()];
+ r.read_exact(&mut len_buf).context(Read)?;
+ let len = u32::from_be_bytes(len_buf.try_into().unwrap());
- let (len_buf, ty_buf) =
- header_buf.split_at(std::mem::size_of::<u32>());
- let len = u32::from_le_bytes(len_buf.try_into().unwrap());
- let ty = u32::from_le_bytes(ty_buf.try_into().unwrap());
let mut data = vec![0u8; len.try_into().unwrap()];
r.read_exact(&mut data).context(Read)?;
+ let (ty_buf, rest) = data.split_at(std::mem::size_of::<u32>());
+ let ty = u32::from_be_bytes(ty_buf.try_into().unwrap());
- Ok(Packet { ty, data })
+ Ok(Packet {
+ ty,
+ data: rest.to_vec(),
+ })
}
- fn read_async<R: tokio::io::AsyncRead>(
- r: R,
- ) -> impl futures::future::Future<Item = (Self, R), Error = Error> {
- let header_buf = [0u8; std::mem::size_of::<u32>() * 2];
- tokio::io::read_exact(r, header_buf)
- .and_then(|(r, buf)| {
- let (len_buf, ty_buf) =
+ fn read_async(
+ r: FramedReader,
+ ) -> impl futures::future::Future<Item = (Self, FramedReader), Error = Error>
+ {
+ r.0.into_future()
+ .map_err(|(e, _)| Error::ReadAsync { source: e })
+ .and_then(|(data, r)| match data {
+ Some(data) => Ok((data, r)),
+ None => Err(Error::EOF),
+ })
+ .map(|(buf, r)| {
+ let (ty_buf, data_buf) =
buf.split_at(std::mem::size_of::<u32>());
- let len = u32::from_le_bytes(len_buf.try_into().unwrap());
- let ty = u32::from_le_bytes(ty_buf.try_into().unwrap());
- let body_buf = vec![0u8; len.try_into().unwrap()];
- tokio::io::read_exact(r, body_buf).map(move |(r, buf)| {
- (
- Packet {
- ty,
- data: buf.to_vec(),
- },
- r,
- )
- })
+ let ty = u32::from_be_bytes(ty_buf.try_into().unwrap());
+ let data = data_buf.to_vec();
+ (Packet { ty, data }, FramedReader(r))
})
- .context(ReadAsync)
}
fn write<W: std::io::Write>(&self, mut w: W) -> Result<()> {
- Ok(w.write_all(&self.as_bytes()).context(Write)?)
+ let bytes = self.as_bytes();
+ let len: u32 = bytes.len().try_into().unwrap();
+ let len_buf = len.to_be_bytes();
+ let buf: Vec<u8> =
+ len_buf.iter().chain(bytes.iter()).copied().collect();
+ Ok(w.write_all(&buf).context(Write)?)
}
- fn write_async<W: tokio::io::AsyncWrite>(
+ fn write_async(
&self,
- w: W,
- ) -> impl futures::future::Future<Item = W, Error = Error> {
- tokio::io::write_all(w, self.as_bytes())
- .map(|(w, _)| w)
+ w: FramedWriter,
+ ) -> impl futures::future::Future<Item = FramedWriter, Error = Error>
+ {
+ w.0.send(bytes::Bytes::from(self.as_bytes()))
+ .map(FramedWriter)
.context(WriteAsync)
}
fn as_bytes(&self) -> Vec<u8> {
- let len: u32 = self.data.len().try_into().unwrap();
- let len_buf = len.to_le_bytes();
- let ty = self.ty.to_le_bytes();
- len_buf
+ self.ty
+ .to_be_bytes()
.iter()
- .chain(ty.iter())
.chain(self.data.iter())
.cloned()
.collect()
@@ -201,7 +242,7 @@ impl From<&Message> for Packet {
n.try_into().unwrap()
}
fn write_u32(val: u32, data: &mut Vec<u8>) {
- data.extend_from_slice(&val.to_le_bytes());
+ data.extend_from_slice(&val.to_be_bytes());
}
fn write_bytes(val: &[u8], data: &mut Vec<u8>) {
write_u32(u32_from_usize(val.len()), data);
@@ -286,7 +327,7 @@ impl std::convert::TryFrom<Packet> for Message {
fn try_from(packet: Packet) -> Result<Self> {
fn read_u32(data: &[u8]) -> Result<(u32, &[u8])> {
let (buf, rest) = data.split_at(std::mem::size_of::<u32>());
- let val = u32::from_le_bytes(buf.try_into().context(ParseInt)?);
+ let val = u32::from_be_bytes(buf.try_into().context(ParseInt)?);
Ok((val, rest))
}
fn read_bytes(data: &[u8]) -> Result<(Vec<u8>, &[u8])> {