Skip to content

Commit

Permalink
Fixing reflect
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Jan 15, 2025
1 parent 8f9651f commit 1f506d1
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 42 deletions.
1 change: 1 addition & 0 deletions _examples/reflect/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func main() {

func run(args *args) func(context.Context, []string) error {
return func(ctx context.Context, v []string) error {
fmt.Println("args:", v)
fmt.Println("arg:", args.Arg)
fmt.Println("url set:", args.URLSet)
if args.URL != nil {
Expand Down
43 changes: 34 additions & 9 deletions cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,9 @@ func NewFlag(name, usage string, opts ...Option) (*Flag, error) {
// determined by the result of calling [DefaultFlagNameMapper], or it can be
// set with the `name:` option (see below).
//
// Use the [CommandOption] [From] to add flags when creating a command with
// [Run]/[RunContext]/[Sub]/[NewCommand].
//
// Example:
//
// args := struct{
Expand All @@ -717,7 +720,7 @@ func NewFlag(name, usage string, opts ...Option) (*Flag, error) {
// MyFloat float64 `ox:"my float,hidden,name:MYF"`
// }{}
//
// ox.FromFlags(&args)
// ox.FlagsFrom(&args)
//
// Recognized options:
//
Expand All @@ -742,7 +745,7 @@ func NewFlag(name, usage string, opts ...Option) (*Flag, error) {
// The `default:` option will be expanded by [Context.Expand] when the
// command's flags are populated.
//
// The tag name (`ox`) can be changed by setting the [DefaultTagName] variable
// The tag name (`ox`) can be changed by setting the [DefaultStructTagName] variable
// if necessary.
func FlagsFrom[T *E, E any](val T) ([]*Flag, error) {
return appendFlags(nil, reflect.ValueOf(val), nil)
Expand Down Expand Up @@ -836,7 +839,7 @@ func NewExec[T ExecType](f T) (ExecFunc, error) {
return nil, fmt.Errorf("%w: invalid exec func %T", ErrInvalidType, f)
}

// appendFlags builds flags for a value, appending them to flags.
// appendFlags builds and appends flags for value v to flags.
func appendFlags(flags []*Flag, v reflect.Value, parents []string) ([]*Flag, error) {
if v.Kind() != reflect.Pointer || v.Elem().Kind() != reflect.Struct {
return nil, fmt.Errorf("%w: %s is not a *struct", ErrInvalidType, v.Type())
Expand All @@ -846,16 +849,18 @@ func appendFlags(flags []*Flag, v reflect.Value, parents []string) ([]*Flag, err
for i := range typ.NumField() {
// check field is exported
f := typ.Field(i)
tag, ok := f.Tag.Lookup(DefaultTagName)
tag, ok := f.Tag.Lookup(DefaultStructTagName)
if !ok {
continue
}
if r := []rune(f.Name); !unicode.IsUpper(r[0]) {
return nil, fmt.Errorf("%w: unexported field `%s` has tag `%s`", ErrInvalidType, f.Name, DefaultTagName)
return nil, fmt.Errorf("%w: unexported field `%s` has tag `%s`", ErrInvalidType, f.Name, DefaultStructTagName)
}
field := v.Field(i)
tags := SplitBy(tag, ',')
switch name, kind, field := tags[0], f.Type.Kind(), v.Field(i); {
case kind == reflect.Pointer && f.Type.Elem().Kind() == reflect.Struct: // *struct
isField := isField(field)
switch name, kind := tags[0], f.Type.Kind(); {
case !isField && kind == reflect.Pointer && f.Type.Elem().Kind() == reflect.Struct: // *struct
switch {
case name == "-":
continue
Expand All @@ -869,7 +874,7 @@ func appendFlags(flags []*Flag, v reflect.Value, parents []string) ([]*Flag, err
if flags, err = appendFlags(flags, field, append(parents, name)); err != nil {
return nil, err
}
case kind == reflect.Struct: // struct
case !isField && kind == reflect.Struct: // struct
switch {
case name == "-":
continue
Expand Down Expand Up @@ -977,10 +982,30 @@ func buildFlagOpts(parent, value reflect.Value, tags []string) ([]Option, error)
return nil, fmt.Errorf("%w: %q", ErrUnknownTagOption, key)
}
}
// prepend bind to opt
// prepend bind to options
return prepend(opts, BindRef(value, set)), nil
}

// isField returns true if v is not a struct or the struct contains at least
// one `ox` tag.
func isField(v reflect.Value) bool {
typ := v.Type()
kind := typ.Kind()
if kind == reflect.Pointer {
typ = typ.Elem()
kind = typ.Kind()
}
if kind != reflect.Struct {
return false
}
for i := range typ.NumField() {
if _, ok := typ.Field(i).Tag.Lookup(DefaultStructTagName); ok {
return false
}
}
return true
}

// setField returns the pointer to the bool for name.
func setField(value reflect.Value, name string) (*bool, error) {
if r := []rune(name); !unicode.IsUpper(r[0]) {
Expand Down
37 changes: 25 additions & 12 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package ox_test

import (
"context"
"errors"
"net/netip"
"net/url"
"time"

Expand Down Expand Up @@ -42,13 +44,17 @@ func Example() {
// myApp [flags] [args]
//
// Flags:
// -s, --my-string string a string
// -b, --my-bool a bool
// -i, --ints int a slice of ints
// -d, --date date formatted date
// -v, --verbose enable verbose
// --version show version, then exit
// -h, --help show help, then exit
// -s, --my-string string a string
// -b, --my-bool a bool
// -i, --ints int a slice of ints
// -d, --date date formatted date
// -u, --url url a url
// -v, --verbose enable verbose
// -y, --sub-bools string=bool bool map
// -S, --x-some-other-string string another string
// --x-a-really-long-name string long arg
// --version show version, then exit
// -h, --help show help, then exit
//
// See: https://github.com/xo/ox for more information.
}
Expand All @@ -60,7 +66,8 @@ func Example_argsTest() {
Number float64 `ox:"a number"`
}{}
subArgs := struct {
URL *url.URL `ox:"a url,short:u"`
URL *url.URL `ox:"a url,short:u"`
Addr *netip.Addr `ox:"an ip address"`
}{}
ox.RunContext(
context.Background(),
Expand All @@ -69,12 +76,17 @@ func Example_argsTest() {
ox.Defaults(),
ox.From(&args),
ox.Sub(
// ox.Exec(mySubFunc),
ox.Exec(func() error {
// return an error to show that this func is not called
return errors.New("oops!")
}),
ox.Usage("sub", "a sub command to test"),
ox.From(&subArgs),
ox.Sort(true),
),
ox.Sort(true),
// the command line args to test
ox.Args("sub", "--help"),
ox.Args("help", "sub"),
)
// Output:
// sub a sub command to test
Expand All @@ -83,8 +95,9 @@ func Example_argsTest() {
// extest sub [flags] [args]
//
// Flags:
// -u, --url url a url
// -h, --help show help, then exit
// -u, --url url a url
// --addr addr an ip address
// -h, --help show help, then exit
}

// Example_psql demonstrates building complex help output, based on original
Expand Down
40 changes: 19 additions & 21 deletions ox.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import (
var (
// DefaultContext is the default [context.Context].
DefaultContext = context.Background()
// DefaultTagName is the default struct tag name used in [FromFlags] and
// related func's.
DefaultTagName = "ox"
// DefaultStructTagName is the default struct tag name used in [FromFlags]
// and related func's.
DefaultStructTagName = "ox"
// DefaultLayout is the default timestamp layout used for formatting and
// parsing [Time] values.
DefaultLayout = time.RFC3339
Expand Down Expand Up @@ -422,12 +422,10 @@ func (ctx *Context) Run(parent context.Context) error {
return nil
}

// Populate populates the context's vars with all the command's flags values,
// overwriting any set variables if applicable. When all is true, all flag
// values will be populated, otherwise only flags with default values will be.
//
// When overwrite is true, existing vars will be set either to flag's empty or
// default value.
// Populate populates the context's vars with the command's flag's default
// values when not already set. When all is true, all flag values will be
// populated with the default or zero value. When overwrite is true, any
// existing flag vars will be overwritten.
func (ctx *Context) Populate(cmd *Command, all, overwrite bool) error {
if cmd.Flags == nil {
return nil
Expand Down Expand Up @@ -455,16 +453,16 @@ func (ctx *Context) Populate(cmd *Command, all, overwrite bool) error {

// Expand expands variables in v, where any variable can be the following:
//
// $APPNAME - the root command's name
// $HOME - the current user's home directory
// $USER - the current user's user name
// $CONFIG - the current user's config directory
// $APPCONFIG - the current user's config directory, with the root command's name added as a subdir
// $CACHE - the current user's cache directory
// $APPCACHE - the current user's cache directory, with the root command's name added as a subdir
// $NUMCPU - the value of [runtime.NumCPU]
// $ARCH - the value of [runtime.GOARCH]
// $OS - the value of [runtime.GOOS]
// $HOME - the current user's home directory (ex: ~/)
// $USER - the current user's user name (ex: user)
// $APPNAME - the root command's name (ex: appName)
// $CONFIG - the current user's config directory (ex: ~/.config)
// $APPCONFIG - the current user's config directory, with the root command's name added as a subdir (ex: ~/.config/appName)
// $CACHE - the current user's cache directory (ex: ~/.cache)
// $APPCACHE - the current user's cache directory, with the root command's name added as a subdir (ex: ~/.cache/appName)
// $NUMCPU - the value of [runtime.NumCPU] (ex: 4)
// $ARCH - the value of [runtime.GOARCH] (ex: amd64)
// $OS - the value of [runtime.GOOS] (ex: windows)
// $ENV{KEY} - the environment value for $KEY
// $CONFIG_TYPE{KEY} - the registered config file loader type and key value, for example: `$YAML{my_key}`, `$TOML{my_key}`
//
Expand All @@ -491,6 +489,8 @@ func (ctx *Context) Expand(v any) (string, error) {
}
return u.Username, nil
}
case "$APPNAME":
return ctx.Root.Name, nil
case "$CONFIG":
f = userConfigDir
case "$APPCONFIG":
Expand All @@ -511,8 +511,6 @@ func (ctx *Context) Expand(v any) (string, error) {
}
return filepath.Join(dir, ctx.Root.Name), nil
}
case "$APPNAME":
return ctx.Root.Name, nil
case "$NUMCPU":
return strconv.Itoa(runtime.NumCPU()), nil
case "$ARCH":
Expand Down

0 comments on commit 1f506d1

Please sign in to comment.