diff --git a/go.mod b/go.mod index 03824d50..8dd2d312 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-quicktest/qt v1.101.0 github.com/google/go-cmp v0.6.0 github.com/google/renameio/v2 v2.0.0 + github.com/muesli/cancelreader v0.2.2 github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e github.com/rogpeppe/go-internal v1.12.0 golang.org/x/sync v0.6.0 diff --git a/go.sum b/go.sum index ff4788f1..4c15ee1a 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= diff --git a/interp/builtin.go b/interp/builtin.go index b24a3d6a..0342a7f0 100644 --- a/interp/builtin.go +++ b/interp/builtin.go @@ -9,12 +9,13 @@ import ( "context" "errors" "fmt" - "io" "os" "path/filepath" "strconv" "strings" + "sync" + "github.com/muesli/cancelreader" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/syntax" ) @@ -589,10 +590,7 @@ func (r *Runner) builtinCode(ctx context.Context, pos syntax.Pos, name string, a r.out(prompt) } - line, err := r.readLine(raw) - if err != nil { - return 1 - } + line, err := r.readLine(ctx, raw) if len(args) == 0 { args = append(args, shellReplyVar) } @@ -606,6 +604,12 @@ func (r *Runner) builtinCode(ctx context.Context, pos syntax.Pos, name string, a r.setVarString(name, val) } + // We can get data back from readLine and an error at the same time, so + // check err after we process the data. + if err != nil { + return 1 + } + return 0 case "getopts": @@ -917,7 +921,7 @@ func (r *Runner) printOptLine(name string, enabled, supported bool) { r.outf("%s\t%s\t(%q not supported)\n", name, state, r.optStatusText(!enabled)) } -func (r *Runner) readLine(raw bool) ([]byte, error) { +func (r *Runner) readLine(ctx context.Context, raw bool) ([]byte, error) { if r.stdin == nil { return nil, errors.New("interp: can't read, there's no stdin") } @@ -925,9 +929,38 @@ func (r *Runner) readLine(raw bool) ([]byte, error) { var line []byte esc := false + stdin := r.stdin + if osFile, ok := stdin.(*os.File); ok { + cr, err := cancelreader.NewReader(osFile) + if err != nil { + return nil, err + } + stdin = cr + done := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + select { + case <-ctx.Done(): + cr.Cancel() + case <-done: + } + wg.Done() + }() + defer func() { + close(done) + wg.Wait() + // Could put the Close in the above goroutine, but if "read" is + // immediately called again, the Close might overlap with creating a + // new cancelreader. Want this cancelreader to be completely closed + // by the time readLine returns. + cr.Close() + }() + } + for { var buf [1]byte - n, err := r.stdin.Read(buf[:]) + n, err := stdin.Read(buf[:]) if n > 0 { b := buf[0] switch { @@ -945,11 +978,8 @@ func (r *Runner) readLine(raw bool) ([]byte, error) { esc = false } } - if err == io.EOF && len(line) > 0 { - return line, nil - } if err != nil { - return nil, err + return line, err } } } diff --git a/interp/interp_test.go b/interp/interp_test.go index 8a6a528a..f65b721c 100644 --- a/interp/interp_test.go +++ b/interp/interp_test.go @@ -2903,6 +2903,22 @@ done <<< 2`, "read -r -p 'Prompt and raw flag together: ' a <<< '\\a\\b\\c'; echo $a", "Prompt and raw flag together: \\a\\b\\c\n #IGNORE bash requires a terminal", }, + { + `a=a; echo | (read a; echo -n "$a")`, + "", + }, + { + `a=b; read a < /dev/null; echo -n "$a"`, + "", + }, + { + "a=c; echo x | (read a; echo -n $a)", + "x", + }, + { + "a=d; echo -n y | (read a; echo -n $a)", + "y", + }, // getopts { @@ -3943,6 +3959,51 @@ func TestRunnerContext(t *testing.T) { } } +func TestCancelreader(t *testing.T) { + t.Parallel() + + p := syntax.NewParser() + file := parse(t, p, "read x") + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + // Make the linter happy, even though we deliberately wait for the + // timeout. + defer cancel() + + var stdinRead *os.File + if runtime.GOOS == "windows" { + // On Windows, the cancelreader only works on stdin + stdinRead = os.Stdin + } else { + var stdinWrite *os.File + var err error + stdinRead, stdinWrite, err = os.Pipe() + if err != nil { + t.Fatalf("Error calling os.Pipe: %v", err) + } + defer func() { + stdinWrite.Close() + stdinRead.Close() + }() + } + r, _ := interp.New(interp.StdIO(stdinRead, nil, nil)) + now := time.Now() + errChan := make(chan error) + go func() { + errChan <- r.Run(ctx, file) + }() + + timeout := 500 * time.Millisecond + select { + case err := <-errChan: + if err == nil || err.Error() != "exit status 1" || ctx.Err() != context.DeadlineExceeded { + t.Fatalf("'read x' did not timeout correctly; err: %v, ctx.Err(): %v; dur: %v", + err, ctx.Err(), time.Since(now)) + } + case <-time.After(timeout): + t.Fatalf("program was not killed in %s", timeout) + } +} + func TestRunnerAltNodes(t *testing.T) { t.Parallel() diff --git a/interp/runner.go b/interp/runner.go index 79bf0fd4..7df486fd 100644 --- a/interp/runner.go +++ b/interp/runner.go @@ -515,7 +515,7 @@ func (r *Runner) cmd(ctx context.Context, cm syntax.Command) { } r.errf("%s", ps3) - line, err := r.readLine(true) + line, err := r.readLine(ctx, true) if err != nil { r.exit = 1 return nil