Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #107: add basic zsh completion (command hierarchy only) #497

Merged
merged 3 commits into from
Jul 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions zsh_completions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package cobra

import (
"bytes"
"fmt"
"io"
"strings"
)

// GenZshCompletion generates a zsh completion file and writes to the passed writer.
func (cmd *Command) GenZshCompletion(w io.Writer) error {
buf := new(bytes.Buffer)

writeHeader(buf, cmd)
maxDepth := maxDepth(cmd)
writeLevelMapping(buf, maxDepth)
writeLevelCases(buf, maxDepth, cmd)

_, err := buf.WriteTo(w)
return err
}

func writeHeader(w io.Writer, cmd *Command) {
fmt.Fprintf(w, "#compdef %s\n\n", cmd.Name())
}

func maxDepth(c *Command) int {
if len(c.Commands()) == 0 {
return 0
}
maxDepthSub := 0
for _, s := range c.Commands() {
subDepth := maxDepth(s)
if subDepth > maxDepthSub {
maxDepthSub = subDepth
}
}
return 1 + maxDepthSub
}

func writeLevelMapping(w io.Writer, numLevels int) {
fmt.Fprintln(w, `_arguments \`)
for i := 1; i <= numLevels; i++ {
fmt.Fprintf(w, ` '%d: :->level%d' \`, i, i)
fmt.Fprintln(w)
}
fmt.Fprintf(w, ` '%d: :%s'`, numLevels+1, "_files")
fmt.Fprintln(w)
}

func writeLevelCases(w io.Writer, maxDepth int, root *Command) {
fmt.Fprintln(w, "case $state in")
defer fmt.Fprintln(w, "esac")

for i := 1; i <= maxDepth; i++ {
fmt.Fprintf(w, " level%d)\n", i)
writeLevel(w, root, i)
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")
}

func writeLevel(w io.Writer, root *Command, i int) {
fmt.Fprintf(w, " case $words[%d] in\n", i)
defer fmt.Fprintln(w, " esac")

commands := filterByLevel(root, i)
byParent := groupByParent(commands)

for p, c := range byParent {
names := names(c)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable names defined but not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used two lines below?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I didn't notice it.

fmt.Fprintf(w, " %s)\n", p)
fmt.Fprintf(w, " _arguments '%d: :(%s)'\n", i, strings.Join(names, " "))
fmt.Fprintln(w, " ;;")
}
fmt.Fprintln(w, " *)")
fmt.Fprintln(w, " _arguments '*: :_files'")
fmt.Fprintln(w, " ;;")

}

func filterByLevel(c *Command, l int) []*Command {
cs := make([]*Command, 0)
if l == 0 {
cs = append(cs, c)
return cs
}
for _, s := range c.Commands() {
cs = append(cs, filterByLevel(s, l-1)...)
}
return cs
}

func groupByParent(commands []*Command) map[string][]*Command {
m := make(map[string][]*Command)
for _, c := range commands {
parent := c.Parent()
if parent == nil {
continue
}
m[parent.Name()] = append(m[parent.Name()], c)
}
return m
}

func names(commands []*Command) []string {
ns := make([]string, len(commands))
for i, c := range commands {
ns[i] = c.Name()
}
return ns
}
88 changes: 88 additions & 0 deletions zsh_completions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package cobra

import (
"bytes"
"strings"
"testing"
)

func TestZshCompletion(t *testing.T) {
tcs := []struct {
name string
root *Command
expectedExpressions []string
}{
{
name: "trivial",
root: &Command{Use: "trivialapp"},
expectedExpressions: []string{"#compdef trivial"},
},
{
name: "linear",
root: func() *Command {
r := &Command{Use: "linear"}

sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)

sub2 := &Command{Use: "sub2"}
sub1.AddCommand(sub2)

sub3 := &Command{Use: "sub3"}
sub2.AddCommand(sub3)
return r
}(),
expectedExpressions: []string{"sub1", "sub2", "sub3"},
},
{
name: "flat",
root: func() *Command {
r := &Command{Use: "flat"}
r.AddCommand(&Command{Use: "c1"})
r.AddCommand(&Command{Use: "c2"})
return r
}(),
expectedExpressions: []string{"(c1 c2)"},
},
{
name: "tree",
root: func() *Command {
r := &Command{Use: "tree"}

sub1 := &Command{Use: "sub1"}
r.AddCommand(sub1)

sub11 := &Command{Use: "sub11"}
sub12 := &Command{Use: "sub12"}

sub1.AddCommand(sub11)
sub1.AddCommand(sub12)

sub2 := &Command{Use: "sub2"}
r.AddCommand(sub2)

sub21 := &Command{Use: "sub21"}
sub22 := &Command{Use: "sub22"}

sub2.AddCommand(sub21)
sub2.AddCommand(sub22)

return r
}(),
expectedExpressions: []string{"(sub11 sub12)", "(sub21 sub22)"},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
tc.root.GenZshCompletion(buf)
completion := buf.String()
for _, expectedExpression := range tc.expectedExpressions {
if !strings.Contains(completion, expectedExpression) {
t.Errorf("expected completion to contain '%v' somewhere; got '%v'", expectedExpression, completion)
}
}
})
}
}