Skip to content

Commit

Permalink
Support pointer flags (#22)
Browse files Browse the repository at this point in the history
Change-Id: I97d37bdaeab5b878b8064a555e3a07a131b175f0
  • Loading branch information
jxskiss authored Aug 16, 2023
1 parent cbeeac2 commit 9c3107d
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 26 deletions.
2 changes: 1 addition & 1 deletion cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ func (p *App) parseArgs(v any, opts ...ParseOpt) (fs *flag.FlagSet, err error) {
if err = ctx.checkRequired(); err != nil {
return fs, err
}
tidyFlagSet(fs, ctx.flags, nonflagArgs)
tidyFlags(fs, ctx.flags, nonflagArgs)
return fs, err
}

Expand Down
150 changes: 127 additions & 23 deletions cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"flag"
"fmt"
"os"
"reflect"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -72,17 +74,18 @@ func TestParsing_WithoutCallingRun(t *testing.T) {
func TestParsing_CheckFlagSetValues(t *testing.T) {
resetDefaultApp()
var args struct {
A bool `cli:"-a, -a-flag, description a flag"`
A1 bool `cli:"-1, -a1-flag"`
B int32 `cli:"-b, -b-flag, description b flag"`
C int64 `cli:"-c, --c-flag, description c flag"`
D float32 `cli:"-D, --d-flag, description d flag"`
E float64 `cli:"-E, --e-flag, description e flag"`
F string `cli:"-f, -f-flag, description f flag"`
G uint `cli:"-g, --g-flag, description g flag"`
H []bool `cli:"-H, --h-flag, description h flag"`
I []uint `cli:"-i, -i-flag, description i flag"`
J []string `cli:"-j, -j-flag, description j flag"`
A bool `cli:"-a, -a-flag, description a flag"`
A1 bool `cli:"-1, -a1-flag"`
B int32 `cli:"-b, -b-flag, description b flag"`
C int64 `cli:"-c, --c-flag, description c flag"`
D float32 `cli:"-D, --d-flag, description d flag"`
E float64 `cli:"-E, --e-flag, description e flag"`
F string `cli:"-f, -f-flag, description f flag"`
G uint `cli:"-g, --g-flag, description g flag"`
H []bool `cli:"-H, --h-flag, description h flag"`
I []uint `cli:"-i, -i-flag, description i flag"`
J []string `cli:"-j, -j-flag, description j flag"`
K time.Duration `cli:"-k, --k-flag, description k flag"`

ValueImpl2 flagValueImpl2 `cli:"-v, -v-flag, description v flag"`

Expand All @@ -109,6 +112,7 @@ func TestParsing_CheckFlagSetValues(t *testing.T) {
"-j-flag", "j2",
"-j-flag", "j,3",
"-j-flag", "j,4,5",
"-k", "1.5s",
"-v", "abc",
"-v", "123",

Expand All @@ -128,10 +132,11 @@ func TestParsing_CheckFlagSetValues(t *testing.T) {
assert.Equal(t, []bool{true, false, true, false}, args.H)
assert.Equal(t, []uint{5, 6, 7, 8}, args.I)
assert.Equal(t, []string{"j1", "j2", `j,3`, `j,4,5`}, args.J)
assert.Equal(t, 1500*time.Millisecond, args.K)
assert.Equal(t, []string{"some-args 0", "some-args 1"}, args.Args)
assert.Equal(t, []byte("abc123"), args.ValueImpl2.Data)

flagCount := 12 * 2
flagCount := 13 * 2
fs.Visit(func(flag *flag.Flag) {
flagCount--
})
Expand Down Expand Up @@ -163,8 +168,10 @@ func TestParsing_CheckFlagSetValues(t *testing.T) {
{"i-flag", "[5,6,7,8]", []uint{5, 6, 7, 8}},
{"j", `["j1","j2","j,3","j,4,5"]`, []string{"j1", "j2", "j,3", "j,4,5"}},
{"j-flag", `["j1","j2","j,3","j,4,5"]`, []string{"j1", "j2", "j,3", "j,4,5"}},
{"v", "flagValueImpl2", []byte("abc123")},
{"v-flag", "flagValueImpl2", []byte("abc123")},
{"k", "1.5s", 1500 * time.Millisecond},
{"k-flag", "1.5s", 1500 * time.Millisecond},
{"v", "abc123", []byte("abc123")},
{"v-flag", "abc123", []byte("abc123")},
} {
got := fs.Lookup(tt.flag).Value.String()
assert.Equalf(t, tt.want, got, "flag= %v", tt.flag)
Expand All @@ -173,6 +180,106 @@ func TestParsing_CheckFlagSetValues(t *testing.T) {
}
}

func TestParsing_PointerValues(t *testing.T) {
var args struct {
A *bool `cli:"-a, -a-flag, description a flag"`
A1 *bool `cli:"-1, -a1-flag"`
B *int32 `cli:"-b, -b-flag, description b flag"`
C *int64 `cli:"-c, --c-flag, description c flag"`
D *float32 `cli:"-D, --d-flag, description d flag"`
E *float64 `cli:"-E, --e-flag, description e flag"`
F *string `cli:"-f, -f-flag, description f flag"`
G *uint `cli:"-g, --g-flag, description g flag"`
K *time.Duration `cli:"-k, --k-flag, description k flag"`

ValueImpl2 *flagValueImpl2 `cli:"-v, -v-flag, description v flag"`

Arg1 *string `cli:"arg1"`
}

t.Run("all empty", func(t *testing.T) {
resetDefaultApp()
fs, err := Parse(&args, WithArgs([]string{}))
assert.Nil(t, err)
for _, x := range []any{
args.A, args.A1, args.B, args.C, args.D, args.E, args.F, args.G, args.K, args.ValueImpl2,
args.Arg1,
} {
assert.True(t, reflect.ValueOf(x).IsNil())
}
_ = fs
})

t.Run("set flags", func(t *testing.T) {
resetDefaultApp()
fs, err := Parse(&args, WithArgs([]string{
"-a-flag",
"-1=false",
"-b", "1",
"-c-flag", "2",
"--D", "3",
"--e-flag", "4",
"-f", "fstr",
"-g-flag", "5",
"-k", "1.5s",
"-v", "abc",
"-v", "123",

"arg1 value",
}))
assert.Nil(t, err)

assert.True(t, args.A != nil && *args.A)
assert.True(t, args.A1 != nil && !*args.A1)
assert.Equal(t, int32(1), *args.B)
assert.Equal(t, int64(2), *args.C)
assert.Equal(t, float32(3), *args.D)
assert.Equal(t, float64(4), *args.E)
assert.Equal(t, "fstr", *args.F)
assert.Equal(t, uint(5), *args.G)
assert.Equal(t, 1500*time.Millisecond, *args.K)
assert.Equal(t, []byte("abc123"), args.ValueImpl2.Data)
assert.Equal(t, "arg1 value", *args.Arg1)

flagCount := 10 * 2
fs.Visit(func(flag *flag.Flag) {
flagCount--
})
assert.Zero(t, flagCount)
for _, tt := range []struct {
flag string
want string
value any
}{
{"a", "true", true},
{"a-flag", "true", true},
{"1", "false", false},
{"a1-flag", "false", false},
{"b", "1", int32(1)},
{"b-flag", "1", int32(1)},
{"c", "2", int64(2)},
{"c-flag", "2", int64(2)},
{"D", "3", float32(3)},
{"d-flag", "3", float32(3)},
{"E", "4", float64(4)},
{"e-flag", "4", float64(4)},
{"f", "fstr", "fstr"},
{"f-flag", "fstr", "fstr"},
{"g", "5", uint(5)},
{"g-flag", "5", uint(5)},
{"k", "1.5s", 1500 * time.Millisecond},
{"k-flag", "1.5s", 1500 * time.Millisecond},
{"v", "abc123", []byte("abc123")},
{"v-flag", "abc123", []byte("abc123")},
} {
got := fs.Lookup(tt.flag).Value.String()
assert.Equalf(t, tt.want, got, "flag= %v", tt.flag)
gotValue := fs.Lookup(tt.flag).Value.(flag.Getter).Get()
assert.Equalf(t, tt.value, gotValue, "flag= %v", tt.flag)
}
})
}

func TestParsing_DefaultValues(t *testing.T) {
resetDefaultApp()
var args struct {
Expand Down Expand Up @@ -563,8 +670,11 @@ func (f *flagValueImpl2) Get() any {
return f.Data
}

func (f flagValueImpl2) String() string {
return "flagValueImpl2"
func (f *flagValueImpl2) String() string {
if f == nil {
return ""
}
return string(f.Data)
}

func (f *flagValueImpl2) Set(s string) error {
Expand All @@ -591,16 +701,10 @@ type SomeComplexType struct {
func TestParse_UnsupportedType(t *testing.T) {
resetDefaultApp()
var args1 struct {
A *bool `cli:"-a"`
}
var args2 struct {
B *string `cli:"-b"`
}
var args3 struct {
C *SomeComplexType `cli:"-c"`
}

for _, args := range []any{&args1, &args2, &args3} {
for _, args := range []any{&args1} {
assert.Panics(t, func() {
Parse(args)
})
Expand Down
45 changes: 44 additions & 1 deletion flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ func (f *_flag) Get() any {
if f.rv.CanAddr() && f.rv.Addr().Type().Implements(flagGetterTyp) {
return f.rv.Addr().Interface().(flag.Getter).Get()
}
if f.rv.Kind() == reflect.Pointer {
if f.rv.Elem().IsValid() {
return f.rv.Elem().Interface()
}
zero := reflect.New(f.rv.Type().Elem()).Elem()
return zero.Interface()
}
return f.rv.Interface()
}

Expand All @@ -119,9 +126,20 @@ func formatValue(rv reflect.Value) string {
b, _ := rv.Addr().Interface().(textValue).MarshalText()
return string(b)
}
if rv.Kind() == reflect.Pointer {
return formatValueOfBasicTypePtr(rv)
}
return formatValueOfBasicType(rv)
}

func formatValueOfBasicTypePtr(rv reflect.Value) string {
if rv.Elem().IsValid() {
return formatValueOfBasicType(rv.Elem())
}
zero := reflect.New(rv.Type().Elem()).Elem()
return formatValueOfBasicType(zero)
}

func formatValueOfBasicType(rv reflect.Value) string {
switch rv.Kind() {
case reflect.Bool:
Expand Down Expand Up @@ -176,6 +194,11 @@ func applyValue(rv reflect.Value, s string) error {
}

func applyValueOfBasicType(rv reflect.Value, s string) error {
if isSupportedBasicTypePtr(rv.Type()) {
rv.Set(reflect.New(rv.Type().Elem()))
return applyValueOfBasicType(rv.Elem(), s)
}

if isIntegerValue(rv) {
return applyIntegerValue(rv, s)
}
Expand Down Expand Up @@ -252,6 +275,11 @@ func applyIntegerValue(rv reflect.Value, s string) error {
return nil
}

func (f *_flag) isBooleanPtr() bool {
return f.rv.Kind() == reflect.Pointer &&
f.rv.Type().Elem().Kind() == reflect.Bool
}

func (f *_flag) isBoolean() bool {
return f.rv.Kind() == reflect.Bool
}
Expand Down Expand Up @@ -515,14 +543,22 @@ func (p *flagParser) appendFlag(f *_flag) {

func (p *flagParser) addToFlagSet(f *_flag, fv reflect.Value) {
fs := p.fs
if fv.Kind() == reflect.Bool {
if f.isBoolean() {
ptr := fv.Addr().Interface().(*bool)
fs.BoolVar(ptr, f.name, f.rv.Bool(), f.description)
if f.short != "" {
fs.BoolVar(ptr, f.short, f.rv.Bool(), f.description)
}
return
}
if f.isBooleanPtr() {
ptr := new(bool)
fs.BoolVar(ptr, f.name, false, f.description)
if f.short != "" {
fs.BoolVar(ptr, f.short, false, f.description)
}
return
}
fs.Var(f, f.name, f.description)
if f.short != "" {
fs.Var(f, f.short, f.description)
Expand Down Expand Up @@ -718,6 +754,9 @@ func isSupportedType(rv reflect.Value) bool {
if isFlagValueImpl(rv) || isTextValueImpl(rv) {
return true
}
if isSupportedBasicTypePtr(rv.Type()) {
return true
}
if isSupportedBasicType(rv.Kind()) {
return true
}
Expand Down Expand Up @@ -763,6 +802,10 @@ func zeroTextValueStr(rv reflect.Value) string {
return string(b)
}

func isSupportedBasicTypePtr(typ reflect.Type) bool {
return typ.Kind() == reflect.Pointer && isSupportedBasicType(typ.Elem().Kind())
}

func isSupportedBasicType(kind reflect.Kind) bool {
switch kind {
case reflect.Bool,
Expand Down
9 changes: 8 additions & 1 deletion flag_unsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"unsafe"
)

func tidyFlagSet(fs *flag.FlagSet, flags []*_flag, nonflagArgs []string) {
func tidyFlags(fs *flag.FlagSet, flags []*_flag, nonflagArgs []string) {
m := make(map[string]*_flag)
for _, f := range flags {
m[f.name] = f
Expand All @@ -27,6 +27,13 @@ func tidyFlagSet(fs *flag.FlagSet, flags []*_flag, nonflagArgs []string) {
if f == nil {
return
}

// Special processing for *bool value.
if f.isBooleanPtr() {
f.rv.Set(reflect.New(f.rv.Type().Elem()))
f.rv.Elem().SetBool(ff.Value.String() == "true")
}

if f.name != ff.Name {
formal[f.name].Value = ff.Value
actual[f.name] = formal[f.name]
Expand Down

0 comments on commit 9c3107d

Please sign in to comment.