Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pragma directives not picked up on return statement #540

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions internal/injector/aspect/join/directive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package join

import (
"go/token"
"strings"
"unicode"
"unicode/utf8"
Expand All @@ -19,6 +20,16 @@ import (

type directive string

// Directive matches nodes that are prefaced by a special pragma comment, which
// is a single-line style comment without any blanks between the leading // and
// the directive name. Directives apply to the node they are directly attached
// to, but also to certain nested nodes:
// - For assignments, it applies to the RHS only; unless it's a delcaration
// assignment (the := token), in which case it also applies to the LHS,
// - For call expressions, it applies only to the function part (not the
// arguments)n
// - For channel send operations, it only applies to the value being sent,
// - For defer, go, and return statements, it applies to the value side.
func Directive(name string) directive {
return directive(name)
}
Expand Down Expand Up @@ -49,10 +60,28 @@ func (d directive) matchesChain(chain *context.NodeChain) bool {
}
}

// If the parent is an assignment statement, so we also check it for directives.
if parent := chain.Parent(); parent != nil {
if _, isAssign := parent.Node().(*dst.AssignStmt); isAssign && d.matchesChain(parent) {
return true
switch node := parent.Node().(type) {
// Also check whether the parent carries the directive if it's one of the node types that would
// typically carry directives that applies to its nested node.
case *dst.AssignStmt:
checkParent := chain.PropertyName() == "Rhs"
checkParent = checkParent || (node.Tok == token.DEFINE && chain.PropertyName() == "Lhs")
if checkParent && d.matchesChain(parent) {
return true
}
case *dst.CallExpr:
if chain.PropertyName() == "Fun" && d.matchesChain(parent) {
return true
}
case *dst.SendStmt:
if chain.PropertyName() == "Value" && d.matchesChain(parent) {
return true
}
case *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.ReturnStmt:
if d.matchesChain(parent) {
return true
}
}
}

Expand Down
164 changes: 164 additions & 0 deletions internal/injector/aspect/join/directive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@
package join

import (
"bytes"
"go/parser"
"go/printer"
"go/token"
"strings"
"testing"

"github.com/DataDog/orchestrion/internal/injector/aspect/context"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
"github.com/dave/dst/dstutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -26,3 +36,157 @@ func TestDirectiveMatch(t *testing.T) {
// Not a directive (not a single-line comment syntax)
require.False(t, dir.matches("\t/*test:directive*/"))
}

func TestDirective(t *testing.T) {
type testCase struct {
preamble string
statement string
expectMatches []string // The [testCase.statement] is always matched (last), and must not be represented here.
}
tests := map[string]testCase{
"assignment-declaration": {
statement: "foo := func(...int) {}(1, 2, 3)",
expectMatches: []string{
"foo", // Matches because it's being defined here
"func(...int) {}",
"func(...int) {}(1, 2, 3)",
},
},
"assignment": {
preamble: "var foo func(...int)",
statement: "foo = func(...int) {}(1, 2, 3)",
expectMatches: []string{
"func(...int) {}",
"func(...int) {}(1, 2, 3)",
},
},
"multi-assignment": {
preamble: "var foo func(...int)",
statement: "_, foo = nil, func(...int) {}(1, 2, 3)",
expectMatches: []string{
"nil",
"func(...int) {}",
"func(...int) {}(1, 2, 3)",
},
},
"call": {
preamble: "var foo func(...int)",
statement: "foo(1, 2, 3)",
expectMatches: []string{
"foo",
"foo(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression)
},
},
"immediately-invoked-function-expression": {
statement: "func(...int) {}(1, 2, 3)",
expectMatches: []string{
"func(...int) {}",
"func(...int) {}(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression)
},
},
"defer": {
statement: "defer func(...int) {}(1, 2, 3)",
expectMatches: []string{
"func(...int) {}",
"func(...int) {}(1, 2, 3)",
},
},
"go": {
statement: "go func(...int) {}(1, 2, 3)",
expectMatches: []string{
"func(...int) {}",
"func(...int) {}(1, 2, 3)",
},
},
"return": {
statement: "return func(...int) int { return 0 }(1, 2, 3)",
expectMatches: []string{
"func(...int) int { return 0 }",
"func(...int) int { return 0 }(1, 2, 3)",
},
},
"chan-send": {
preamble: "var ch chan <-int",
statement: "ch <- func(...int) int { return 0 }(1, 2, 3)",
expectMatches: []string{
"func(...int) int { return 0 }",
"func(...int) int { return 0 }(1, 2, 3)",
},
},
}

const pragma = "//test:directive"
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
source := strings.Join(
[]string{
"package main",
"func main() {",
tc.preamble,
pragma,
tc.statement,
"}",
},
"\n",
)

fset := token.NewFileSet()
astFile, err := parser.ParseFile(fset, "input.go", source, parser.ParseComments)
require.NoError(t, err)

dec := decorator.NewDecorator(fset)

dstFile, err := dec.DecorateFile(astFile)
require.NoError(t, err)

visitor := &visitor{pragma: Directive("test:directive")}
dstutil.Apply(
dstFile,
visitor.pre,
visitor.post,
)
require.NotEmpty(t, visitor.matches)

descriptors := make([]string, len(visitor.matches))
for idx, match := range visitor.matches {
node := dec.Ast.Nodes[match.Node()]

var str bytes.Buffer
printer.Fprint(&str, fset, node)

descriptors[idx] = str.String()
}
assert.Equal(t, tc.statement, strings.TrimPrefix(descriptors[len(descriptors)-1], pragma+"\n"),
"the statement itself should always be matched")
assert.Equal(t, tc.expectMatches, descriptors[:len(descriptors)-1])
})
}
}

type visitor struct {
pragma directive
file *dst.File
node *context.NodeChain
matches []*context.NodeChain
}

func (v *visitor) pre(cursor *dstutil.Cursor) bool {
if cursor.Node() == nil {
return false
}

if file, ok := cursor.Node().(*dst.File); ok {
v.file = file
}

v.node = v.node.Child(cursor)
return true
}

func (v *visitor) post(*dstutil.Cursor) bool {
if v.pragma.matchesChain(v.node) {
v.matches = append(v.matches, v.node)
}
v.node = v.node.Parent()
return true
}
Loading