diff --git a/gazelle/common/treesitter/BUILD.bazel b/gazelle/common/treesitter/BUILD.bazel index 3cfdad0bb..f1ba9d7c8 100644 --- a/gazelle/common/treesitter/BUILD.bazel +++ b/gazelle/common/treesitter/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "filters.go", "parser.go", "queries.go", + "query.go", "traversal.go", ], importpath = "aspect.build/cli/gazelle/common/treesitter", diff --git a/gazelle/common/treesitter/filters.go b/gazelle/common/treesitter/filters.go index 5fd9131d5..bfe52eb82 100644 --- a/gazelle/common/treesitter/filters.go +++ b/gazelle/common/treesitter/filters.go @@ -16,7 +16,7 @@ import ( // Predicates implemented here: // - eq? // - match? -func matchesAllPredicates(q *sitter.Query, m *sitter.QueryMatch, qc *sitter.QueryCursor, input []byte) bool { +func matchesAllPredicates(q *sitterQuery, m *sitter.QueryMatch, qc *sitter.QueryCursor, input []byte) bool { qm := &sitter.QueryMatch{ ID: m.ID, PatternIndex: m.PatternIndex, diff --git a/gazelle/common/treesitter/queries.go b/gazelle/common/treesitter/queries.go index b325bd432..3b743bd94 100644 --- a/gazelle/common/treesitter/queries.go +++ b/gazelle/common/treesitter/queries.go @@ -13,18 +13,18 @@ import ( var ErrorsQuery = `(ERROR) @error` // A cache of parsed queries per language -var queryCache = make(map[LanguageGrammar]map[string]*sitter.Query) +var queryCache = make(map[LanguageGrammar]map[string]*sitterQuery) var queryMutex sync.Mutex -func parseQuery(lang LanguageGrammar, queryStr string) *sitter.Query { +func parseQuery(lang LanguageGrammar, queryStr string) *sitterQuery { queryMutex.Lock() defer queryMutex.Unlock() if queryCache[lang] == nil { - queryCache[lang] = make(map[string]*sitter.Query) + queryCache[lang] = make(map[string]*sitterQuery) } if queryCache[lang][queryStr] == nil { - queryCache[lang][queryStr] = mustNewQuery(lang, []byte(queryStr)) + queryCache[lang][queryStr] = mustNewQuery(lang, queryStr) } return queryCache[lang][queryStr] @@ -39,7 +39,7 @@ func (tree TreeAst) QueryStrings(query, returnVar string) []string { // Execute the query. qc := sitter.NewQueryCursor() - qc.Exec(sitterQuery, rootNode) + qc.Exec(sitterQuery.q, rootNode) // Collect string from the query results. for { @@ -82,7 +82,7 @@ func (tree TreeAst) Query(query string) <-chan ASTQueryResult { // Execute the query. go func() { qc := sitter.NewQueryCursor() - qc.Exec(q, rootNode) + qc.Exec(q.q, rootNode) for { m, ok := qc.NextMatch() @@ -104,7 +104,7 @@ func (tree TreeAst) Query(query string) <-chan ASTQueryResult { return out } -func (tree TreeAst) mapQueryMatchCaptures(m *sitter.QueryMatch, q *sitter.Query) map[string]string { +func (tree TreeAst) mapQueryMatchCaptures(m *sitter.QueryMatch, q *sitterQuery) map[string]string { captures := make(map[string]string, len(m.Captures)) for _, c := range m.Captures { name := q.CaptureNameForId(c.Index) @@ -120,11 +120,11 @@ func (tree TreeAst) mapQueryMatchCaptures(m *sitter.QueryMatch, q *sitter.Query) // Find and read the `from` QueryCapture from a QueryMatch. // Filter matches based on captures value using "equals-{name}" vars. -func fetchQueryMatch(query *sitter.Query, name string, m *sitter.QueryMatch, sourceCode []byte) *sitter.QueryCapture { +func fetchQueryMatch(query *sitterQuery, name string, m *sitter.QueryMatch, sourceCode []byte) *sitter.QueryCapture { var result *sitter.QueryCapture - for ci, c := range m.Captures { - cn := query.CaptureNameForId(uint32(ci)) + for _, c := range m.Captures { + cn := query.CaptureNameForId(c.Index) // Filters where a capture must equal a specific value. if strings.HasPrefix(cn, "equals-") { @@ -145,8 +145,8 @@ func fetchQueryMatch(query *sitter.Query, name string, m *sitter.QueryMatch, sou return result } -func mustNewQuery(lang LanguageGrammar, query []byte) *sitter.Query { - treeQ, err := sitter.NewQuery(query, toSitterLanguage(lang)) +func mustNewTreeQuery(lang LanguageGrammar, query string) *sitter.Query { + treeQ, err := sitter.NewQuery([]byte(query), toSitterLanguage(lang)) if err != nil { BazelLog.Fatalf("Failed to create query for %q: %v", query, err) } @@ -166,7 +166,7 @@ func (tree TreeAst) QueryErrors() []error { // Execute the import query qc := sitter.NewQueryCursor() - qc.Exec(query, node) + qc.Exec(query.q, node) // Collect import statements from the query results for { diff --git a/gazelle/common/treesitter/query.go b/gazelle/common/treesitter/query.go new file mode 100644 index 000000000..fadb4fe10 --- /dev/null +++ b/gazelle/common/treesitter/query.go @@ -0,0 +1,53 @@ +package treesitter + +import sitter "github.com/smacker/go-tree-sitter" + +// Basic wrapper around sitter.Query to cache tree-sitter cgo calls. +type sitterQuery struct { + q *sitter.Query + + // Pre-computed and cached query data + stringValues []string + captureNames []string + predicatePatterns [][][]sitter.QueryPredicateStep +} + +func mustNewQuery(lang LanguageGrammar, query string) *sitterQuery { + q := mustNewTreeQuery(lang, query) + + captureNames := make([]string, q.CaptureCount()) + for i := uint32(0); i < q.CaptureCount(); i++ { + captureNames[i] = q.CaptureNameForId(i) + } + + stringValues := make([]string, q.StringCount()) + for i := uint32(0); i < q.StringCount(); i++ { + stringValues[i] = q.StringValueForId(i) + } + + predicatePatterns := make([][][]sitter.QueryPredicateStep, q.PatternCount()) + for i := uint32(0); i < q.PatternCount(); i++ { + predicatePatterns[i] = q.PredicatesForPattern(i) + } + + return &sitterQuery{ + q: q, + stringValues: stringValues, + captureNames: captureNames, + predicatePatterns: predicatePatterns, + } +} + +// Cached query data accessors mirroring the tree-sitter Query signatures. + +func (q *sitterQuery) StringValueForId(id uint32) string { + return q.stringValues[id] +} + +func (q *sitterQuery) CaptureNameForId(id uint32) string { + return q.captureNames[id] +} + +func (q *sitterQuery) PredicatesForPattern(patternIndex uint32) [][]sitter.QueryPredicateStep { + return q.predicatePatterns[patternIndex] +}