aboutsummaryrefslogtreecommitdiffstats
path: root/src/protocol.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol.rs')
-rw-r--r--src/protocol.rs74
1 files changed, 59 insertions, 15 deletions
diff --git a/src/protocol.rs b/src/protocol.rs
index 381883e..059abd0 100644
--- a/src/protocol.rs
+++ b/src/protocol.rs
@@ -53,6 +53,41 @@ impl<T: tokio::io::AsyncWrite> FramedWriter<T> {
pub const PROTO_VERSION: u8 = 1;
+#[repr(u8)]
+#[derive(Copy, Clone, Debug)]
+pub enum AuthType {
+ Plain = 0,
+ RecurseCenter,
+}
+
+impl AuthType {
+ pub fn name(self) -> &'static str {
+ match self {
+ Self::Plain => "plain",
+ Self::RecurseCenter => "recurse_center",
+ }
+ }
+
+ pub fn is_oauth(self) -> bool {
+ match self {
+ Self::Plain => false,
+ Self::RecurseCenter => true,
+ }
+ }
+}
+
+impl std::convert::TryFrom<u8> for AuthType {
+ type Error = Error;
+
+ fn try_from(n: u8) -> Result<Self> {
+ Ok(match n {
+ 0 => Self::Plain,
+ 1 => Self::RecurseCenter,
+ _ => return Err(Error::InvalidAuthType { ty: n }),
+ })
+ }
+}
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Auth {
Plain { username: String },
@@ -60,24 +95,34 @@ pub enum Auth {
}
impl Auth {
- pub fn is_oauth(&self) -> bool {
- match self {
- Self::Plain { .. } => false,
- Self::RecurseCenter { .. } => true,
+ pub fn plain(username: &str) -> Self {
+ Self::Plain {
+ username: username.to_string(),
+ }
+ }
+
+ pub fn recurse_center(id: Option<&str>) -> Self {
+ Self::RecurseCenter {
+ id: id.map(std::string::ToString::to_string),
}
}
- pub fn name(&self) -> &str {
+ pub fn is_oauth(&self) -> bool {
+ self.auth_type().is_oauth()
+ }
+
+ pub fn name(&self) -> &'static str {
+ self.auth_type().name()
+ }
+
+ fn auth_type(&self) -> AuthType {
match self {
- Self::Plain { .. } => "plain",
- Self::RecurseCenter { .. } => "recurse_center",
+ Self::Plain { .. } => AuthType::Plain,
+ Self::RecurseCenter { .. } => AuthType::RecurseCenter,
}
}
}
-const AUTH_PLAIN: u8 = 0;
-const AUTH_RECURSE_CENTER: u8 = 1;
-
// XXX https://github.com/rust-lang/rust/issues/64362
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -385,14 +430,13 @@ impl From<&Message> for Packet {
}
}
fn write_auth(val: &Auth, data: &mut Vec<u8>) {
+ write_u8(val.auth_type() as u8, data);
match val {
Auth::Plain { username } => {
- write_u8(AUTH_PLAIN, data);
write_str(username, data);
}
Auth::RecurseCenter { id } => {
let id = id.as_ref().map_or("", |s| s.as_str());
- write_u8(AUTH_RECURSE_CENTER, data);
write_str(id, data);
}
}
@@ -617,19 +661,19 @@ impl std::convert::TryFrom<Packet> for Message {
}
fn read_auth(data: &[u8]) -> Result<(Auth, &[u8])> {
let (ty, data) = read_u8(data)?;
+ let ty = AuthType::try_from(ty)?;
let (auth, data) = match ty {
- AUTH_PLAIN => {
+ AuthType::Plain => {
let (username, data) = read_str(data)?;
let auth = Auth::Plain { username };
(auth, data)
}
- AUTH_RECURSE_CENTER => {
+ AuthType::RecurseCenter => {
let (id, data) = read_str(data)?;
let id = if id == "" { None } else { Some(id) };
let auth = Auth::RecurseCenter { id };
(auth, data)
}
- _ => return Err(Error::InvalidAuthType { ty }),
};
Ok((auth, data))
}