From 2cc615d654362c9de911ea4329a86f86392b0142 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Thu, 3 Mar 2022 15:33:03 +0000 Subject: [PATCH] Add tests for motd fixes Part of this includes renaming export_test.go to export.go so I could test the MOTD outside of lib/client/export.go --- lib/client/api.go | 4 +- lib/client/{export_test.go => export.go} | 0 tool/tsh/tsh.go | 13 +++++ tool/tsh/tsh_test.go | 74 +++++++++++++++++++++--- 4 files changed, 80 insertions(+), 11 deletions(-) rename lib/client/{export_test.go => export.go} (100%) diff --git a/lib/client/api.go b/lib/client/api.go index 9bae66df44d08..fb44ee7b1fbb8 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2441,7 +2441,7 @@ func (tc *TeleportClient) LogoutAll() error { return nil } -// PingAndShowMOTD pings the Teleport Proxy and displays the MOTD if it's available. +// PingAndShowMOTD pings the Teleport Proxy and displays the Message Of The Day if it's available. func (tc *TeleportClient) PingAndShowMOTD(ctx context.Context) (*webclient.PingResponse, error) { pr, err := tc.Ping(ctx) if err != nil { @@ -2660,7 +2660,7 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error { // use might enter at the prompt. Whatever the user enters will // be simply discarded, and the user can still CTRL+C out if they // disagree. - _, err := passwordFromConsole() + _, err := passwordFromConsoleFn() if err != nil { return trace.Wrap(err) } diff --git a/lib/client/export_test.go b/lib/client/export.go similarity index 100% rename from lib/client/export_test.go rename to lib/client/export.go diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 51e93821e48bb..572c9890ea859 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -260,6 +260,8 @@ type CLIConf struct { // overrideStdout allows to switch standard output source for resource command. Used in tests. overrideStdout io.Writer + // overrideStderr allows to switch standard error source for resource command. Used in tests. + overrideStderr io.Writer // mockSSOLogin used in tests to override sso login handler in teleport client. mockSSOLogin client.SSOLoginFunc @@ -306,6 +308,14 @@ func (c *CLIConf) Stdout() io.Writer { return os.Stdout } +// Stderr returns the stderr writer. +func (c *CLIConf) Stderr() io.Writer { + if c.overrideStderr != nil { + return c.overrideStderr + } + return os.Stderr +} + func main() { cmdLineOrig := os.Args[1:] var cmdLine []string @@ -2106,6 +2116,9 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro } } + tc.Config.Stderr = cf.Stderr() + tc.Config.Stdout = cf.Stdout() + tc.Config.Reason = cf.Reason tc.Config.Invited = cf.Invited tc.Config.DisplayParticipantRequirements = cf.displayParticipantRequirements diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index c8556f67fb5ac..7ad4a1265623d 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -17,12 +17,15 @@ limitations under the License. package main import ( + "bufio" + "bytes" "context" "fmt" "io/ioutil" "net" "os" "path/filepath" + "strings" "testing" "time" @@ -174,7 +177,11 @@ func TestOIDCLogin(t *testing.T) { connector := mockConnector(t) - authProcess, proxyProcess := makeTestServers(t, withBootstrap(populist, dictator, connector, alice)) + motd := "MESSAGE_OF_THE_DAY_OIDC" + authProcess, proxyProcess := makeTestServers(t, + withBootstrap(populist, dictator, connector, alice), + withMOTD(t, motd), + ) authServer := authProcess.GetAuthServer() require.NotNil(t, authServer) @@ -212,6 +219,8 @@ func TestOIDCLogin(t *testing.T) { } }() + buf := bytes.NewBuffer([]byte{}) + sc := bufio.NewScanner(buf) err = Run([]string{ "login", "--insecure", @@ -222,6 +231,7 @@ func TestOIDCLogin(t *testing.T) { }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) cf.SiteName = "localhost" + cf.overrideStderr = buf return nil })) @@ -230,11 +240,22 @@ func TestOIDCLogin(t *testing.T) { // verify that auto-request happened require.True(t, didAutoRequest.Load()) + findMOTD(t, sc, motd) // if we got this far, then tsh successfully registered name change from `alice` to // `alice@example.com`, since the correct name needed to be used for the access // request to be generated. } +func findMOTD(t *testing.T, sc *bufio.Scanner, motd string) { + t.Helper() + for sc.Scan() { + if strings.Contains(sc.Text(), motd) { + return + } + } + require.Fail(t, "Failed to find %q MOTD in the logs", motd) +} + // TestLoginIdentityOut makes sure that "tsh login --out " command // writes identity credentials to the specified path. func TestLoginIdentityOut(t *testing.T) { @@ -282,7 +303,11 @@ func TestRelogin(t *testing.T) { require.NoError(t, err) alice.SetRoles([]string{"access"}) - authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice)) + motd := "RELOGIN MOTD PRESENT" + authProcess, proxyProcess := makeTestServers(t, + withBootstrap(connector, alice), + withMOTD(t, motd), + ) authServer := authProcess.GetAuthServer() require.NotNil(t, authServer) @@ -290,6 +315,8 @@ func TestRelogin(t *testing.T) { proxyAddr, err := proxyProcess.ProxyWebAddr() require.NoError(t, err) + buf := bytes.NewBuffer([]byte{}) + sc := bufio.NewScanner(buf) err = Run([]string{ "login", "--insecure", @@ -298,9 +325,11 @@ func TestRelogin(t *testing.T) { "--proxy", proxyAddr.String(), }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) + cf.overrideStderr = buf return nil })) require.NoError(t, err) + findMOTD(t, sc, motd) err = Run([]string{ "login", @@ -308,10 +337,20 @@ func TestRelogin(t *testing.T) { "--debug", "--proxy", proxyAddr.String(), "localhost", - }, setHomePath(tmpHomePath)) + }, setHomePath(tmpHomePath), + cliOption(func(cf *CLIConf) error { + cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) + cf.overrideStderr = buf + return nil + })) require.NoError(t, err) + findMOTD(t, sc, motd) - err = Run([]string{"logout"}, setHomePath(tmpHomePath)) + err = Run([]string{"logout"}, setHomePath(tmpHomePath), + cliOption(func(cf *CLIConf) error { + cf.overrideStderr = buf + return nil + })) require.NoError(t, err) err = Run([]string{ @@ -323,8 +362,10 @@ func TestRelogin(t *testing.T) { "localhost", }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) + cf.overrideStderr = buf return nil })) + findMOTD(t, sc, motd) require.NoError(t, err) } @@ -1194,8 +1235,8 @@ func TestSetX11Config(t *testing.T) { } type testServersOpts struct { - bootstrap []types.Resource - authConfigFunc func(cfg *service.AuthConfig) + bootstrap []types.Resource + authConfigFuncs []func(cfg *service.AuthConfig) } type testServerOptFunc func(o *testServersOpts) @@ -1208,7 +1249,11 @@ func withBootstrap(bootstrap ...types.Resource) testServerOptFunc { func withAuthConfig(fn func(cfg *service.AuthConfig)) testServerOptFunc { return func(o *testServersOpts) { - o.authConfigFunc = fn + if o.authConfigFuncs == nil { + o.authConfigFuncs = []func(cfg *service.AuthConfig){} + } + + o.authConfigFuncs = append(o.authConfigFuncs, fn) } } @@ -1223,6 +1268,17 @@ func withClusterName(t *testing.T, n string) testServerOptFunc { }) } +func withMOTD(t *testing.T, motd string) testServerOptFunc { + oldpass := client.PasswordFromConsoleFn + *client.PasswordFromConsoleFn = func() (string, error) { + return "", nil + } + t.Cleanup(func() { *client.PasswordFromConsoleFn = *oldpass }) + return withAuthConfig(func(cfg *service.AuthConfig) { + cfg.Preference.SetMessageOfTheDay(motd) + }) +} + func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.TeleportProcess, proxy *service.TeleportProcess) { var options testServersOpts for _, opt := range opts { @@ -1255,8 +1311,8 @@ func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.Tel cfg.Proxy.Enabled = false cfg.Log = utils.NewLoggerForTests() - if options.authConfigFunc != nil { - options.authConfigFunc(&cfg.Auth) + for _, fn := range options.authConfigFuncs { + fn(&cfg.Auth) } auth, err = service.NewTeleport(cfg)