Skip to content

Commit

Permalink
interp: use os.Pipe when StdIO or OpenHandler produce non-file stdins
Browse files Browse the repository at this point in the history
A recent commit resolved the stdin draining issue with os/exec processes
when stdin was created via a here-document, as those used buffers.

Fix the same issue when stdin was provided via StdIO or Openandler,
which are not guaranteed to provide an os.File.
Much like the os/exec.Cmd.Stdin, we create a pipe and copy the contents
from the original reader via a goroutine when not a file.

This makes sense from an API standpoint: much like os/exec,
the sh/interp API allows any io.Reader as stdin, but using an os.File
will save a pipe and a goroutine as it can be used directly.

Note that this requires cmd/gosh's tests for runInteractive
to swap io.Pipe for os.Pipe, as otherwise we buffer stdin reads
which cause the tests to hang.
  • Loading branch information
mvdan committed Aug 8, 2024
1 parent 262cc0e commit 63f3119
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 42 deletions.
12 changes: 9 additions & 3 deletions cmd/gosh/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package main
import (
"fmt"
"io"
"os"
"testing"

"github.com/go-quicktest/qt"
"mvdan.cc/sh/v3/interp"
)

Expand Down Expand Up @@ -179,14 +181,18 @@ func TestInteractive(t *testing.T) {
t.Parallel()
for _, tc := range interactiveTests {
t.Run("", func(t *testing.T) {
inReader, inWriter := io.Pipe()
outReader, outWriter := io.Pipe()
inReader, inWriter, err := os.Pipe()
qt.Assert(t, qt.IsNil(err))
outReader, outWriter, err := os.Pipe()
qt.Assert(t, qt.IsNil(err))
runner, _ := interp.New(interp.StdIO(inReader, outWriter, outWriter))
errc := make(chan error, 1)
go func() {
errc <- runInteractive(runner, inReader, outWriter, outWriter)
// Discard the rest of the input.
io.Copy(io.Discard, inReader)
inReader.Close()
outWriter.Close()
}()

if err := readString(outReader, "$ "); err != nil {
Expand Down Expand Up @@ -215,7 +221,7 @@ func TestInteractive(t *testing.T) {
// so that any remaining prompt writes get discarded.
outReader.Close()

err := <-errc
err = <-errc
if err != nil && tc.wantErr == "" {
t.Fatalf("unexpected error: %v", err)
} else if tc.wantErr != "" && fmt.Sprint(err) != tc.wantErr {
Expand Down
38 changes: 31 additions & 7 deletions interp/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ type Runner struct {
// statHandler is a function responsible for getting file stat. It must be non-nil.
statHandler StatHandlerFunc

// TODO: we should force stdin to always be an *os.File,
// otherwise the first os/exec command we execute with stdin
// will always consume the entirety of the contents.

stdin io.Reader
stdin *os.File // e.g. the read end of a pipe
stdout io.Writer
stderr io.Writer

Expand Down Expand Up @@ -148,7 +144,7 @@ type Runner struct {
origDir string
origParams []string
origOpts runnerOpts
origStdin io.Reader
origStdin *os.File
origStdout io.Writer
origStderr io.Writer

Expand Down Expand Up @@ -425,12 +421,40 @@ func StatHandler(f StatHandlerFunc) RunnerOption {
}
}

func stdinFile(r io.Reader) (*os.File, error) {
switch r := r.(type) {
case *os.File:
return r, nil
case nil:
return nil, nil
default:
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
}
go func() {
io.Copy(pw, r)
pw.Close()
}()
return pr, nil
}
}

// StdIO configures an interpreter's standard input, standard output, and
// standard error. If out or err are nil, they default to a writer that discards
// the output.
//
// Note that providing a non-nil standard input other than [os.File] will require
// an [os.Pipe] and spawning a goroutine to copy into it,
// as an [os.File] is the only way to share a reader with subprocesses.
// See [os/exec.Cmd.Stdin].
func StdIO(in io.Reader, out, err io.Writer) RunnerOption {
return func(r *Runner) error {
r.stdin = in
stdin, _err := stdinFile(in)
if _err != nil {
return _err
}
r.stdin = stdin
if out == nil {
out = io.Discard
}
Expand Down
52 changes: 24 additions & 28 deletions interp/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,38 +930,34 @@ func (r *Runner) readLine(ctx context.Context, raw bool) ([]byte, error) {
var line []byte
esc := false

stdin := r.stdin
if osFile, ok := stdin.(*os.File); ok {
cr, err := cancelreader.NewReader(osFile)
if err != nil {
return nil, err
}
stdin = cr
done := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
select {
case <-ctx.Done():
cr.Cancel()
case <-done:
}
wg.Done()
}()
defer func() {
close(done)
wg.Wait()
// Could put the Close in the above goroutine, but if "read" is
// immediately called again, the Close might overlap with creating a
// new cancelreader. Want this cancelreader to be completely closed
// by the time readLine returns.
cr.Close()
}()
cr, err := cancelreader.NewReader(r.stdin)
if err != nil {
return nil, err
}
done := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
select {
case <-ctx.Done():
cr.Cancel()
case <-done:
}
wg.Done()
}()
defer func() {
close(done)
wg.Wait()
// Could put the Close in the above goroutine, but if "read" is
// immediately called again, the Close might overlap with creating a
// new cancelreader. Want this cancelreader to be completely closed
// by the time readLine returns.
cr.Close()
}()

for {
var buf [1]byte
n, err := stdin.Read(buf[:])
n, err := cr.Read(buf[:])
if n > 0 {
b := buf[0]
switch {
Expand Down
3 changes: 3 additions & 0 deletions interp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ func pathExts(env expand.Environ) []string {
// Use a return error of type [*os.PathError] to have the error printed to
// stderr and the exit status set to 1. If the error is of any other type, the
// interpreter will come to a stop.
//
// Note that implementations which do not return [os.File] will cause
// extra files and goroutines for input redirections; see [StdIO].
type OpenHandlerFunc func(ctx context.Context, path string, flag int, perm os.FileMode) (io.ReadWriteCloser, error)

// DefaultOpenHandler returns the [OpenHandlerFunc] used by default.
Expand Down
2 changes: 1 addition & 1 deletion interp/interp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4418,5 +4418,5 @@ func TestRunnerNonFileStdin(t *testing.T) {
cb.WriteString(err.Error())
}
// TODO: just like with heredocs, the first exec_ok call consumes all stdin.
qt.Assert(t, qt.Equals(cb.String(), "a\nexec ok\n"))
qt.Assert(t, qt.Equals(cb.String(), "a\nexec ok\nb\nexec ok\nc\nexec ok\n"))
}
11 changes: 8 additions & 3 deletions interp/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ func (r *Runner) stmts(ctx context.Context, stmts []*syntax.Stmt) {
}
}

func (r *Runner) hdocReader(rd *syntax.Redirect) (io.ReadCloser, error) {
func (r *Runner) hdocReader(rd *syntax.Redirect) (*os.File, error) {
pr, pw, err := os.Pipe()
if err != nil {
return nil, err
Expand Down Expand Up @@ -904,7 +904,11 @@ func (r *Runner) redir(ctx context.Context, rd *syntax.Redirect) (io.Closer, err
}
switch rd.Op {
case syntax.RdrIn:
r.stdin = f
stdin, err := stdinFile(f)
if err != nil {
return nil, err
}
r.stdin = stdin
case syntax.RdrOut, syntax.AppOut:
*orig = f
case syntax.RdrAll, syntax.AppAll:
Expand Down Expand Up @@ -1002,14 +1006,15 @@ func (r *Runner) open(ctx context.Context, path string, flags int, mode os.FileM
// TODO: support wrapped PathError returned from openHandler.
switch err.(type) {
case nil:
return f, nil
case *os.PathError:
if print {
r.errf("%v\n", err)
}
default: // handler's custom fatal error
r.setErr(err)
}
return f, err
return nil, err
}

func (r *Runner) stat(ctx context.Context, name string) (fs.FileInfo, error) {
Expand Down

0 comments on commit 63f3119

Please sign in to comment.