From 0adfa2d8ca8be5c5277886590ff64080d0a74e28 Mon Sep 17 00:00:00 2001 From: bpicode Date: Tue, 25 Jul 2017 00:16:52 +0200 Subject: [PATCH 1/3] Issue #107: add basic zsh completion (command hierarchy only) --- zsh_completions.go | 123 ++++++++++++++++++++++++++++++++++++++++ zsh_completions_test.go | 88 ++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 zsh_completions.go create mode 100644 zsh_completions_test.go diff --git a/zsh_completions.go b/zsh_completions.go new file mode 100644 index 000000000..e5d7e2b2e --- /dev/null +++ b/zsh_completions.go @@ -0,0 +1,123 @@ +package cobra + +import ( + "bytes" + "io" + "fmt" + "strings" +) + +type fWriter struct { + io.Writer +} + +func (fw *fWriter) fWriteLn(format string, a ...interface{}) (int, error) { + return io.WriteString(fw, fmt.Sprintf(format+"\n", a...)) +} + +// GenZshCompletion generates a zsh completion file and writes to the passed writer. +func (cmd *Command) GenZshCompletion(w io.Writer) error { + buf := new(bytes.Buffer) + fw := &fWriter{buf} + + writeHeader(fw, cmd) + maxDepth := maxDepth(cmd) + writeLevelMapping(fw, maxDepth) + writeLevelCases(fw, maxDepth, cmd) + + _, err := buf.WriteTo(w) + return err +} + +func writeHeader(fw *fWriter, cmd *Command) { + fw.fWriteLn("#compdef %s", cmd.Name()) + fw.fWriteLn("") +} + +func maxDepth(c *Command) int { + if len(c.Commands()) == 0 { + return 0 + } + maxDepthSub := 0 + for _, s := range c.Commands() { + subDepth := maxDepth(s) + if subDepth > maxDepthSub { + maxDepthSub = subDepth + } + } + return 1 + maxDepthSub +} + +func writeLevelMapping(fw *fWriter, numLevels int) { + fw.fWriteLn(`_arguments \`) + for i := 1; i <= numLevels; i++ { + fw.fWriteLn(` '%d: :->level%d' \`, i, i) + } + fw.fWriteLn(` '%d: :%s'`, numLevels+1, "_files") + fw.fWriteLn("") +} + +func writeLevelCases(fw *fWriter, maxDepth int, root *Command) { + fw.fWriteLn("case $state in") + defer fw.fWriteLn("esac") + + for i := 1; i <= maxDepth; i++ { + fw.fWriteLn(" level%d)", i) + writeLevel(fw, root, i) + fw.fWriteLn(" ;;") + } + fw.fWriteLn(" *)") + fw.fWriteLn(" _arguments '*: :_files'") + fw.fWriteLn(" ;;") +} + +func writeLevel(fw *fWriter, root *Command, i int) { + fw.fWriteLn(fmt.Sprintf(" case $words[%d] in", i)) + defer fw.fWriteLn(" esac") + + commands := filterByLevel(root, i) + byParent := groupByParent(commands) + + for p, c := range byParent { + names := names(c) + fw.fWriteLn(fmt.Sprintf(" %s)", p)) + fw.fWriteLn(fmt.Sprintf(" _arguments '%d: :(%s)'", i, strings.Join(names, " "))) + fw.fWriteLn(fmt.Sprintf(" ;;")) + } + fw.fWriteLn(" *)") + fw.fWriteLn(" _arguments '*: :_files'") + fw.fWriteLn(" ;;") + +} + +func filterByLevel(c *Command, l int) []*Command { + cs := make([]*Command, 0) + if l == 0 { + cs = append(cs, c) + return cs + } + for _, s := range c.Commands() { + cs = append(cs, filterByLevel(s, l-1)...) + } + return cs +} + +func groupByParent(commands []*Command) map[string][]*Command { + m := make(map[string][]*Command) + for _, c := range commands { + parent := c.Parent() + if parent == nil { + continue + } + m[parent.Name()] = append(m[parent.Name()], c) + } + return m +} + +func names(commands []*Command) []string { + ns := make([]string, len(commands)) + for i, c := range commands { + ns[i] = c.Name() + } + return ns +} diff --git a/zsh_completions_test.go b/zsh_completions_test.go new file mode 100644 index 000000000..8b3c08c5c --- /dev/null +++ b/zsh_completions_test.go @@ -0,0 +1,88 @@ +package cobra + +import ( + "testing" + "bytes" + "strings" +) + +func TestZshCompletion(t *testing.T) { + tcs := []struct { + name string + root *Command + expectedExpressions []string + }{ + { + name: "trivial", + root: &Command{Use: "trivialapp"}, + expectedExpressions: []string{"#compdef trivial"}, + }, + { + name: "linear", + root: func() *Command { + r := &Command{Use: "linear"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub2 := &Command{Use: "sub2"} + sub1.AddCommand(sub2) + + sub3 := &Command{Use: "sub3"} + sub2.AddCommand(sub3) + return r + }(), + expectedExpressions: []string{"sub1", "sub2", "sub3"}, + }, + { + name: "flat", + root: func() *Command { + r := &Command{Use: "flat"} + r.AddCommand(&Command{Use: "c1"}) + r.AddCommand(&Command{Use: "c2"}) + return r + }(), + expectedExpressions: []string{"(c1 c2)"}, + }, + { + name: "tree", + root: func() *Command { + r := &Command{Use: "tree"} + + sub1 := &Command{Use: "sub1"} + r.AddCommand(sub1) + + sub11 := &Command{Use: "sub11"} + sub12 := &Command{Use: "sub12"} + + sub1.AddCommand(sub11) + sub1.AddCommand(sub12) + + sub2 := &Command{Use: "sub2"} + r.AddCommand(sub2) + + sub21 := &Command{Use: "sub21"} + sub22 := &Command{Use: "sub22"} + + sub2.AddCommand(sub21) + sub2.AddCommand(sub22) + + return r + }(), + expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + tc.root.GenZshCompletion(buf) + completion := buf.String() + for _, expectedExpression := range tc.expectedExpressions { + if !strings.Contains(completion, expectedExpression) { + t.Errorf("expected completion to contain '%v' somewhere; got '%v'", expectedExpression, completion) + } + } + }) + } +} From 52d469c0916e373c3501cd944b44bb9241c61b85 Mon Sep 17 00:00:00 2001 From: bpicode Date: Tue, 25 Jul 2017 00:25:38 +0200 Subject: [PATCH 2/3] Issue #107: fix import order --- zsh_completions.go | 2 +- zsh_completions_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/zsh_completions.go b/zsh_completions.go index e5d7e2b2e..d73cddfef 100644 --- a/zsh_completions.go +++ b/zsh_completions.go @@ -2,8 +2,8 @@ package cobra import ( "bytes" - "io" "fmt" + "io" "strings" ) diff --git a/zsh_completions_test.go b/zsh_completions_test.go index 8b3c08c5c..08b851591 100644 --- a/zsh_completions_test.go +++ b/zsh_completions_test.go @@ -1,9 +1,9 @@ package cobra import ( - "testing" "bytes" "strings" + "testing" ) func TestZshCompletion(t *testing.T) { From dcf38a7f948aa0f4ebb74dcd4256660cc4624277 Mon Sep 17 00:00:00 2001 From: bpicode Date: Wed, 26 Jul 2017 20:13:22 +0200 Subject: [PATCH 3/3] Issue #107: use fmt.Fprintf instead of wrapping the writer --- zsh_completions.go | 67 ++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/zsh_completions.go b/zsh_completions.go index d73cddfef..b350aeeca 100644 --- a/zsh_completions.go +++ b/zsh_completions.go @@ -7,31 +7,21 @@ import ( "strings" ) -type fWriter struct { - io.Writer -} - -func (fw *fWriter) fWriteLn(format string, a ...interface{}) (int, error) { - return io.WriteString(fw, fmt.Sprintf(format+"\n", a...)) -} - // GenZshCompletion generates a zsh completion file and writes to the passed writer. func (cmd *Command) GenZshCompletion(w io.Writer) error { buf := new(bytes.Buffer) - fw := &fWriter{buf} - writeHeader(fw, cmd) + writeHeader(buf, cmd) maxDepth := maxDepth(cmd) - writeLevelMapping(fw, maxDepth) - writeLevelCases(fw, maxDepth, cmd) + writeLevelMapping(buf, maxDepth) + writeLevelCases(buf, maxDepth, cmd) _, err := buf.WriteTo(w) return err } -func writeHeader(fw *fWriter, cmd *Command) { - fw.fWriteLn("#compdef %s", cmd.Name()) - fw.fWriteLn("") +func writeHeader(w io.Writer, cmd *Command) { + fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name()) } func maxDepth(c *Command) int { @@ -48,45 +38,46 @@ func maxDepth(c *Command) int { return 1 + maxDepthSub } -func writeLevelMapping(fw *fWriter, numLevels int) { - fw.fWriteLn(`_arguments \`) +func writeLevelMapping(w io.Writer, numLevels int) { + fmt.Fprintln(w, `_arguments \`) for i := 1; i <= numLevels; i++ { - fw.fWriteLn(` '%d: :->level%d' \`, i, i) + fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i) + fmt.Fprintln(w) } - fw.fWriteLn(` '%d: :%s'`, numLevels+1, "_files") - fw.fWriteLn("") + fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files") + fmt.Fprintln(w) } -func writeLevelCases(fw *fWriter, maxDepth int, root *Command) { - fw.fWriteLn("case $state in") - defer fw.fWriteLn("esac") +func writeLevelCases(w io.Writer, maxDepth int, root *Command) { + fmt.Fprintln(w, "case $state in") + defer fmt.Fprintln(w, "esac") for i := 1; i <= maxDepth; i++ { - fw.fWriteLn(" level%d)", i) - writeLevel(fw, root, i) - fw.fWriteLn(" ;;") + fmt.Fprintf(w, " level%d)\n", i) + writeLevel(w, root, i) + fmt.Fprintln(w, " ;;") } - fw.fWriteLn(" *)") - fw.fWriteLn(" _arguments '*: :_files'") - fw.fWriteLn(" ;;") + fmt.Fprintln(w, " *)") + fmt.Fprintln(w, " _arguments '*: :_files'") + fmt.Fprintln(w, " ;;") } -func writeLevel(fw *fWriter, root *Command, i int) { - fw.fWriteLn(fmt.Sprintf(" case $words[%d] in", i)) - defer fw.fWriteLn(" esac") +func writeLevel(w io.Writer, root *Command, i int) { + fmt.Fprintf(w, " case $words[%d] in\n", i) + defer fmt.Fprintln(w, " esac") commands := filterByLevel(root, i) byParent := groupByParent(commands) for p, c := range byParent { names := names(c) - fw.fWriteLn(fmt.Sprintf(" %s)", p)) - fw.fWriteLn(fmt.Sprintf(" _arguments '%d: :(%s)'", i, strings.Join(names, " "))) - fw.fWriteLn(fmt.Sprintf(" ;;")) + fmt.Fprintf(w, " %s)\n", p) + fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " ")) + fmt.Fprintln(w, " ;;") } - fw.fWriteLn(" *)") - fw.fWriteLn(" _arguments '*: :_files'") - fw.fWriteLn(" ;;") + fmt.Fprintln(w, " *)") + fmt.Fprintln(w, " _arguments '*: :_files'") + fmt.Fprintln(w, " ;;") }