From 63f3119ecdbe87cf562c786df8ce237247474274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Thu, 8 Aug 2024 11:19:55 +0100 Subject: [PATCH] interp: use os.Pipe when StdIO or OpenHandler produce non-file stdins A recent commit resolved the stdin draining issue with os/exec processes when stdin was created via a here-document, as those used buffers. Fix the same issue when stdin was provided via StdIO or Openandler, which are not guaranteed to provide an os.File. Much like the os/exec.Cmd.Stdin, we create a pipe and copy the contents from the original reader via a goroutine when not a file. This makes sense from an API standpoint: much like os/exec, the sh/interp API allows any io.Reader as stdin, but using an os.File will save a pipe and a goroutine as it can be used directly. Note that this requires cmd/gosh's tests for runInteractive to swap io.Pipe for os.Pipe, as otherwise we buffer stdin reads which cause the tests to hang. --- cmd/gosh/main_test.go | 12 +++++++--- interp/api.go | 38 +++++++++++++++++++++++++------ interp/builtin.go | 52 ++++++++++++++++++++----------------------- interp/handler.go | 3 +++ interp/interp_test.go | 2 +- interp/runner.go | 11 ++++++--- 6 files changed, 76 insertions(+), 42 deletions(-) diff --git a/cmd/gosh/main_test.go b/cmd/gosh/main_test.go index e3b00ae38..397b8282d 100644 --- a/cmd/gosh/main_test.go +++ b/cmd/gosh/main_test.go @@ -6,8 +6,10 @@ package main import ( "fmt" "io" + "os" "testing" + "github.com/go-quicktest/qt" "mvdan.cc/sh/v3/interp" ) @@ -179,14 +181,18 @@ func TestInteractive(t *testing.T) { t.Parallel() for _, tc := range interactiveTests { t.Run("", func(t *testing.T) { - inReader, inWriter := io.Pipe() - outReader, outWriter := io.Pipe() + inReader, inWriter, err := os.Pipe() + qt.Assert(t, qt.IsNil(err)) + outReader, outWriter, err := os.Pipe() + qt.Assert(t, qt.IsNil(err)) runner, _ := interp.New(interp.StdIO(inReader, outWriter, outWriter)) errc := make(chan error, 1) go func() { errc <- runInteractive(runner, inReader, outWriter, outWriter) // Discard the rest of the input. io.Copy(io.Discard, inReader) + inReader.Close() + outWriter.Close() }() if err := readString(outReader, "$ "); err != nil { @@ -215,7 +221,7 @@ func TestInteractive(t *testing.T) { // so that any remaining prompt writes get discarded. outReader.Close() - err := <-errc + err = <-errc if err != nil && tc.wantErr == "" { t.Fatalf("unexpected error: %v", err) } else if tc.wantErr != "" && fmt.Sprint(err) != tc.wantErr { diff --git a/interp/api.go b/interp/api.go index 92498a194..8bb2937f2 100644 --- a/interp/api.go +++ b/interp/api.go @@ -90,11 +90,7 @@ type Runner struct { // statHandler is a function responsible for getting file stat. It must be non-nil. statHandler StatHandlerFunc - // TODO: we should force stdin to always be an *os.File, - // otherwise the first os/exec command we execute with stdin - // will always consume the entirety of the contents. - - stdin io.Reader + stdin *os.File // e.g. the read end of a pipe stdout io.Writer stderr io.Writer @@ -148,7 +144,7 @@ type Runner struct { origDir string origParams []string origOpts runnerOpts - origStdin io.Reader + origStdin *os.File origStdout io.Writer origStderr io.Writer @@ -425,12 +421,40 @@ func StatHandler(f StatHandlerFunc) RunnerOption { } } +func stdinFile(r io.Reader) (*os.File, error) { + switch r := r.(type) { + case *os.File: + return r, nil + case nil: + return nil, nil + default: + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + go func() { + io.Copy(pw, r) + pw.Close() + }() + return pr, nil + } +} + // StdIO configures an interpreter's standard input, standard output, and // standard error. If out or err are nil, they default to a writer that discards // the output. +// +// Note that providing a non-nil standard input other than [os.File] will require +// an [os.Pipe] and spawning a goroutine to copy into it, +// as an [os.File] is the only way to share a reader with subprocesses. +// See [os/exec.Cmd.Stdin]. func StdIO(in io.Reader, out, err io.Writer) RunnerOption { return func(r *Runner) error { - r.stdin = in + stdin, _err := stdinFile(in) + if _err != nil { + return _err + } + r.stdin = stdin if out == nil { out = io.Discard } diff --git a/interp/builtin.go b/interp/builtin.go index 6c5e7be82..6462b0405 100644 --- a/interp/builtin.go +++ b/interp/builtin.go @@ -930,38 +930,34 @@ func (r *Runner) readLine(ctx context.Context, raw bool) ([]byte, error) { var line []byte esc := false - stdin := r.stdin - if osFile, ok := stdin.(*os.File); ok { - cr, err := cancelreader.NewReader(osFile) - if err != nil { - return nil, err - } - stdin = cr - done := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(1) - go func() { - select { - case <-ctx.Done(): - cr.Cancel() - case <-done: - } - wg.Done() - }() - defer func() { - close(done) - wg.Wait() - // Could put the Close in the above goroutine, but if "read" is - // immediately called again, the Close might overlap with creating a - // new cancelreader. Want this cancelreader to be completely closed - // by the time readLine returns. - cr.Close() - }() + cr, err := cancelreader.NewReader(r.stdin) + if err != nil { + return nil, err } + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + select { + case <-ctx.Done(): + cr.Cancel() + case <-done: + } + wg.Done() + }() + defer func() { + close(done) + wg.Wait() + // Could put the Close in the above goroutine, but if "read" is + // immediately called again, the Close might overlap with creating a + // new cancelreader. Want this cancelreader to be completely closed + // by the time readLine returns. + cr.Close() + }() for { var buf [1]byte - n, err := stdin.Read(buf[:]) + n, err := cr.Read(buf[:]) if n > 0 { b := buf[0] switch { diff --git a/interp/handler.go b/interp/handler.go index b86cd7e95..e3924c2ed 100644 --- a/interp/handler.go +++ b/interp/handler.go @@ -295,6 +295,9 @@ func pathExts(env expand.Environ) []string { // Use a return error of type [*os.PathError] to have the error printed to // stderr and the exit status set to 1. If the error is of any other type, the // interpreter will come to a stop. +// +// Note that implementations which do not return [os.File] will cause +// extra files and goroutines for input redirections; see [StdIO]. type OpenHandlerFunc func(ctx context.Context, path string, flag int, perm os.FileMode) (io.ReadWriteCloser, error) // DefaultOpenHandler returns the [OpenHandlerFunc] used by default. diff --git a/interp/interp_test.go b/interp/interp_test.go index fbdc47803..a08a3bb4b 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -4418,5 +4418,5 @@ func TestRunnerNonFileStdin(t *testing.T) { cb.WriteString(err.Error()) } // TODO: just like with heredocs, the first exec_ok call consumes all stdin. - qt.Assert(t, qt.Equals(cb.String(), "a\nexec ok\n")) + qt.Assert(t, qt.Equals(cb.String(), "a\nexec ok\nb\nexec ok\nc\nexec ok\n")) } diff --git a/interp/runner.go b/interp/runner.go index 1d403bfe2..d3d1a7c8c 100644 --- a/interp/runner.go +++ b/interp/runner.go @@ -796,7 +796,7 @@ func (r *Runner) stmts(ctx context.Context, stmts []*syntax.Stmt) { } } -func (r *Runner) hdocReader(rd *syntax.Redirect) (io.ReadCloser, error) { +func (r *Runner) hdocReader(rd *syntax.Redirect) (*os.File, error) { pr, pw, err := os.Pipe() if err != nil { return nil, err @@ -904,7 +904,11 @@ func (r *Runner) redir(ctx context.Context, rd *syntax.Redirect) (io.Closer, err } switch rd.Op { case syntax.RdrIn: - r.stdin = f + stdin, err := stdinFile(f) + if err != nil { + return nil, err + } + r.stdin = stdin case syntax.RdrOut, syntax.AppOut: *orig = f case syntax.RdrAll, syntax.AppAll: @@ -1002,6 +1006,7 @@ func (r *Runner) open(ctx context.Context, path string, flags int, mode os.FileM // TODO: support wrapped PathError returned from openHandler. switch err.(type) { case nil: + return f, nil case *os.PathError: if print { r.errf("%v\n", err) @@ -1009,7 +1014,7 @@ func (r *Runner) open(ctx context.Context, path string, flags int, mode os.FileM default: // handler's custom fatal error r.setErr(err) } - return f, err + return nil, err } func (r *Runner) stat(ctx context.Context, name string) (fs.FileInfo, error) {