Skip to content

Commit

Permalink
refactor: refactor config management for testability
Browse files Browse the repository at this point in the history
  • Loading branch information
moyiz committed Dec 13, 2023
1 parent ff8f67a commit f822ad1
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 71 deletions.
6 changes: 3 additions & 3 deletions cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ It will create the config directory if at does not exist.`,
// Create directories if not exist
configDir := path.Dir(configFile)
if _, err := os.Stat(configDir); err != nil {
os.MkdirAll(configDir, 0775)
os.MkdirAll(configDir, 0o775)
}

// Separate alias key and command at `--` or set the command to the last argument
Expand All @@ -40,8 +40,8 @@ It will create the config directory if at does not exist.`,
sep = len(args) - 1
}

c := config.GetFromFiles(configFile)
if err := c.SetAlias(args[:sep], args[sep:]); err != nil {
config.LoadFiles(configFile)
if err := config.SetAlias(args[:sep], args[sep:]); err != nil {
fmt.Println("na: add:", err)
os.Exit(1)
}
Expand Down
5 changes: 3 additions & 2 deletions cmd/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"slices"

"github.com/moyiz/na/internal/cli"
"github.com/moyiz/na/internal/config"
"github.com/moyiz/na/internal/consts"
"github.com/spf13/cobra"
Expand All @@ -28,8 +29,8 @@ func validListArgs(cmd *cobra.Command, args []string, toComplete string) ([]stri
if slices.Contains(os.Args, "--") {
return []string{}, cobra.ShellCompDirectiveDefault
}
config.GetFromFiles(AllConfigFiles()...)
return config.ListNextParts(args), cobra.ShellCompDirectiveNoFileComp
config.LoadFiles(AllConfigFiles()...)
return cli.ListNextParts(config.ListAliases(), args), cobra.ShellCompDirectiveNoFileComp
}

func listRun(cmd *cobra.Command, args []string) {
Expand Down
9 changes: 5 additions & 4 deletions cmd/remove.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"
"strings"

"github.com/moyiz/na/internal/cli"
"github.com/moyiz/na/internal/config"
"github.com/moyiz/na/internal/consts"
"github.com/spf13/cobra"
Expand All @@ -23,8 +24,8 @@ By default, the global (home directory config) configuration is used.`,
Args: cobra.MinimumNArgs(1),
ValidArgsFunction: validRemoveArgs,
Run: func(cmd *cobra.Command, args []string) {
c := config.GetFromFiles(ActiveConfigFile())
if err := c.UnsetAlias(args...); err != nil {
config.LoadFiles(ActiveConfigFile())
if err := config.UnsetAlias(args...); err != nil {
fmt.Println("na:", strings.Join(args, " ")+":", err)
os.Exit(1)
}
Expand All @@ -35,6 +36,6 @@ func validRemoveArgs(cmd *cobra.Command, args []string, toComplete string) ([]st
if slices.Contains(os.Args, "--") {
return []string{}, cobra.ShellCompDirectiveDefault
}
config.GetFromFiles(ActiveConfigFile())
return config.ListNextParts(args), cobra.ShellCompDirectiveNoFileComp
config.LoadFiles(ActiveConfigFile())
return cli.ListNextParts(config.ListAliases(), args), cobra.ShellCompDirectiveNoFileComp
}
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func AllConfigFiles() []string {

func init() {
cobra.EnableCommandSorting = true
cobra.OnInitialize(func() { config.GetFromFiles(AllConfigFiles()...) })
cobra.OnInitialize(func() { config.LoadFiles(AllConfigFiles()...) })
rootCmd.PersistentFlags().StringP("config", "c", "", "Path of the config file to use")
rootCmd.PersistentFlags().BoolP("local", "l", false, "Use local config (.na.yaml)")
rootCmd.MarkFlagsMutuallyExclusive("local", "config")
Expand Down
9 changes: 5 additions & 4 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"slices"
"strings"

"github.com/moyiz/na/internal/cli"
"github.com/moyiz/na/internal/config"
"github.com/moyiz/na/internal/consts"
"github.com/moyiz/na/internal/utils"
Expand Down Expand Up @@ -34,8 +35,8 @@ func validRunArgs(cmd *cobra.Command, args []string, toComplete string) ([]strin
// Potential location to auto complete commands
return []string{}, cobra.ShellCompDirectiveDefault
}
config.GetFromFiles(AllConfigFiles()...)
return config.ListNextParts(args), cobra.ShellCompDirectiveNoFileComp
config.LoadFiles(AllConfigFiles()...)
return cli.ListNextParts(config.ListAliases(), args), cobra.ShellCompDirectiveNoFileComp
}

func runRun(cmd *cobra.Command, args []string) {
Expand All @@ -52,8 +53,8 @@ func runRun(cmd *cobra.Command, args []string) {
aliasParts = args
}

c := config.GetFromFiles(AllConfigFiles()...)
if alias, err := c.GetAlias(aliasParts...); err != nil {
config.LoadFiles(AllConfigFiles()...)
if alias, err := config.GetAlias(aliasParts...); err != nil {
fmt.Println("na:", strings.Join(aliasParts, " ")+":", err)
} else {
utils.RunInCurrentShell(alias.Command, extraArgs)
Expand Down
31 changes: 31 additions & 0 deletions internal/cli/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cli

import (
"slices"
"strings"

"github.com/moyiz/na/internal/config"
)

// Given a list of alias name parts, return a list of valid next parts.
// Example:
//
// my:
// aliases:
// one: cmd1
// two: cmd2
//
// ListNextParts([]string{"my"}) -> []string{"aliases"}
// ListNextParts([]string{"my", "aliases"}) -> []string{"cmd1", "cmd2"}
func ListNextParts(aliases []config.Alias, parts []string) []string {
currentPrefix := strings.Join(parts, " ")
suggestions := make([]string, 0)
for _, a := range aliases {
trail, found := strings.CutPrefix(a.Name, currentPrefix)
if tf := strings.Fields(trail); found && len(tf) > 0 && !slices.Contains(suggestions, tf[0]) {
suggestions = append(suggestions, tf[0])
}
}

return suggestions
}
108 changes: 51 additions & 57 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,87 @@ package config
import (
"bytes"
"encoding/json"
"errors"
"path"
"sort"
"strings"

"github.com/spf13/viper"
)

func GetFromFiles(filePath ...string) *Config {
type Config struct {
v *viper.Viper
}

var c Config

func init() {
c = Config{v: viper.New()}
}

type Alias struct {
Name string
Command string
}

func LoadFiles(filePath ...string) map[string]any {
return c.LoadFiles(filePath...)
}

func (c *Config) LoadFiles(filePath ...string) map[string]any {
for i, p := range filePath {
viper.SetConfigFile(path.Clean(p))
c.v.SetConfigFile(path.Clean(p))
if i == 0 {
viper.ReadInConfig()
c.v.ReadInConfig()
} else {
viper.MergeInConfig()
c.v.MergeInConfig()
}
}
return &Config{aliases: viper.AllSettings()}
return c.v.AllSettings()
}

type Config struct {
aliases map[string]interface{}
func SetAlias(key []string, command []string) error {
return c.SetAlias(key, command)
}

func (c *Config) Write() error {
encoded, _ := json.MarshalIndent(c.aliases, "", " ")
viper.ReadConfig(bytes.NewReader(encoded))
return viper.WriteConfig()
func (c *Config) SetAlias(name []string, command []string) error {
c.v.Set(strings.Join(name, "."), strings.Join(command, " "))
return c.v.WriteConfig()
}

func (c *Config) SetAlias(key []string, command []string) error {
viper.Set(strings.Join(key, "."), strings.Join(command, " "))
return viper.WriteConfig()
func UnsetAlias(key ...string) error {
return c.UnsetAlias(key...)
}

func (c *Config) UnsetAlias(key ...string) error {
var parent map[string]interface{}
var keyIsMap, keyExists bool
keySize := len(key)
aliasPointer := c.aliases
settings := c.v.AllSettings()
aliasWalker := settings
for i, k := range key {
parent = aliasPointer
_, keyExists = aliasPointer[k]
aliasPointer, keyIsMap = aliasPointer[k].(map[string]interface{})
parent = aliasWalker
_, keyExists = aliasWalker[k]
aliasWalker, keyIsMap = aliasWalker[k].(map[string]interface{})
if !keyExists {
return errors.New("not found")
return ErrAliasNotFound
} else if !keyIsMap && i < keySize-1 {
return errors.New("key is invalid. Did you mean `" + strings.Join(key[:i+1], " ") + "`?")
return ErrInvalidAliasKey{strings.Join(key[:i+1], " ")}
} else if i == keySize-1 {
break
}
}
delete(parent, key[keySize-1])
return c.Write()
encoded, _ := json.MarshalIndent(settings, "", " ")
c.v.ReadConfig(bytes.NewReader(encoded))
return c.v.WriteConfig()
}

type Alias struct {
Name string
Command string
func ListAliases(prefix ...string) []Alias {
return c.ListAliases(prefix...)
}

func (c *Config) ListAliases(prefix ...string) []Alias {
return ListAliases(prefix...)
}

func ListAliases(prefix ...string) []Alias {
keys := viper.AllKeys()
keys := c.v.AllKeys()
sort.Strings(keys)

aliases := make([]Alias, 0)
Expand All @@ -78,43 +92,23 @@ func ListAliases(prefix ...string) []Alias {
if strings.HasPrefix(k, aliasPrefix) {
aliases = append(aliases, Alias{
Name: strings.ReplaceAll(k, ".", " "),
Command: viper.GetString(k),
Command: c.v.GetString(k),
})
}
}

return aliases
}

func GetAlias(part ...string) (Alias, error) {
return c.GetAlias(part...)
}

func (c *Config) GetAlias(part ...string) (Alias, error) {
key := strings.Join(part, ".")
if command := viper.GetString(key); command == "" {
return Alias{}, errors.New("not found")
if command := c.v.GetString(key); command == "" {
return Alias{}, ErrAliasNotFound
} else {
return Alias{Name: strings.ReplaceAll(key, ".", " "), Command: command}, nil
}
}

// Given a list of alias name parts, return a list of valid next parts.
// Example:
//
// my:
// aliases:
// one: cmd1
// two: cmd2
//
// ListNextParts([]string{"my"}) -> []string{"aliases"}
// ListNextParts([]string{"my", "aliases"}) -> []string{"cmd1", "cmd2"}
func ListNextParts(parts []string) []string {
currentPrefix := strings.Join(parts, " ")
suggestions := make([]string, 0)
for _, a := range ListAliases(parts...) {
trail, _ := strings.CutPrefix(a.Name, currentPrefix)
if trailFields := strings.Fields(trail); len(trailFields) > 0 {
suggestions = append(suggestions, trailFields[0])
} else {
break
}
}
return suggestions
}
16 changes: 16 additions & 0 deletions internal/config/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package config

import (
"errors"
"fmt"
)

var ErrAliasNotFound = errors.New("not found")

type ErrInvalidAliasKey struct {
value string
}

func (e ErrInvalidAliasKey) Error() string {
return fmt.Sprintf("Key is invalid. Did you mean `" + e.value + "`?")
}

0 comments on commit f822ad1

Please sign in to comment.