Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typed nil validation to dsl.Security #3574

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions dsl/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,35 +228,34 @@ func JWTSecurity(name string, fn ...func()) *expr.SchemeExpr {
// })
func Security(args ...any) {
var dsl func()
{
if d, ok := args[len(args)-1].(func()); ok {
args = args[:len(args)-1]
dsl = d
}
}

var schemes []*expr.SchemeExpr
{
schemes = make([]*expr.SchemeExpr, len(args))
for i, arg := range args {
switch val := arg.(type) {
case string:
for _, s := range expr.Root.Schemes {
if s.SchemeName == val {
schemes[i] = expr.DupScheme(s)
break
}
}
if schemes[i] == nil {
eval.ReportError("security scheme %q not found", val)
return
if d, ok := args[len(args)-1].(func()); ok {
args = args[:len(args)-1]
dsl = d
}

schemes := make([]*expr.SchemeExpr, len(args))
for i, arg := range args {
switch val := arg.(type) {
case string:
for _, s := range expr.Root.Schemes {
if s.SchemeName == val {
schemes[i] = expr.DupScheme(s)
break
}
case *expr.SchemeExpr:
schemes[i] = expr.DupScheme(val)
default:
eval.InvalidArgError("security scheme or security scheme name", val)
}
if schemes[i] == nil {
eval.ReportError("security scheme %q not found", val)
return
}
case *expr.SchemeExpr:
if val == nil {
eval.InvalidArgError("security scheme", val)
return
}
schemes[i] = expr.DupScheme(val)
default:
eval.InvalidArgError("security scheme or security scheme name", val)
return
}
}

Expand Down
23 changes: 12 additions & 11 deletions eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ func TestInvalidArgError(t *testing.T) {
dsl func()
want string
}{
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
"Security (typed nil)": {func() { Security((*expr.SchemeExpr)(nil)) }, "cannot use (*expr.SchemeExpr)(nil) (type *expr.SchemeExpr) as type security scheme"},
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
}
for name, tc := range dsls {
t.Run(name, func(t *testing.T) {
Expand Down