diff options
Diffstat (limited to 'src/runner/mod.rs')
-rw-r--r-- | src/runner/mod.rs | 101 |
1 files changed, 49 insertions, 52 deletions
diff --git a/src/runner/mod.rs b/src/runner/mod.rs index 1a5003f..d06b332 100644 --- a/src/runner/mod.rs +++ b/src/runner/mod.rs @@ -70,7 +70,7 @@ enum Frame { pub async fn run( commands: &str, - shell_write: Option<&async_std::fs::File>, + shell_write: &mut Option<tokio::fs::File>, ) -> anyhow::Result<i32> { let mut env = Env::new_from_env()?; run_commands(commands, &mut env, shell_write).await?; @@ -86,7 +86,7 @@ pub async fn run( async fn run_commands( commands: &str, env: &mut Env, - shell_write: Option<&async_std::fs::File>, + shell_write: &mut Option<tokio::fs::File>, ) -> anyhow::Result<()> { let commands = crate::parse::ast::Commands::parse(commands)?; let commands = commands.commands(); @@ -152,7 +152,7 @@ async fn run_commands( .map(IntoIterator::into_iter) }) .collect::<futures_util::stream::FuturesOrdered<_>>() - .collect::<Result<Vec<_>, _>>().await? + .try_collect::<Vec<_>>().await? .into_iter() .flatten() .collect() @@ -231,7 +231,7 @@ async fn run_commands( async fn run_pipeline( pipeline: crate::parse::ast::Pipeline, env: &mut Env, - shell_write: Option<&async_std::fs::File>, + shell_write: &mut Option<tokio::fs::File>, ) -> anyhow::Result<()> { write_event(shell_write, Event::RunPipeline(env.idx(), pipeline.span())) .await?; @@ -240,9 +240,9 @@ async fn run_pipeline( // level would not be safe, because in the case of a command line like // "echo foo; ls", we would pass the stdout fd to the ls process while it // is still open here, and may still have data buffered. - let stdin = unsafe { async_std::fs::File::from_raw_fd(0) }; - let stdout = unsafe { async_std::fs::File::from_raw_fd(1) }; - let stderr = unsafe { async_std::fs::File::from_raw_fd(2) }; + let stdin = unsafe { std::fs::File::from_raw_fd(0) }; + let stdout = unsafe { std::fs::File::from_raw_fd(1) }; + let stderr = unsafe { std::fs::File::from_raw_fd(2) }; let mut io = builtins::Io::new(); io.set_stdin(stdin); io.set_stdout(stdout); @@ -265,10 +265,10 @@ async fn run_pipeline( } async fn write_event( - fh: Option<&async_std::fs::File>, + fh: &mut Option<tokio::fs::File>, event: Event, ) -> anyhow::Result<()> { - if let Some(mut fh) = fh { + if let Some(fh) = fh { fh.write_all(&bincode::serialize(&event)?).await?; fh.flush().await?; } @@ -322,11 +322,11 @@ async fn wait_children( pg: Option<nix::unistd::Pid>, env: &Env, io: &builtins::Io, - shell_write: Option<&async_std::fs::File>, + shell_write: &mut Option<tokio::fs::File>, ) -> std::process::ExitStatus { enum Res { Child(nix::Result<nix::sys::wait::WaitStatus>), - Builtin(Option<(anyhow::Result<std::process::ExitStatus>, bool)>), + Builtin((anyhow::Result<std::process::ExitStatus>, bool)), } macro_rules! bail { @@ -353,7 +353,8 @@ async fn wait_children( (sys::id_to_pid(child.id().unwrap()), (child, i == count - 1)) }) .collect(); - let mut builtins: futures_util::stream::FuturesUnordered<_> = + let mut builtin_count = builtins.len(); + let builtins: futures_util::stream::FuturesUnordered<_> = builtins .into_iter() .map(|(i, child)| async move { @@ -361,47 +362,40 @@ async fn wait_children( }) .collect(); - let (wait_w, wait_r) = async_std::channel::unbounded(); - let new_wait = move || { - if let Some(pg) = pg { - let wait_w = wait_w.clone(); - async_std::task::spawn(async move { - let res = blocking::unblock(move || { - nix::sys::wait::waitpid( - sys::neg_pid(pg), - Some(nix::sys::wait::WaitPidFlag::WUNTRACED), - ) - }) - .await; - if wait_w.is_closed() { - // we shouldn't be able to drop real process terminations + let (wait_w, wait_r) = tokio::sync::mpsc::unbounded_channel(); + if let Some(pg) = pg { + tokio::task::spawn_blocking(move || loop { + let res = nix::sys::wait::waitpid( + sys::neg_pid(pg), + Some(nix::sys::wait::WaitPidFlag::WUNTRACED), + ); + match wait_w.send(res) { + Ok(_) => {} + Err(tokio::sync::mpsc::error::SendError(res)) => { + // we should never drop wait_r while there are still valid + // things to read assert!(res.is_err()); - } else { - wait_w.send(res).await.unwrap(); + break; } - }); - } - }; - - new_wait(); - loop { - if children.is_empty() && builtins.is_empty() { - break; - } + } + }); + } - let child = async { Res::Child(wait_r.recv().await.unwrap()) }; - let builtin = async { - Res::Builtin(if builtins.is_empty() { - std::future::pending().await - } else { - builtins.next().await - }) - }; - match child.race(builtin).await { + let mut stream: futures_util::stream::SelectAll<_> = [ + tokio_stream::wrappers::UnboundedReceiverStream::new(wait_r) + .map(Res::Child) + .boxed(), + builtins.map(Res::Builtin).boxed(), + ] + .into_iter() + .collect(); + while let Some(res) = stream.next().await { + match res { Res::Child(Ok(status)) => { match status { - // we can't call child.status() here to unify these branches - // because our waitpid call already collected the status + // we can't call child.status() here to unify these + // branches because our waitpid call already collected the + // status nix::sys::wait::WaitStatus::Exited(pid, code) => { let (_, last) = children.remove(&pid).unwrap(); if last { @@ -449,12 +443,11 @@ async fn wait_children( } _ => {} } - new_wait(); } Res::Child(Err(e)) => { bail!(e); } - Res::Builtin(Some((Ok(status), last))) => { + Res::Builtin((Ok(status), last)) => { // this conversion is safe because the Signal enum is // repr(i32) #[allow(clippy::as_conversions)] @@ -470,11 +463,15 @@ async fn wait_children( if last { final_status = Some(status); } + builtin_count -= 1; } - Res::Builtin(Some((Err(e), _))) => { + Res::Builtin((Err(e), _)) => { bail!(e); } - Res::Builtin(None) => {} + } + + if children.is_empty() && builtin_count == 0 { + break; } } |