diff --git a/cmd/protogetter/main.go b/cmd/protogetter/main.go index e03fdcf..fdc8c14 100644 --- a/cmd/protogetter/main.go +++ b/cmd/protogetter/main.go @@ -7,9 +7,5 @@ import ( ) func main() { - cfg := &protogetter.Config{ - Mode: protogetter.StandaloneMode, - } - - singlechecker.Main(protogetter.NewAnalyzer(cfg)) + singlechecker.Main(protogetter.NewAnalyzer(nil)) } diff --git a/go.mod b/go.mod index a4edd0e..f03c8aa 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,11 @@ module github.com/ghostiam/protogetter go 1.19 -require golang.org/x/tools v0.12.0 +require ( + github.com/gobwas/glob v0.2.3 + github.com/jessevdk/go-flags v1.5.0 + golang.org/x/tools v0.12.0 +) require ( golang.org/x/mod v0.12.0 // indirect diff --git a/go.sum b/go.sum index 14a6610..1a588b6 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,11 @@ +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/tools v0.12.0 h1:YW6HUoUmYBpwSgyaGaZq1fHjrBjX1rlpZ54T6mu2kss= diff --git a/protogetter.go b/protogetter.go index 3651da6..22e72be 100644 --- a/protogetter.go +++ b/protogetter.go @@ -2,13 +2,16 @@ package protogetter import ( "bytes" + "flag" "fmt" "go/ast" "go/format" "go/token" "log" + "path/filepath" "strings" + "github.com/gobwas/glob" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/inspector" ) @@ -23,31 +26,73 @@ const ( const msgFormat = "avoid direct access to proto field %s, use %s instead" func NewAnalyzer(cfg *Config) *analysis.Analyzer { + if cfg == nil { + cfg = &Config{} + } + return &analysis.Analyzer{ - Name: "protogetter", - Doc: "Reports direct reads from proto message fields when getters should be used", + Name: "protogetter", + Doc: "Reports direct reads from proto message fields when getters should be used", + Flags: flags(cfg), Run: func(pass *analysis.Pass) (any, error) { - Run(pass, cfg) - return nil, nil + _, err := Run(pass, cfg) + return nil, err }, } } +func flags(opts *Config) flag.FlagSet { + fs := flag.NewFlagSet("protogetter", flag.ContinueOnError) + + fs.Func("skip-generated-by", "skip files generated with the given prefixes", func(s string) error { + for _, prefix := range strings.Split(s, ",") { + opts.SkipGeneratedBy = append(opts.SkipGeneratedBy, prefix) + } + return nil + }) + fs.Func("skip-files", "skip files with the given glob patterns", func(s string) error { + for _, pattern := range strings.Split(s, ",") { + opts.SkipFiles = append(opts.SkipFiles, pattern) + } + return nil + }) + + return *fs +} + type Config struct { - Mode Mode - SkipGeneratedBy []string + Mode Mode // Zero value is StandaloneMode. + SkipGeneratedBy []string `short:"g" long:"skip-generated-by" description:"Skip files generated with the given prefixes"` + SkipFiles []string `short:"f" long:"skip-files" description:"Skip files with the given glob patterns"` } -func Run(pass *analysis.Pass, cfg *Config) []Issue { - // Always skip files generated by protoc-gen-go and protoc-gen-grpc-gateway. - skipGeneratedBy := []string{"protoc-gen-go", "protoc-gen-grpc-gateway"} +func Run(pass *analysis.Pass, cfg *Config) ([]Issue, error) { + skipGeneratedBy := make([]string, 0, len(cfg.SkipGeneratedBy)+3) + // Always skip files generated by protoc-gen-go, protoc-gen-go-grpc and protoc-gen-grpc-gateway. + skipGeneratedBy = append(skipGeneratedBy, "protoc-gen-go", "protoc-gen-go-grpc", "protoc-gen-grpc-gateway") for _, s := range cfg.SkipGeneratedBy { - if strings.TrimSpace(s) == "" { + s = strings.TrimSpace(s) + if s == "" { continue } skipGeneratedBy = append(skipGeneratedBy, s) } + skipFilesGlobPatterns := make([]glob.Glob, 0, len(cfg.SkipFiles)) + for _, s := range cfg.SkipFiles { + s = strings.TrimSpace(s) + if s == "" { + continue + } + + compile, err := glob.Compile(s) + if err != nil { + return nil, fmt.Errorf("invalid glob pattern: %w", err) + } + + skipFilesGlobPatterns = append(skipFilesGlobPatterns, compile) + } + nodeTypes := []ast.Node{ (*ast.AssignStmt)(nil), (*ast.CallExpr)(nil), @@ -56,14 +101,20 @@ func Run(pass *analysis.Pass, cfg *Config) []Issue { (*ast.UnaryExpr)(nil), } - // Skip protoc-generated files. + // Skip filtered files. var files []*ast.File for _, f := range pass.Files { - if !skipGeneratedFile(f, skipGeneratedBy) { - files = append(files, f) + if skipGeneratedFile(f, skipGeneratedBy) { + continue + } - // ast.Print(pass.Fset, f) + if skipFilesByGlob(pass.Fset.File(f.Pos()).Name(), skipFilesGlobPatterns) { + continue } + + files = append(files, f) + + // ast.Print(pass.Fset, f) } ins := inspector.New(files) @@ -85,7 +136,7 @@ func Run(pass *analysis.Pass, cfg *Config) []Issue { } }) - return issues + return issues, nil } func analyse(pass *analysis.Pass, filter *PosFilter, n ast.Node) *Report { @@ -196,6 +247,16 @@ func skipGeneratedFile(f *ast.File, prefixes []string) bool { return false } +func skipFilesByGlob(filename string, patterns []glob.Glob) bool { + for _, pattern := range patterns { + if pattern.Match(filename) || pattern.Match(filepath.Base(filename)) { + return true + } + } + + return false +} + func formatNode(node ast.Node) string { buf := new(bytes.Buffer) if err := format.Node(buf, token.NewFileSet(), node); err != nil {