diff options
Diffstat (limited to 'src/protocol.rs')
-rw-r--r-- | src/protocol.rs | 74 |
1 files changed, 59 insertions, 15 deletions
diff --git a/src/protocol.rs b/src/protocol.rs index 381883e..059abd0 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -53,6 +53,41 @@ impl<T: tokio::io::AsyncWrite> FramedWriter<T> { pub const PROTO_VERSION: u8 = 1; +#[repr(u8)] +#[derive(Copy, Clone, Debug)] +pub enum AuthType { + Plain = 0, + RecurseCenter, +} + +impl AuthType { + pub fn name(self) -> &'static str { + match self { + Self::Plain => "plain", + Self::RecurseCenter => "recurse_center", + } + } + + pub fn is_oauth(self) -> bool { + match self { + Self::Plain => false, + Self::RecurseCenter => true, + } + } +} + +impl std::convert::TryFrom<u8> for AuthType { + type Error = Error; + + fn try_from(n: u8) -> Result<Self> { + Ok(match n { + 0 => Self::Plain, + 1 => Self::RecurseCenter, + _ => return Err(Error::InvalidAuthType { ty: n }), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum Auth { Plain { username: String }, @@ -60,24 +95,34 @@ pub enum Auth { } impl Auth { - pub fn is_oauth(&self) -> bool { - match self { - Self::Plain { .. } => false, - Self::RecurseCenter { .. } => true, + pub fn plain(username: &str) -> Self { + Self::Plain { + username: username.to_string(), + } + } + + pub fn recurse_center(id: Option<&str>) -> Self { + Self::RecurseCenter { + id: id.map(std::string::ToString::to_string), } } - pub fn name(&self) -> &str { + pub fn is_oauth(&self) -> bool { + self.auth_type().is_oauth() + } + + pub fn name(&self) -> &'static str { + self.auth_type().name() + } + + fn auth_type(&self) -> AuthType { match self { - Self::Plain { .. } => "plain", - Self::RecurseCenter { .. } => "recurse_center", + Self::Plain { .. } => AuthType::Plain, + Self::RecurseCenter { .. } => AuthType::RecurseCenter, } } } -const AUTH_PLAIN: u8 = 0; -const AUTH_RECURSE_CENTER: u8 = 1; - // XXX https://github.com/rust-lang/rust/issues/64362 #[allow(dead_code)] #[derive(Debug, Clone, PartialEq, Eq)] @@ -385,14 +430,13 @@ impl From<&Message> for Packet { } } fn write_auth(val: &Auth, data: &mut Vec<u8>) { + write_u8(val.auth_type() as u8, data); match val { Auth::Plain { username } => { - write_u8(AUTH_PLAIN, data); write_str(username, data); } Auth::RecurseCenter { id } => { let id = id.as_ref().map_or("", |s| s.as_str()); - write_u8(AUTH_RECURSE_CENTER, data); write_str(id, data); } } @@ -617,19 +661,19 @@ impl std::convert::TryFrom<Packet> for Message { } fn read_auth(data: &[u8]) -> Result<(Auth, &[u8])> { let (ty, data) = read_u8(data)?; + let ty = AuthType::try_from(ty)?; let (auth, data) = match ty { - AUTH_PLAIN => { + AuthType::Plain => { let (username, data) = read_str(data)?; let auth = Auth::Plain { username }; (auth, data) } - AUTH_RECURSE_CENTER => { + AuthType::RecurseCenter => { let (id, data) = read_str(data)?; let id = if id == "" { None } else { Some(id) }; let auth = Auth::RecurseCenter { id }; (auth, data) } - _ => return Err(Error::InvalidAuthType { ty }), }; Ok((auth, data)) } |