aboutsummaryrefslogblamecommitdiffstats
path: root/src/process.rs
blob: f5c76ed348b4909323cfcb4af29df5a59a01ca73 (plain) (tree)





















                                                                              


                                                                  













































































































































































































































































































































                                                                              
                                          





                                                              
use futures::future::Future as _;
use snafu::ResultExt as _;
use std::os::unix::io::AsRawFd as _;
use tokio::io::{AsyncRead as _, AsyncWrite as _};
use tokio_pty_process::{CommandExt as _, PtyMaster as _};

const READ_BUFFER_SIZE: usize = 4 * 1024;

/// Represents events generated by the process.
#[derive(Debug, PartialEq, Eq)]
pub enum Event {
    /// Emitted once the command has been successfully spawned.
    CommandStart { cmd: String, args: Vec<String> },

    /// Emitted every time the command produces output. Note that when a
    /// process is running under a pty, both stdout and stderr are attached to
    /// the single pty input - there is no way to differentiate them when
    /// reading from the pty output.
    Output { data: Vec<u8> },

    /// Emitted when the command has exited.
    CommandExit { status: std::process::ExitStatus },

    /// Emitted by a `ResizingProcess` when a resize event happens
    Resize { size: (u16, u16) },
}

struct State {
    pty: Option<tokio_pty_process::AsyncPtyMaster>,
    process: Option<tokio_pty_process::Child>,
}

impl State {
    fn new() -> Self {
        Self {
            pty: None,
            process: None,
        }
    }

    fn pty(&self) -> &tokio_pty_process::AsyncPtyMaster {
        self.pty.as_ref().unwrap()
    }

    fn pty_mut(&mut self) -> &mut tokio_pty_process::AsyncPtyMaster {
        self.pty.as_mut().unwrap()
    }

    fn process(&mut self) -> &mut tokio_pty_process::Child {
        self.process.as_mut().unwrap()
    }
}

/// A spawned process.
///
/// Wraps `AsyncPtyMaster` and `Child` from `tokio-pty-process` to provide a
/// view of the process as a single stream which emits events. In particular,
/// the stream will return an event when the process starts, when it writes
/// output to the pty, and when it exits. See the `Event` type for more
/// details.
pub struct Process<R: tokio::io::AsyncRead> {
    state: State,
    input: R,
    input_buf: std::collections::VecDeque<u8>,
    cmd: String,
    args: Vec<String>,
    buf: [u8; READ_BUFFER_SIZE],
    started: bool,
    exited: bool,
    needs_resize: Option<(u16, u16)>,
    stdin_closed: bool,
    stdout_closed: bool,
}

impl<R: tokio::io::AsyncRead + 'static> Process<R> {
    /// Creates a new process stream.
    ///
    /// The process is not spawned and the pty is not opened until `poll` is
    /// called.
    ///
    /// Takes as input the command and arguments to run, as well as the
    /// `AsyncRead` object to read input from. Typically you will pass in
    /// something connected to stdin here, although other options may be more
    /// useful for automation or testing.
    pub fn new(cmd: &str, args: &[String], input: R) -> Self {
        Self {
            state: State::new(),
            input,
            input_buf: std::collections::VecDeque::new(),
            cmd: cmd.to_string(),
            args: args.to_vec(),
            buf: [0; READ_BUFFER_SIZE],
            started: false,
            exited: false,
            needs_resize: None,
            stdin_closed: false,
            stdout_closed: false,
        }
    }

    /// Requests a change to the pty's terminal size.
    ///
    /// This will only be applied on the next call to `poll`.
    pub fn resize(&mut self, rows: u16, cols: u16) {
        self.needs_resize = Some((rows, cols));
    }

    /// Returns a mutable reference to the input object provided in the
    /// constructor.
    ///
    /// This can be useful if you are driving the input manually, rather than
    /// just hooking it up directly to stdin.
    pub fn input(&mut self) -> &mut R {
        &mut self.input
    }
}

impl<R: tokio::io::AsyncRead + 'static> Process<R> {
    const POLL_FNS:
        &'static [&'static dyn for<'a> Fn(
            &'a mut Self,
        )
            -> component_future::Poll<
            Option<Event>,
            crate::error::Error,
        >] = &[
        // order is important here - checking command_exit first so that we
        // don't try to read from a process that has already exited, which
        // causes an error. also, poll_resize needs to happen after
        // poll_command_start, or else the pty might not be initialized.
        &Self::poll_command_start,
        &Self::poll_command_exit,
        &Self::poll_resize,
        &Self::poll_read_stdin,
        &Self::poll_write_stdin,
        &Self::poll_read_stdout,
    ];

    fn poll_resize(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        if let Some((rows, cols)) = &self.needs_resize {
            component_future::try_ready!(self
                .state
                .pty()
                .resize(*rows, *cols)
                .context(crate::error::ResizePty));
            log::debug!("resize({}x{})", cols, rows);
            self.needs_resize = None;
            Ok(component_future::Async::DidWork)
        } else {
            Ok(component_future::Async::NothingToDo)
        }
    }

    fn poll_command_start(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        if self.started {
            return Ok(component_future::Async::NothingToDo);
        }

        if self.state.pty.is_none() {
            self.state.pty = Some(
                tokio_pty_process::AsyncPtyMaster::open()
                    .context(crate::error::OpenPty)?,
            );
            log::debug!(
                "openpty({})",
                self.state.pty.as_ref().unwrap().as_raw_fd()
            );
        }

        if self.state.process.is_none() {
            self.state.process = Some(
                std::process::Command::new(&self.cmd)
                    .args(&self.args)
                    .spawn_pty_async(self.state.pty())
                    .context(crate::error::SpawnProcess {
                        cmd: self.cmd.clone(),
                    })?,
            );
            log::debug!(
                "spawn({})",
                self.state.process.as_ref().unwrap().id()
            );
        }

        self.started = true;
        Ok(component_future::Async::Ready(Some(Event::CommandStart {
            cmd: self.cmd.clone(),
            args: self.args.clone(),
        })))
    }

    fn poll_read_stdin(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        if self.exited || self.stdin_closed {
            return Ok(component_future::Async::NothingToDo);
        }

        let n = component_future::try_ready!(self
            .input
            .poll_read(&mut self.buf)
            .context(crate::error::ReadTerminal));
        log::debug!("read_stdin({})", n);
        if n > 0 {
            self.input_buf.extend(self.buf[..n].iter());
        } else {
            self.input_buf.push_back(b'\x04');
            self.stdin_closed = true;
        }
        Ok(component_future::Async::DidWork)
    }

    fn poll_write_stdin(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        if self.exited || self.input_buf.is_empty() {
            return Ok(component_future::Async::NothingToDo);
        }

        let (a, b) = self.input_buf.as_slices();
        let buf = if a.is_empty() { b } else { a };
        let n = component_future::try_ready!(self
            .state
            .pty_mut()
            .poll_write(buf)
            .context(crate::error::WritePty));
        log::debug!("write_stdin({})", n);
        for _ in 0..n {
            self.input_buf.pop_front();
        }
        Ok(component_future::Async::DidWork)
    }

    fn poll_read_stdout(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        match self
            .state
            .pty_mut()
            .poll_read(&mut self.buf)
            .context(crate::error::ReadPty)
        {
            Ok(futures::Async::Ready(n)) => {
                log::debug!("read_stdout({})", n);
                let bytes = self.buf[..n].to_vec();
                Ok(component_future::Async::Ready(Some(Event::Output {
                    data: bytes,
                })))
            }
            Ok(futures::Async::NotReady) => {
                Ok(component_future::Async::NotReady)
            }
            Err(e) => {
                // XXX this seems to be how eof is returned, but this seems...
                // wrong? i feel like there has to be a better way to do this
                if let crate::error::Error::ReadPty { source } = &e {
                    if source.kind() == std::io::ErrorKind::Other {
                        log::debug!("read_stdout(eof)");
                        self.stdout_closed = true;
                        return Ok(component_future::Async::DidWork);
                    }
                }
                Err(e)
            }
        }
    }

    fn poll_command_exit(
        &mut self,
    ) -> component_future::Poll<Option<Event>, crate::error::Error> {
        if self.exited {
            return Ok(component_future::Async::Ready(None));
        }
        if !self.stdout_closed {
            return Ok(component_future::Async::NothingToDo);
        }

        let status = component_future::try_ready!(self
            .state
            .process()
            .poll()
            .context(crate::error::ProcessExitPoll));
        log::debug!("exit({})", status);
        self.exited = true;
        Ok(component_future::Async::Ready(Some(Event::CommandExit {
            status,
        })))
    }
}

#[must_use = "streams do nothing unless polled"]
impl<R: tokio::io::AsyncRead + 'static> futures::stream::Stream
    for Process<R>
{
    type Item = Event;
    type Error = crate::error::Error;

    fn poll(&mut self) -> futures::Poll<Option<Self::Item>, Self::Error> {
        component_future::poll_stream(self, Self::POLL_FNS)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use futures::sink::Sink as _;
    use futures::stream::Stream as _;

    #[test]
    fn test_simple() {
        let (wres, rres) = tokio::sync::mpsc::channel(100);
        let wres2 = wres.clone();
        let mut wres = wres.wait();
        let buf = std::io::Cursor::new(b"hello world\n");
        let fut = Process::new("cat", &[], buf)
            .for_each(move |e| {
                wres.send(Ok(e)).unwrap();
                Ok(())
            })
            .map_err(|e| {
                wres2.wait().send(Err(e)).unwrap();
            });
        tokio::run(fut);

        let mut rres = rres.wait();

        let event = rres.next();
        let event = event.unwrap();
        let event = event.unwrap();
        let event = event.unwrap();
        assert_eq!(
            event,
            Event::CommandStart {
                cmd: "cat".to_string(),
                args: vec![]
            }
        );

        let mut output: Vec<u8> = vec![];
        let mut exited = false;
        for event in rres {
            assert!(!exited);
            let event = event.unwrap();
            let event = event.unwrap();
            match event {
                Event::CommandStart { .. } => {
                    panic!("unexpected CommandStart")
                }
                Event::Output { data } => {
                    output.extend(data.iter());
                }
                Event::CommandExit { status } => {
                    assert!(status.success());
                    exited = true;
                }
                Event::Resize { .. } => {}
            }
        }
        assert!(exited);
        assert_eq!(output, b"hello world\r\nhello world\r\n");
    }
}