From 89f48f714c6a85ac730b7295255891a52e621cc2 Mon Sep 17 00:00:00 2001 From: Alan Parra Date: Fri, 24 Jun 2022 14:17:36 -0300 Subject: [PATCH] Backport lib/utils/prompt improvements to [v9] (#13822) `prompt.ContextReader`, along with various parts of `lib/utils/prompt`, where re-written in master so we can alleviate input swallowing issues. For various reasons I didn't backport all changes to v9, but the less-eager input-swallowing loop is now what stands between us and issue #13021. This isn't a backport of a specific PR, but instead a port of the entire `lib/utils/prompt` package, as the PRs that touch the package unfortunately do more than we want to backport. The interfaces are still compatible and I did test various `tsh login` and `tsh ssh` scenarios. I've thrown in #13382 for good measure, as well. Closes #13021. * Backport lib/utils/prompt improvements * Fix lib/client tests * Restore terminal state on exit (#13382) --- lib/client/api_login_test.go | 2 +- lib/utils/prompt/confirmation.go | 42 ++- lib/utils/prompt/context_reader.go | 346 ++++++++++++++++++++++++ lib/utils/prompt/context_reader_test.go | 273 +++++++++++++++++++ lib/utils/prompt/mock.go | 74 +++++ lib/utils/prompt/stdin.go | 141 +++------- lib/utils/prompt/stdin_test.go | 76 ------ tool/tsh/tsh.go | 6 +- 8 files changed, 768 insertions(+), 192 deletions(-) create mode 100644 lib/utils/prompt/context_reader.go create mode 100644 lib/utils/prompt/context_reader_test.go create mode 100644 lib/utils/prompt/mock.go delete mode 100644 lib/utils/prompt/stdin_test.go diff --git a/lib/client/api_login_test.go b/lib/client/api_login_test.go index 238e182f9789d..37c23de5cf275 100644 --- a/lib/client/api_login_test.go +++ b/lib/client/api_login_test.go @@ -103,7 +103,7 @@ func TestTeleportClient_Login_localMFALogin(t *testing.T) { promptWebauthn func(ctx context.Context, origin string, assertion *wanlib.CredentialAssertion) (*proto.MFAAuthenticateResponse, error) }{} var loginMocksMU sync.RWMutex - *client.PromptOTP = func(ctx context.Context, out io.Writer, in *prompt.ContextReader, question string) (string, error) { + *client.PromptOTP = func(ctx context.Context, out io.Writer, in prompt.Reader, question string) (string, error) { loginMocksMU.RLock() defer loginMocksMU.RUnlock() return loginMocks.promptOTP(ctx) diff --git a/lib/utils/prompt/confirmation.go b/lib/utils/prompt/confirmation.go index 1710f67ccb308..5d3cb19739c03 100644 --- a/lib/utils/prompt/confirmation.go +++ b/lib/utils/prompt/confirmation.go @@ -26,13 +26,27 @@ import ( "github.com/gravitational/trace" ) +// Reader is the interface for prompt readers. +type Reader interface { + // ReadContext reads from the underlying buffer, respecting context + // cancellation. + ReadContext(ctx context.Context) ([]byte, error) +} + +// SecureReader is the interface for password readers. +type SecureReader interface { + // ReadPassword reads from the underlying buffer, respecting context + // cancellation. + ReadPassword(ctx context.Context) ([]byte, error) +} + // Confirmation prompts the user for a yes/no confirmation for question. // The prompt is written to out and the answer is read from in. // -// question should be a plain sentece without "[yes/no]"-type hints at the end. +// question should be a plain sentence without "[yes/no]"-type hints at the end. // // ctx can be canceled to abort the prompt. -func Confirmation(ctx context.Context, out io.Writer, in *ContextReader, question string) (bool, error) { +func Confirmation(ctx context.Context, out io.Writer, in Reader, question string) (bool, error) { fmt.Fprintf(out, "%s [y/N]: ", question) answer, err := in.ReadContext(ctx) if err != nil { @@ -49,14 +63,14 @@ func Confirmation(ctx context.Context, out io.Writer, in *ContextReader, questio // PickOne prompts the user to pick one of the provided string options. // The prompt is written to out and the answer is read from in. // -// question should be a plain sentece without the list of provided options. +// question should be a plain sentence without the list of provided options. // // ctx can be canceled to abort the prompt. -func PickOne(ctx context.Context, out io.Writer, in *ContextReader, question string, options []string) (string, error) { +func PickOne(ctx context.Context, out io.Writer, in Reader, question string, options []string) (string, error) { fmt.Fprintf(out, "%s [%s]: ", question, strings.Join(options, ", ")) answerOrig, err := in.ReadContext(ctx) if err != nil { - return "", trace.WrapWithMessage(err, "failed reading prompt response") + return "", trace.Wrap(err, "failed reading prompt response") } answer := strings.ToLower(strings.TrimSpace(string(answerOrig))) for _, opt := range options { @@ -72,11 +86,25 @@ func PickOne(ctx context.Context, out io.Writer, in *ContextReader, question str // The prompt is written to out and the answer is read from in. // // ctx can be canceled to abort the prompt. -func Input(ctx context.Context, out io.Writer, in *ContextReader, question string) (string, error) { +func Input(ctx context.Context, out io.Writer, in Reader, question string) (string, error) { fmt.Fprintf(out, "%s: ", question) answer, err := in.ReadContext(ctx) if err != nil { - return "", trace.WrapWithMessage(err, "failed reading prompt response") + return "", trace.Wrap(err, "failed reading prompt response") } return strings.TrimSpace(string(answer)), nil } + +// Password prompts the user for a password. The prompt is written to out and +// the answer is read from in. +// The in reader has to be a terminal. +func Password(ctx context.Context, out io.Writer, in SecureReader, question string) (string, error) { + if question != "" { + fmt.Fprintf(out, "%s:\n", question) + } + answer, err := in.ReadPassword(ctx) + if err != nil { + return "", trace.Wrap(err, "failed reading prompt response") + } + return string(answer), nil // passwords not trimmed +} diff --git a/lib/utils/prompt/context_reader.go b/lib/utils/prompt/context_reader.go new file mode 100644 index 0000000000000..7e7b659b3213e --- /dev/null +++ b/lib/utils/prompt/context_reader.go @@ -0,0 +1,346 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prompt + +import ( + "bufio" + "context" + "errors" + "io" + "os" + "os/signal" + "sync" + + "github.com/gravitational/trace" + "golang.org/x/term" + + log "github.com/sirupsen/logrus" +) + +// ErrReaderClosed is returned from ContextReader.ReadContext after it is +// closed. +var ErrReaderClosed = errors.New("ContextReader has been closed") + +// ErrNotTerminal is returned by password reads attempted in non-terminal +// readers. +var ErrNotTerminal = errors.New("underlying reader is not a terminal") + +const bufferSize = 4096 + +type readOutcome struct { + value []byte + err error +} + +type readerState int + +const ( + readerStateIdle readerState = iota + readerStateClean + readerStatePassword + readerStateClosed +) + +// termI aggregates methods from golang.org/x/term for easy mocking. +type termI interface { + GetState(fd int) (*term.State, error) + IsTerminal(fd int) bool + ReadPassword(fd int) ([]byte, error) + Restore(fd int, oldState *term.State) error +} + +// gxTerm delegates method calls to golang.org/x/term methods. +type gxTerm struct{} + +func (gxTerm) GetState(fd int) (*term.State, error) { + return term.GetState(fd) +} + +func (gxTerm) IsTerminal(fd int) bool { + return term.IsTerminal(fd) +} + +func (gxTerm) ReadPassword(fd int) ([]byte, error) { + return term.ReadPassword(fd) +} + +func (gxTerm) Restore(fd int, oldState *term.State) error { + return term.Restore(fd, oldState) +} + +// ContextReader is a wrapper around an underlying io.Reader or terminal that +// allows reads to be abandoned. An abandoned read may be reclaimed by future +// callers. +// ContextReader instances are not safe for concurrent use, callers may block +// indefinitely and reads may be lost. +type ContextReader struct { + term termI + + // reader is used for clean reads. + reader io.Reader + // fd is used for password reads. + // Only present if the underlying reader is a terminal, otherwise set to -1. + fd int + + closed chan struct{} + reads chan readOutcome + + mu *sync.Mutex + cond *sync.Cond + previousTermState *term.State + state readerState +} + +// NewContextReader creates a new ContextReader wrapping rd. +// Callers should avoid reading from rd after the ContextReader is used, as +// abandoned calls may be in progress. It is safe to read from rd if one can +// guarantee that no calls where abandoned. +// Calling ContextReader.Close attempts to release resources, but note that +// ongoing reads cannot be interrupted. +func NewContextReader(rd io.Reader) *ContextReader { + term := gxTerm{} + + fd := -1 + if f, ok := rd.(*os.File); ok { + val := int(f.Fd()) + if term.IsTerminal(val) { + fd = val + } + } + + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + cr := &ContextReader{ + term: term, + reader: bufio.NewReader(rd), + fd: fd, + closed: make(chan struct{}), + reads: make(chan readOutcome), // unbuffered + mu: mu, + cond: cond, + } + go cr.processReads() + return cr +} + +func (cr *ContextReader) processReads() { + defer close(cr.reads) + + for { + cr.mu.Lock() + for cr.state == readerStateIdle { + cr.cond.Wait() + } + // Stop the reading loop? Once closed, forever closed. + if cr.state == readerStateClosed { + cr.mu.Unlock() + return + } + // React to the state that took us out of idleness. + // We can't hold the lock during the entire read, so we obey the last state + // observed. + state := cr.state + cr.mu.Unlock() + + var value []byte + var err error + switch state { + case readerStateClean: + value = make([]byte, bufferSize) + var n int + n, err = cr.reader.Read(value) + value = value[:n] + case readerStatePassword: + value, err = cr.term.ReadPassword(cr.fd) + } + cr.mu.Lock() + cr.previousTermState = nil // A finalized read resets the terminal. + switch cr.state { + case readerStateClosed: // Don't transition from closed. + default: + cr.state = readerStateIdle + } + cr.mu.Unlock() + + select { + case <-cr.closed: + log.Warnf("ContextReader closed during ongoing read, dropping %v bytes", len(value)) + return + case cr.reads <- readOutcome{value: value, err: err}: + } + } +} + +// handleInterrupt restores terminal state on interrupts. +// Called only on global ContextReaders, such as Stdin. +func (cr *ContextReader) handleInterrupt() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + defer signal.Stop(c) + + for { + select { + case sig := <-c: + log.Debugf("Captured signal %s, attempting to restore terminal state", sig) + cr.mu.Lock() + _ = cr.maybeRestoreTerm(iAmHoldingTheLock{}) + cr.mu.Unlock() + case <-cr.closed: + return + } + } +} + +// iAmHoldingTheLock exists only to draw attention to the need to hold the lock. +type iAmHoldingTheLock struct{} + +// maybeRestoreTerm attempts to restore terminal state. +// Lock must be held before calling. +func (cr *ContextReader) maybeRestoreTerm(_ iAmHoldingTheLock) error { + if cr.state == readerStatePassword && cr.previousTermState != nil { + err := cr.term.Restore(cr.fd, cr.previousTermState) + cr.previousTermState = nil + return trace.Wrap(err) + } + + return nil +} + +// ReadContext returns the next chunk of output from the reader. +// If ctx is canceled before the read completes, the current read is abandoned +// and may be reclaimed by future callers. +// It is not safe to read from the underlying reader after a read is abandoned, +// nor is it safe to concurrently call ReadContext. +func (cr *ContextReader) ReadContext(ctx context.Context) ([]byte, error) { + if err := cr.fireCleanRead(); err != nil { + return nil, trace.Wrap(err) + } + + return cr.waitForRead(ctx) +} + +func (cr *ContextReader) fireCleanRead() error { + cr.mu.Lock() + defer cr.mu.Unlock() + + // Atempt to restore terminal state, so we transition to a clean read. + if err := cr.maybeRestoreTerm(iAmHoldingTheLock{}); err != nil { + return trace.Wrap(err) + } + + switch cr.state { + case readerStateIdle: // OK, transition and broadcast. + cr.state = readerStateClean + cr.cond.Broadcast() + case readerStateClean: // OK, ongoing read. + case readerStatePassword: // OK, ongoing read. + case readerStateClosed: + return ErrReaderClosed + } + return nil +} + +func (cr *ContextReader) waitForRead(ctx context.Context) ([]byte, error) { + select { + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case <-cr.closed: + return nil, ErrReaderClosed + case read := <-cr.reads: + return read.value, read.err + } +} + +// ReadPassword reads a password from the underlying reader, provided that the +// reader is a terminal. +// It follows the semantics of ReadContext. +func (cr *ContextReader) ReadPassword(ctx context.Context) ([]byte, error) { + if cr.fd == -1 { + return nil, ErrNotTerminal + } + if err := cr.firePasswordRead(); err != nil { + return nil, trace.Wrap(err) + } + + return cr.waitForRead(ctx) +} + +func (cr *ContextReader) firePasswordRead() error { + cr.mu.Lock() + defer cr.mu.Unlock() + + switch cr.state { + case readerStateIdle: // OK, transition and broadcast. + // Save present terminal state, so it may be restored in case the read goes + // from password to clean. + state, err := cr.term.GetState(cr.fd) + if err != nil { + return trace.Wrap(err) + } + cr.previousTermState = state + cr.state = readerStatePassword + cr.cond.Broadcast() + case readerStateClean: // OK, ongoing clean read. + // TODO(codingllama): Transition the terminal to password read? + log.Warn("prompt: Clean read reused by password read") + case readerStatePassword: // OK, ongoing password read. + case readerStateClosed: + return ErrReaderClosed + } + return nil +} + +// Close closes the context reader, attempting to release resources and aborting +// ongoing and future ReadContext calls. +// Background reads that are already blocked cannot be interrupted, thus Close +// doesn't guarantee a release of all resources. +func (cr *ContextReader) Close() error { + cr.mu.Lock() + defer cr.mu.Unlock() + + switch cr.state { + case readerStateClosed: // OK, already closed. + default: + // Attempt to restore terminal state on close. + _ = cr.maybeRestoreTerm(iAmHoldingTheLock{}) + + cr.state = readerStateClosed + close(cr.closed) // interrupt blocked sends. + cr.cond.Broadcast() + } + + return nil +} + +// PasswordReader is a ContextReader that reads passwords from the underlying +// terminal. +type PasswordReader ContextReader + +// Password returns a PasswordReader from a ContextReader. +// The returned PasswordReader is only functional if the underlying reader is a +// terminal. +func (cr *ContextReader) Password() *PasswordReader { + return (*PasswordReader)(cr) +} + +// ReadContext reads a password from the underlying reader, provided that the +// reader is a terminal. It is equivalent to ContextReader.ReadPassword. +// It follows the semantics of ReadContext. +func (pr *PasswordReader) ReadContext(ctx context.Context) ([]byte, error) { + cr := (*ContextReader)(pr) + return cr.ReadPassword(ctx) +} diff --git a/lib/utils/prompt/context_reader_test.go b/lib/utils/prompt/context_reader_test.go new file mode 100644 index 0000000000000..e58fc1c58f714 --- /dev/null +++ b/lib/utils/prompt/context_reader_test.go @@ -0,0 +1,273 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prompt + +import ( + "context" + "io" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/term" +) + +func TestContextReader(t *testing.T) { + pr, pw := io.Pipe() + t.Cleanup(func() { pr.Close() }) + t.Cleanup(func() { pw.Close() }) + + write := func(t *testing.T, s string) { + t.Helper() + _, err := pw.Write([]byte(s)) + assert.NoError(t, err, "Write failed") + } + + ctx := context.Background() + cr := NewContextReader(pr) + + t.Run("simple read", func(t *testing.T) { + go write(t, "hello") + buf, err := cr.ReadContext(ctx) + require.NoError(t, err) + require.Equal(t, string(buf), "hello") + }) + + t.Run("reclaim abandoned read", func(t *testing.T) { + done := make(chan struct{}) + cancelCtx, cancel := context.WithCancel(ctx) + go func() { + time.Sleep(1 * time.Millisecond) // give ReadContext time to block + cancel() + write(t, "after cancel") + close(done) + }() + buf, err := cr.ReadContext(cancelCtx) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, buf) + + <-done // wait for write + buf, err = cr.ReadContext(ctx) + require.NoError(t, err) + require.Equal(t, string(buf), "after cancel") + }) + + t.Run("close ContextReader", func(t *testing.T) { + go func() { + time.Sleep(1 * time.Millisecond) // give ReadContext time to block + assert.NoError(t, cr.Close(), "Close errored") + }() + _, err := cr.ReadContext(ctx) + require.ErrorIs(t, err, ErrReaderClosed) + + // Subsequent reads fail. + _, err = cr.ReadContext(ctx) + require.ErrorIs(t, err, ErrReaderClosed) + + // Ongoing read after Close is dropped. + write(t, "unblock goroutine") + buf, err := cr.ReadContext(ctx) + assert.ErrorIs(t, err, ErrReaderClosed) + assert.Empty(t, buf, "buf not empty") + + // Multiple closes are fine. + assert.NoError(t, cr.Close(), "2nd Close failed") + }) + + // Re-creating is safe because the tests above leave no "pending" reads. + cr = NewContextReader(pr) + + t.Run("close underlying reader", func(t *testing.T) { + go func() { + write(t, "before close") + pw.CloseWithError(io.EOF) + }() + + // Read the last chunk of data successfully. + buf, err := cr.ReadContext(ctx) + require.NoError(t, err) + require.Equal(t, string(buf), "before close") + + // Next read fails because underlying reader is closed. + buf, err = cr.ReadContext(ctx) + require.ErrorIs(t, err, io.EOF) + require.Empty(t, buf) + }) +} + +func TestContextReader_ReadPassword(t *testing.T) { + pr, pw := io.Pipe() + write := func(t *testing.T, s string) { + t.Helper() + _, err := pw.Write([]byte(s)) + assert.NoError(t, err, "Write failed") + } + + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666) + require.NoError(t, err, "Failed to open %v", os.DevNull) + defer devNull.Close() + + term := &fakeTerm{reader: pr} + cr := NewContextReader(pr) + cr.term = term + cr.fd = int(devNull.Fd()) // arbitrary, doesn't matter because term functions are mocked. + + ctx := context.Background() + t.Run("read password", func(t *testing.T) { + const want = "llama45" + go write(t, want) + + got, err := cr.ReadPassword(ctx) + require.NoError(t, err, "ReadPassword failed") + assert.Equal(t, want, string(got), "ReadPassword mismatch") + }) + + t.Run("intertwine reads", func(t *testing.T) { + const want1 = "hello, world" + go write(t, want1) + got, err := cr.ReadPassword(ctx) + require.NoError(t, err, "ReadPassword failed") + assert.Equal(t, want1, string(got), "ReadPassword mismatch") + + const want2 = "goodbye, world" + go write(t, want2) + got, err = cr.ReadContext(ctx) + require.NoError(t, err, "ReadContext failed") + assert.Equal(t, want2, string(got), "ReadContext mismatch") + }) + + t.Run("password read turned clean", func(t *testing.T) { + require.False(t, term.restoreCalled, "restoreCalled sanity check failed") + + // Give ReadPassword time to block. + cancelCtx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + got, err := cr.ReadPassword(cancelCtx) + require.ErrorIs(t, err, context.DeadlineExceeded, "ReadPassword returned unexpected error") + require.Empty(t, got, "ReadPassword mismatch") + + // Reclaim as clean read. + const want = "abandoned pwd read" + go func() { + // Once again, give ReadContext time to block. + // This way we force a restore. + time.Sleep(1 * time.Millisecond) + write(t, want) + }() + got, err = cr.ReadContext(ctx) + require.NoError(t, err, "ReadContext failed") + assert.Equal(t, want, string(got), "ReadContext mismatch") + assert.True(t, term.restoreCalled, "term.Restore not called") + }) + + t.Run("Close", func(t *testing.T) { + require.NoError(t, cr.Close(), "Close errored") + + _, err := cr.ReadPassword(ctx) + require.ErrorIs(t, err, ErrReaderClosed, "ReadPassword returned unexpected error") + }) +} + +func TestNotifyExit_restoresTerminal(t *testing.T) { + oldStdin := Stdin() + t.Cleanup(func() { SetStdin(oldStdin) }) + + pr, _ := io.Pipe() + + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0666) + require.NoError(t, err, "Failed to open %v", os.DevNull) + defer devNull.Close() + + term := &fakeTerm{reader: pr} + ctx := context.Background() + + tests := []struct { + name string + doRead func(ctx context.Context, cr *ContextReader) error + wantRestore bool + }{ + { + name: "no pending read", + doRead: func(ctx context.Context, cr *ContextReader) error { + <-ctx.Done() + return ctx.Err() + }, + }, + { + name: "pending clean read", + doRead: func(ctx context.Context, cr *ContextReader) error { + _, err := cr.ReadContext(ctx) + return err + }, + }, + { + name: "pending password read", + doRead: func(ctx context.Context, cr *ContextReader) error { + _, err := cr.ReadPassword(ctx) + return err + }, + wantRestore: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + term.restoreCalled = false // reset state between tests + + cr := NewContextReader(pr) + cr.term = term + cr.fd = int(devNull.Fd()) // arbitrary + SetStdin(cr) + + // Give the read time to block. + ctx, cancel := context.WithTimeout(ctx, 1*time.Millisecond) + defer cancel() + err := test.doRead(ctx, cr) + require.ErrorIs(t, err, context.DeadlineExceeded, "unexpected read error") + + NotifyExit() // closes Stdin + assert.Equal(t, test.wantRestore, term.restoreCalled, "term.Restore mismatch") + }) + } +} + +type fakeTerm struct { + reader io.Reader + restoreCalled bool +} + +func (t *fakeTerm) GetState(fd int) (*term.State, error) { + return &term.State{}, nil +} + +func (t *fakeTerm) IsTerminal(fd int) bool { + return true +} + +func (t *fakeTerm) ReadPassword(fd int) ([]byte, error) { + const bufLen = 1024 // arbitrary, big enough for test data + data := make([]byte, bufLen) + n, err := t.reader.Read(data) + data = data[:n] + return data, err +} + +func (t *fakeTerm) Restore(fd int, oldState *term.State) error { + t.restoreCalled = true + return nil +} diff --git a/lib/utils/prompt/mock.go b/lib/utils/prompt/mock.go new file mode 100644 index 0000000000000..455a700da3d7d --- /dev/null +++ b/lib/utils/prompt/mock.go @@ -0,0 +1,74 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prompt + +import ( + "context" + "errors" + "sync" +) + +type FakeReplyFunc func(context.Context) (string, error) + +type FakeReader struct { + mu sync.Mutex + replies []FakeReplyFunc +} + +// NewFakeReader returns a fake that can be used in place of a ContextReader. +// Call Add functions in the desired order to configure responses. Each call +// represents a read reply, in order. +func NewFakeReader() *FakeReader { + return &FakeReader{} +} + +func (r *FakeReader) AddReply(fn FakeReplyFunc) *FakeReader { + r.mu.Lock() + defer r.mu.Unlock() + r.replies = append(r.replies, fn) + return r +} + +func (r *FakeReader) AddString(s string) *FakeReader { + return r.AddReply(func(context.Context) (string, error) { + return s, nil + }) +} + +func (r *FakeReader) AddError(err error) *FakeReader { + return r.AddReply(func(context.Context) (string, error) { + return "", err + }) +} + +func (r *FakeReader) ReadContext(ctx context.Context) ([]byte, error) { + r.mu.Lock() + if len(r.replies) == 0 { + r.mu.Unlock() + return nil, errors.New("no fake replies available") + } + + // Pop first reply. + fn := r.replies[0] + r.replies = r.replies[1:] + r.mu.Unlock() + + val, err := fn(ctx) + return []byte(val), err +} + +func (r *FakeReader) ReadPassword(ctx context.Context) ([]byte, error) { + return r.ReadContext(ctx) +} diff --git a/lib/utils/prompt/stdin.go b/lib/utils/prompt/stdin.go index 56c672f2e0e28..49a35d55f7da9 100644 --- a/lib/utils/prompt/stdin.go +++ b/lib/utils/prompt/stdin.go @@ -17,127 +17,54 @@ limitations under the License. package prompt import ( - "context" - "errors" - "io" "os" "sync" ) var ( - stdinOnce = &sync.Once{} - stdin *ContextReader + stdinMU sync.Mutex + stdin StdinReader ) -// Stdin returns a singleton ContextReader wrapped around os.Stdin. -// -// os.Stdin should not be used directly after the first call to this function -// to avoid losing data. Closing this ContextReader will prevent all future -// reads for all callers. -func Stdin() *ContextReader { - stdinOnce.Do(func() { - stdin = NewContextReader(os.Stdin) - }) - return stdin +// StdinReader contains ContextReader methods applicable to stdin. +type StdinReader interface { + Reader + SecureReader } -// ErrReaderClosed is returned from ContextReader.Read after it was closed. -var ErrReaderClosed = errors.New("ContextReader has been closed") - -// ContextReader is a wrapper around io.Reader where each individual -// ReadContext call can be canceled using a context. -type ContextReader struct { - r io.Reader - data chan []byte - close chan struct{} - - mu sync.RWMutex - err error -} - -// NewContextReader creates a new ContextReader wrapping r. Callers should not -// use r after creating this ContextReader to avoid loss of data (the last read -// will be lost). +// Stdin returns a singleton ContextReader wrapped around os.Stdin. // -// Callers are responsible for closing the ContextReader to release associated -// resources. -func NewContextReader(r io.Reader) *ContextReader { - cr := &ContextReader{ - r: r, - data: make(chan []byte), - close: make(chan struct{}), - } - go cr.read() - return cr -} - -func (r *ContextReader) setErr(err error) { - r.mu.Lock() - defer r.mu.Unlock() - if r.err != nil { - // Keep only the first encountered error. - return - } - r.err = err -} - -func (r *ContextReader) getErr() error { - r.mu.RLock() - defer r.mu.RUnlock() - return r.err -} - -func (r *ContextReader) read() { - defer close(r.data) - - for { - // Allocate a new buffer for every read because we need to send it to - // another goroutine. - buf := make([]byte, 4*1024) // 4kB, matches Linux page size. - n, err := r.r.Read(buf) - r.setErr(err) - buf = buf[:n] - if n == 0 { - return - } - select { - case <-r.close: - return - case r.data <- buf: - } +// os.Stdin should not be used directly after the first call to this function +// to avoid losing data. +func Stdin() StdinReader { + stdinMU.Lock() + defer stdinMU.Unlock() + if stdin == nil { + cr := NewContextReader(os.Stdin) + go cr.handleInterrupt() + stdin = cr } + return stdin } -// ReadContext returns the next chunk of output from the reader. If ctx is -// canceled before any data is available, ReadContext will return too. If r -// was closed, ReadContext will return immediately with ErrReaderClosed. -func (r *ContextReader) ReadContext(ctx context.Context) ([]byte, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-r.close: - // Close was called, unblock immediately. - // r.data might still be blocked if it's blocked on the Read call. - return nil, r.getErr() - case buf, ok := <-r.data: - if !ok { - // r.data was closed, so the read goroutine has finished. - // No more data will be available, return the latest error. - return nil, r.getErr() - } - return buf, nil - } +// SetStdin allows callers to change the Stdin reader. +// Useful to replace Stdin for tests, but should be avoided in production code. +func SetStdin(rd StdinReader) { + stdinMU.Lock() + defer stdinMU.Unlock() + stdin = rd } -// Close releases the background resources of r. All ReadContext calls will -// unblock immediately. -func (r *ContextReader) Close() { - select { - case <-r.close: - // Already closed, do nothing. - return - default: - close(r.close) - r.setErr(ErrReaderClosed) +// NotifyExit notifies prompt singletons, such as Stdin, that the program is +// about to exit. This allows singletons to perform actions such as restoring +// terminal state. +// Once NotifyExit is called the singletons will be closed. +func NotifyExit() { + // Note: don't call methods such as Stdin() here, we don't want to + // inadvertently hijack the prompts on exit. + stdinMU.Lock() + if cr, ok := stdin.(*ContextReader); ok { + _ = cr.Close() } + stdinMU.Unlock() } diff --git a/lib/utils/prompt/stdin_test.go b/lib/utils/prompt/stdin_test.go deleted file mode 100644 index 32d25457b836d..0000000000000 --- a/lib/utils/prompt/stdin_test.go +++ /dev/null @@ -1,76 +0,0 @@ -/* -Copyright 2021 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package prompt - -import ( - "context" - "io" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestContextReader(t *testing.T) { - pr, pw := io.Pipe() - t.Cleanup(func() { pr.Close() }) - t.Cleanup(func() { pw.Close() }) - - write := func(t *testing.T, s string) { - _, err := pw.Write([]byte(s)) - require.NoError(t, err) - } - ctx := context.Background() - - r := NewContextReader(pr) - - t.Run("simple read", func(t *testing.T) { - go write(t, "hello") - buf, err := r.ReadContext(ctx) - require.NoError(t, err) - require.Equal(t, string(buf), "hello") - }) - - t.Run("cancelled read", func(t *testing.T) { - cancelCtx, cancel := context.WithCancel(ctx) - go cancel() - buf, err := r.ReadContext(cancelCtx) - require.ErrorIs(t, err, context.Canceled) - require.Empty(t, buf) - - go write(t, "after cancel") - buf, err = r.ReadContext(ctx) - require.NoError(t, err) - require.Equal(t, string(buf), "after cancel") - }) - - t.Run("close underlying reader", func(t *testing.T) { - go func() { - write(t, "before close") - pw.CloseWithError(io.EOF) - }() - - // Read the last chunk of data successfully. - buf, err := r.ReadContext(ctx) - require.NoError(t, err) - require.Equal(t, string(buf), "before close") - - // Next read fails because underlying reader is closed. - buf, err = r.ReadContext(ctx) - require.ErrorIs(t, err, io.EOF) - require.Empty(t, buf) - }) -} diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index ba3135376381d..ece6fc15d13b1 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -60,6 +60,7 @@ import ( "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/prompt" "github.com/gravitational/kingpin" "github.com/gravitational/trace" @@ -363,7 +364,10 @@ func main() { default: cmdLine = cmdLineOrig } - if err := Run(cmdLine); err != nil { + + err := Run(cmdLine) + prompt.NotifyExit() // Allow prompt to restore terminal state on exit. + if err != nil { var exitError *exitCodeError if errors.As(err, &exitError) { os.Exit(exitError.code)