diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index d50d893..dd6f595 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -4,12 +4,14 @@ import ( "github.com/conplementag/cops-hq/internal" "github.com/spf13/cobra" "github.com/spf13/viper" + "os" ) type cli struct { - programName string - version string - rootCmd *cobra.Command + programName string + version string + rootCmd *cobra.Command + defaultCommand string } func (cli *cli) AddBaseCommand(use string, shortInfo string, longDescription string, runFunction func()) Command { @@ -48,6 +50,27 @@ func (cli *cli) AddBaseCommand(use string, shortInfo string, longDescription str } func (cli *cli) Run() error { + if cli.defaultCommand != "" { + rootCommand := cli.GetRootCommand() + registeredCommands := rootCommand.Commands() + + var isCommandSet = false + for _, a := range registeredCommands { + for _, b := range os.Args[1:] { + if a.Name() == b { + isCommandSet = true + break + } + } + } + + // if no command set on the command line, use the default command by extending the existing command line args + if !isCommandSet { + args := append([]string{cli.defaultCommand}, os.Args[1:]...) + rootCommand.SetArgs(args) + } + } + err := cli.rootCmd.Execute() return internal.ReturnErrorOrPanic(err) } @@ -55,3 +78,11 @@ func (cli *cli) Run() error { func (cli *cli) GetRootCommand() *cobra.Command { return cli.rootCmd } + +func (cli *cli) OnInitialize(initFunction func()) { + cobra.OnInitialize(initFunction) +} + +func (cli *cli) SetDefaultCommand(command string) { + cli.defaultCommand = command +} diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index d574a37..1e71ae8 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -115,7 +115,6 @@ func Test_PersistentParametersAreAvailableThroughViperInSubcommands(t *testing.T func Test_ParametersFromDifferentCommandsShouldNotOverwriteEachOtherInViper(t *testing.T) { // Arrange - cli := New("myprog", "0.0.1") // Act @@ -140,3 +139,61 @@ func Test_ParametersFromDifferentCommandsShouldNotOverwriteEachOtherInViper(t *t assert.Equal(t, "johndoe", viper.GetString("my-arg")) assert.Equal(t, true, viper.GetBool("truth")) } + +func Test_DefaultCommands(t *testing.T) { + // Arrange + cli := New("myprog", "0.0.1") + cli.SetDefaultCommand("first") + + wasCalled := false + + cli.AddBaseCommand("first", "Simple test command 1", "big description", func() { + wasCalled = true + }) + cli.AddBaseCommand("second", "Simple test command 2", "big description", func() {}) + + // Act + testing_utils.PrepareCommandForTesting(cli.GetRootCommand(), "") // intentionally no args, to test the default was called + cli.Run() + + // Assert + assert.True(t, wasCalled) +} + +func Test_InitializerFunctionCalledWhenCommandExecuted(t *testing.T) { + // Arrange + cli := New("myprog", "0.0.1") + wasCalled := false + + cli.OnInitialize(func() { + wasCalled = true + }) + + // Act + cli.AddBaseCommand("first", "Simple test command 1", "big description", nil) + + testing_utils.PrepareCommandForTesting(cli.GetRootCommand(), "first") + cli.Run() + + // Assert + assert.True(t, wasCalled) +} + +func Test_InitializerFunctionNotCalledWhenNoCommandMatching(t *testing.T) { + // Arrange + cli := New("myprog", "0.0.1") + wasCalled := false + + cli.OnInitialize(func() { + wasCalled = true + }) + + // Act + cli.AddBaseCommand("first", "Simple test command 1", "big description", nil) + + testing_utils.PrepareCommandForTesting(cli.GetRootCommand(), "non-existing") + cli.Run() + + // Assert + assert.False(t, wasCalled) +} diff --git a/pkg/cli/factory.go b/pkg/cli/factory.go index 00c8429..f6bfc9b 100644 --- a/pkg/cli/factory.go +++ b/pkg/cli/factory.go @@ -18,6 +18,12 @@ type Cli interface { // GetRootCommand returns the root top level command, directly as cobra.Command which is the library used // under the hood. GetRootCommand() *cobra.Command + + // OnInitialize sets the passed function to be run when each command is called. Consider this like a global initializer + // hook. + OnInitialize(initFunction func()) + + SetDefaultCommand(command string) } // New creates a new Cli instance