aboutsummaryrefslogtreecommitdiffstats
path: root/src/server/tls.rs
blob: a05608afd4b3f2a7af5fd5e5c890ead2f12c0bfd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::prelude::*;

pub struct Server {
    server: super::Server<tokio_tls::TlsStream<tokio::net::TcpStream>>,
    sock_r:
        tokio::sync::mpsc::Receiver<tokio_tls::Accept<tokio::net::TcpStream>>,
    sock_w: tokio::sync::mpsc::Sender<
        tokio_tls::TlsStream<tokio::net::TcpStream>,
    >,
    accepting_sockets: Vec<tokio_tls::Accept<tokio::net::TcpStream>>,
}

impl Server {
    pub fn new(
        buffer_size: usize,
        read_timeout: std::time::Duration,
        sock_r: tokio::sync::mpsc::Receiver<
            tokio_tls::Accept<tokio::net::TcpStream>,
        >,
        allowed_login_methods: std::collections::HashSet<
            crate::protocol::AuthType,
        >,
    ) -> Self {
        let (tls_sock_w, tls_sock_r) = tokio::sync::mpsc::channel(100);
        Self {
            server: super::Server::new(
                buffer_size,
                read_timeout,
                tls_sock_r,
                allowed_login_methods,
            ),
            sock_r,
            sock_w: tls_sock_w,
            accepting_sockets: vec![],
        }
    }
}

impl Server {
    const POLL_FNS:
        &'static [&'static dyn for<'a> Fn(
            &'a mut Self,
        )
            -> crate::component_future::Poll<
            (),
            Error,
        >] = &[
        &Self::poll_new_connections,
        &Self::poll_handshake_connections,
        &Self::poll_server,
    ];

    fn poll_new_connections(
        &mut self,
    ) -> crate::component_future::Poll<(), Error> {
        if let Some(sock) = try_ready!(self
            .sock_r
            .poll()
            .context(crate::error::SocketChannelReceive))
        {
            self.accepting_sockets.push(sock);
            Ok(crate::component_future::Async::DidWork)
        } else {
            Err(Error::SocketChannelClosed)
        }
    }

    fn poll_handshake_connections(
        &mut self,
    ) -> crate::component_future::Poll<(), Error> {
        let mut did_work = false;
        let mut not_ready = false;

        let mut i = 0;
        while i < self.accepting_sockets.len() {
            let sock = self.accepting_sockets.get_mut(i).unwrap();
            match sock.poll() {
                Ok(futures::Async::Ready(sock)) => {
                    self.accepting_sockets.swap_remove(i);
                    self.sock_w.try_send(sock).unwrap_or_else(|e| {
                        log::warn!(
                            "failed to send connected tls socket: {}",
                            e
                        );
                    });
                    did_work = true;
                    continue;
                }
                Ok(futures::Async::NotReady) => {
                    not_ready = true;
                }
                Err(e) => {
                    log::warn!("failed to accept tls connection: {}", e);
                    self.accepting_sockets.swap_remove(i);
                    continue;
                }
            }
            i += 1;
        }

        if did_work {
            Ok(crate::component_future::Async::DidWork)
        } else if not_ready {
            Ok(crate::component_future::Async::NotReady)
        } else {
            Ok(crate::component_future::Async::NothingToDo)
        }
    }

    fn poll_server(&mut self) -> crate::component_future::Poll<(), Error> {
        try_ready!(self.server.poll());
        Ok(crate::component_future::Async::Ready(()))
    }
}

#[must_use = "futures do nothing unless polled"]
impl futures::future::Future for Server {
    type Item = ();
    type Error = Error;

    fn poll(&mut self) -> futures::Poll<Self::Item, Self::Error> {
        crate::component_future::poll_future(self, Self::POLL_FNS)
    }
}