From cc06e4fc8557f9c277c3875ca550ea2d567e8599 Mon Sep 17 00:00:00 2001 From: Jesse Luehrs Date: Wed, 24 Feb 2021 04:40:02 -0500 Subject: make the tokio pty backend actually work properly --- Cargo.toml | 5 ++- src/command.rs | 4 ++ src/error.rs | 3 ++ src/pty.rs | 1 + src/pty/async_io.rs | 4 ++ src/pty/std.rs | 4 ++ src/pty/tokio.rs | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 7 files changed, 133 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c5840e5..005d002 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,8 @@ thiserror = "1.0" async-io = { version = "1.3", optional = true } async-process = { version = "1.0", optional = true } -tokio = { version = "1.2", optional = true, features = ["fs", "process"] } +tokio = { version = "1.2", optional = true, features = ["fs", "process", "net"] } +futures = { version = "0.3", optional = true } [dev-dependencies] smol = "*" @@ -22,4 +23,4 @@ default = ["backend-std"] backend-std = [] backend-async-std = ["async-io", "async-process"] backend-smol = ["async-io", "async-process"] -backend-tokio = ["tokio"] +backend-tokio = ["tokio", "futures"] diff --git a/src/command.rs b/src/command.rs index e908a06..2f68786 100644 --- a/src/command.rs +++ b/src/command.rs @@ -88,6 +88,10 @@ where self.pty.pt() } + pub fn pty_mut(&mut self) -> &mut P::Pt { + self.pty.pt_mut() + } + pub fn pty_resize(&self, size: &crate::pty::Size) -> Result<()> { self.pty.resize(size) } diff --git a/src/error.rs b/src/error.rs index 7521544..5253cd9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,6 +3,9 @@ pub enum Error { #[error("error making pty async")] AsyncPty(#[source] std::io::Error), + #[error("error making pty async")] + AsyncPtyNix(#[source] nix::Error), + #[error("error creating pty")] CreatePty(#[source] nix::Error), diff --git a/src/pty.rs b/src/pty.rs index 07d6b2c..986b216 100644 --- a/src/pty.rs +++ b/src/pty.rs @@ -16,6 +16,7 @@ pub trait Pty { where Self: Sized; fn pt(&self) -> &Self::Pt; + fn pt_mut(&mut self) -> &mut Self::Pt; fn pts(&self) -> Result<::std::fs::File>; fn resize(&self, size: &super::Size) -> Result<()>; } diff --git a/src/pty/async_io.rs b/src/pty/async_io.rs index 3021f7a..7f45050 100644 --- a/src/pty/async_io.rs +++ b/src/pty/async_io.rs @@ -29,6 +29,10 @@ impl super::Pty for Pty { &self.pt } + fn pt_mut(&mut self) -> &mut Self::Pt { + &mut self.pt + } + fn pts(&self) -> Result { let fh = std::fs::OpenOptions::new() .read(true) diff --git a/src/pty/std.rs b/src/pty/std.rs index a68e3fc..e808cdf 100644 --- a/src/pty/std.rs +++ b/src/pty/std.rs @@ -27,6 +27,10 @@ impl super::Pty for Pty { &self.pt } + fn pt_mut(&mut self) -> &mut Self::Pt { + &mut self.pt + } + fn pts(&self) -> Result { let fh = std::fs::OpenOptions::new() .read(true) diff --git a/src/pty/tokio.rs b/src/pty/tokio.rs index 995a154..cd8c7a5 100644 --- a/src/pty/tokio.rs +++ b/src/pty/tokio.rs @@ -1,18 +1,123 @@ use crate::error::*; +use std::io::{Read as _, Write as _}; use std::os::unix::io::{AsRawFd as _, FromRawFd as _}; +// ideally i would just be able to use tokio::fs::File::from_std on the +// std::fs::File i create from the pty fd, but it appears that tokio::fs::File +// doesn't actually support having both a read and a write operation happening +// on it simultaneously - if you poll the future returned by .read() at any +// point, .write().await will never complete (because it is trying to wait for +// the read to finish before processing the write, which will never happen). +// this unfortunately shows up in patterns like select! pretty frequently, so +// we need to do this the complicated way/: +pub struct AsyncPty(tokio::io::unix::AsyncFd); + +impl std::ops::Deref for AsyncPty { + type Target = tokio::io::unix::AsyncFd; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for AsyncPty { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl std::os::unix::io::AsRawFd for AsyncPty { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + self.0.as_raw_fd() + } +} + +impl tokio::io::AsyncRead for AsyncPty { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf, + ) -> std::task::Poll> { + loop { + let mut guard = futures::ready!(self.0.poll_read_ready(cx))?; + let mut b = [0u8; 4096]; + match guard.try_io(|inner| inner.get_ref().read(&mut b)) { + Ok(Ok(bytes)) => { + // XXX this is safe, but not particularly efficient + buf.clear(); + buf.initialize_unfilled_to(bytes); + buf.set_filled(bytes); + buf.filled_mut().copy_from_slice(&b[..bytes]); + return std::task::Poll::Ready(Ok(())); + } + Ok(Err(e)) => return std::task::Poll::Ready(Err(e)), + Err(_would_block) => continue, + } + } + } +} + +impl tokio::io::AsyncWrite for AsyncPty { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + loop { + let mut guard = futures::ready!(self.0.poll_write_ready(cx))?; + match guard.try_io(|inner| inner.get_ref().write(buf)) { + Ok(result) => return std::task::Poll::Ready(result), + Err(_would_block) => continue, + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + loop { + let mut guard = futures::ready!(self.0.poll_write_ready(cx))?; + match guard.try_io(|inner| inner.get_ref().flush()) { + Ok(_) => return std::task::Poll::Ready(Ok(())), + Err(_would_block) => continue, + } + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } +} + pub struct Pty { - pt: tokio::fs::File, + pt: AsyncPty, ptsname: std::path::PathBuf, } impl super::Pty for Pty { - type Pt = tokio::fs::File; + type Pt = AsyncPty; fn new() -> Result { let (pt_fd, ptsname) = super::create_pt()?; + let bits = nix::fcntl::fcntl(pt_fd, nix::fcntl::FcntlArg::F_GETFL) + .map_err(Error::AsyncPtyNix)?; + // this should be safe because i am just using the return value of + // F_GETFL directly, but for whatever reason nix doesn't like + // from_bits(bits) (it claims it has an unknown field) + let opts = unsafe { + nix::fcntl::OFlag::from_bits_unchecked( + bits | nix::fcntl::OFlag::O_NONBLOCK.bits(), + ) + }; + nix::fcntl::fcntl(pt_fd, nix::fcntl::FcntlArg::F_SETFL(opts)) + .map_err(Error::AsyncPtyNix)?; + // safe because posix_openpt (or the previous functions operating on // the result) would have returned an Err (causing us to return early) // if the file descriptor was invalid. additionally, into_raw_fd gives @@ -20,7 +125,9 @@ impl super::Pty for Pty { // File object to take full ownership. let pt = unsafe { std::fs::File::from_raw_fd(pt_fd) }; - let pt = tokio::fs::File::from_std(pt); + let pt = AsyncPty( + tokio::io::unix::AsyncFd::new(pt).map_err(Error::AsyncPty)?, + ); Ok(Self { pt, ptsname }) } @@ -29,6 +136,10 @@ impl super::Pty for Pty { &self.pt } + fn pt_mut(&mut self) -> &mut Self::Pt { + &mut self.pt + } + fn pts(&self) -> Result { let fh = std::fs::OpenOptions::new() .read(true) -- cgit v1.2.3