diff options
author | Jesse Luehrs <doy@tozt.net> | 2019-09-05 03:34:17 -0400 |
---|---|---|
committer | Jesse Luehrs <doy@tozt.net> | 2019-09-05 03:34:17 -0400 |
commit | 9de982f9170b2e53f0cae2336bcd5dcd2b340ddc (patch) | |
tree | 1211554ba441091b342bd584f7170333f7a510f6 | |
parent | 0a1ed087030c4de7adc0c14ab41142271a7e79e0 (diff) | |
download | teleterm-9de982f9170b2e53f0cae2336bcd5dcd2b340ddc.tar.gz teleterm-9de982f9170b2e53f0cae2336bcd5dcd2b340ddc.zip |
simplify serialization/deserialization a bunch
-rw-r--r-- | src/protocol.rs | 276 |
1 files changed, 129 insertions, 147 deletions
diff --git a/src/protocol.rs b/src/protocol.rs index 90ece86..08ede9d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -17,28 +17,19 @@ pub enum Error { #[snafu(display("failed to write packet: {}", source))] WriteAsync { source: tokio::io::Error }, - #[snafu(display("invalid StartCasting message: {}", source))] - ParseStartCastingMessage { source: std::string::FromUtf8Error }, + #[snafu(display("failed to parse string: {}", source))] + ParseString { source: std::string::FromUtf8Error }, - #[snafu(display("invalid StartWatching message: {}", source))] - ParseStartWatchingMessage { source: std::string::FromUtf8Error }, - - #[snafu(display("invalid Sessions message: {}", source))] - ParseMessageLen { + #[snafu(display("failed to parse int: {}", source))] + ParseInt { source: std::array::TryFromSliceError, }, - #[snafu(display("invalid Sessions message: {}", source))] - ParseSessionsMessageId { source: std::string::FromUtf8Error }, - - #[snafu(display("invalid WatchSession message: {}", source))] - ParseWatchSessionMessage { source: std::string::FromUtf8Error }, + #[snafu(display("failed to parse string: {:?}", data))] + ExtraMessageData { data: Vec<u8> }, #[snafu(display("invalid message type: {}", ty))] InvalidMessageType { ty: u32 }, - - #[snafu(display("invalid connection type: {}", ty))] - InvalidConnType { ty: u32 }, } pub type Result<T> = std::result::Result<T, Error>; @@ -205,6 +196,27 @@ impl Packet { impl From<&Message> for Packet { fn from(msg: &Message) -> Self { + fn u32_from_usize(n: usize) -> u32 { + // XXX this can actually panic + n.try_into().unwrap() + } + fn write_u32(val: u32, data: &mut Vec<u8>) { + data.extend_from_slice(&val.to_le_bytes()); + } + fn write_bytes(val: &[u8], data: &mut Vec<u8>) { + write_u32(u32_from_usize(val.len()), data); + data.extend_from_slice(val); + } + fn write_str(val: &str, data: &mut Vec<u8>) { + write_bytes(val.as_bytes(), data); + } + fn write_strvec(val: &[String], data: &mut Vec<u8>) { + write_u32(u32_from_usize(val.len()), data); + for s in val { + write_str(&s, data); + } + } + match msg { Message::StartCasting { proto_version, @@ -213,15 +225,9 @@ impl From<&Message> for Packet { } => { let mut data = vec![]; - data.extend_from_slice(&proto_version.to_le_bytes()); - - let len: u32 = username.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - data.extend_from_slice(username.as_bytes()); - - let len: u32 = term_type.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - data.extend_from_slice(term_type.as_bytes()); + write_u32(*proto_version, &mut data); + write_str(username, &mut data); + write_str(term_type, &mut data); Packet { ty: 0, data } } @@ -232,15 +238,9 @@ impl From<&Message> for Packet { } => { let mut data = vec![]; - data.extend_from_slice(&proto_version.to_le_bytes()); - - let len: u32 = username.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - data.extend_from_slice(username.as_bytes()); - - let len: u32 = term_type.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - data.extend_from_slice(term_type.as_bytes()); + write_u32(*proto_version, &mut data); + write_str(username, &mut data); + write_str(term_type, &mut data); Packet { ty: 1, data } } @@ -248,29 +248,34 @@ impl From<&Message> for Packet { ty: 2, data: vec![], }, - Message::TerminalOutput { data } => Packet { - ty: 3, - data: data.to_vec(), - }, + Message::TerminalOutput { data: output } => { + let mut data = vec![]; + + write_bytes(output, &mut data); + + Packet { + ty: 3, + data: data.to_vec(), + } + } Message::ListSessions => Packet { ty: 4, data: vec![], }, Message::Sessions { ids } => { let mut data = vec![]; - let len: u32 = ids.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - for id in ids { - let len: u32 = id.len().try_into().unwrap(); - data.extend_from_slice(&len.to_le_bytes()); - data.extend_from_slice(&id.as_bytes()); - } + + write_strvec(ids, &mut data); + Packet { ty: 5, data } } - Message::WatchSession { id } => Packet { - ty: 6, - data: id.as_bytes().to_vec(), - }, + Message::WatchSession { id } => { + let mut data = vec![]; + + write_str(id, &mut data); + + Packet { ty: 6, data } + } } } } @@ -279,111 +284,88 @@ impl std::convert::TryFrom<Packet> for Message { type Error = Error; fn try_from(packet: Packet) -> Result<Self> { - match packet.ty { + fn read_u32(data: &[u8]) -> Result<(u32, &[u8])> { + let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); + let val = u32::from_le_bytes(buf.try_into().context(ParseInt)?); + Ok((val, rest)) + } + fn read_bytes(data: &[u8]) -> Result<(Vec<u8>, &[u8])> { + let (len, data) = read_u32(data)?; + let (buf, rest) = data.split_at(len.try_into().unwrap()); + let val = buf.to_vec(); + Ok((val, rest)) + } + fn read_str(data: &[u8]) -> Result<(String, &[u8])> { + let (bytes, rest) = read_bytes(data)?; + let val = String::from_utf8(bytes).context(ParseString)?; + Ok((val, rest)) + } + fn read_strvec(data: &[u8]) -> Result<(Vec<String>, &[u8])> { + let mut val = vec![]; + let (len, mut data) = read_u32(data)?; + for _ in 0..len { + let (subval, subdata) = read_str(data)?; + val.push(subval); + data = subdata; + } + Ok((val, data)) + } + + let data: &[u8] = packet.data.as_ref(); + let (msg, rest) = match packet.ty { 0 => { - let mut data: &[u8] = packet.data.as_ref(); - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let proto_version = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let len = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - let (buf, rest) = data.split_at(len.try_into().unwrap()); - let username = String::from_utf8(buf.to_vec()) - .context(ParseStartWatchingMessage)?; - data = rest; - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let len = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - let (buf, _) = data.split_at(len.try_into().unwrap()); - let term_type = String::from_utf8(buf.to_vec()) - .context(ParseStartWatchingMessage)?; - - Ok(Message::StartCasting { - proto_version, - username, - term_type, - }) + let (proto_version, data) = read_u32(data)?; + let (username, data) = read_str(data)?; + let (term_type, data) = read_str(data)?; + + ( + Message::StartCasting { + proto_version, + username, + term_type, + }, + data, + ) } 1 => { - let mut data: &[u8] = packet.data.as_ref(); - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let proto_version = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let len = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - let (buf, rest) = data.split_at(len.try_into().unwrap()); - let username = String::from_utf8(buf.to_vec()) - .context(ParseStartWatchingMessage)?; - data = rest; - - let (buf, rest) = data.split_at(std::mem::size_of::<u32>()); - let len = u32::from_le_bytes( - buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - let (buf, _) = data.split_at(len.try_into().unwrap()); - let term_type = String::from_utf8(buf.to_vec()) - .context(ParseStartWatchingMessage)?; - - Ok(Message::StartWatching { - proto_version, - username, - term_type, - }) + let (proto_version, data) = read_u32(data)?; + let (username, data) = read_str(data)?; + let (term_type, data) = read_str(data)?; + + ( + Message::StartWatching { + proto_version, + username, + term_type, + }, + data, + ) + } + 2 => (Message::Heartbeat, data), + 3 => { + let (output, data) = read_bytes(data)?; + + (Message::TerminalOutput { data: output }, data) } - 2 => Ok(Message::Heartbeat), - 3 => Ok(Message::TerminalOutput { data: packet.data }), - 4 => Ok(Message::ListSessions), + 4 => (Message::ListSessions, data), 5 => { - let mut ids = vec![]; - let mut data: &[u8] = packet.data.as_ref(); - - let (num_sessions_buf, rest) = - data.split_at(std::mem::size_of::<u32>()); - let num_sessions = u32::from_le_bytes( - num_sessions_buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - - for _ in 0..num_sessions { - let (len_buf, rest) = - data.split_at(std::mem::size_of::<u32>()); - let len = u32::from_le_bytes( - len_buf.try_into().context(ParseMessageLen)?, - ); - data = rest; - - let (id_buf, rest) = - data.split_at(len.try_into().unwrap()); - let id = String::from_utf8(id_buf.to_vec()) - .context(ParseSessionsMessageId)?; - ids.push(id); - data = rest; - } - Ok(Message::Sessions { ids }) + let (ids, data) = read_strvec(data)?; + + (Message::Sessions { ids }, data) + } + 6 => { + let (id, data) = read_str(data)?; + (Message::WatchSession { id }, data) } - 6 => Ok(Message::WatchSession { - id: String::from_utf8(packet.data) - .context(ParseWatchSessionMessage)?, - }), - _ => Err(Error::InvalidMessageType { ty: packet.ty }), + _ => return Err(Error::InvalidMessageType { ty: packet.ty }), + }; + + if !rest.is_empty() { + return Err(Error::ExtraMessageData { + data: rest.to_vec(), + }); } + + Ok(msg) } } |