diff --git a/internal/injector/aspect/join/directive.go b/internal/injector/aspect/join/directive.go index d7b7a7ec..673607aa 100644 --- a/internal/injector/aspect/join/directive.go +++ b/internal/injector/aspect/join/directive.go @@ -6,6 +6,7 @@ package join import ( + "go/token" "strings" "unicode" "unicode/utf8" @@ -19,6 +20,16 @@ import ( type directive string +// Directive matches nodes that are prefaced by a special pragma comment, which +// is a single-line style comment without any blanks between the leading // and +// the directive name. Directives apply to the node they are directly attached +// to, but also to certain nested nodes: +// - For assignments, it applies to the RHS only; unless it's a delcaration +// assignment (the := token), in which case it also applies to the LHS, +// - For call expressions, it applies only to the function part (not the +// arguments)n +// - For channel send operations, it only applies to the value being sent, +// - For defer, go, and return statements, it applies to the value side. func Directive(name string) directive { return directive(name) } @@ -49,9 +60,39 @@ func (d directive) matchesChain(chain *context.NodeChain) bool { } } - // If the parent is an assignment statement, so we also check it for directives. - if parent := chain.Parent(); parent != nil { - if _, isAssign := parent.Node().(*dst.AssignStmt); isAssign && d.matchesChain(parent) { + parent := chain.Parent() + if parent == nil { + return false + } + + switch node := parent.Node().(type) { + // Also check whether the parent carries the directive if it's one of the node types that would + // typically carry directives that applies to its nested node. + case *dst.AssignStmt: + // For assignments, the directive only applies downwards to the RHS, unless it's a declaration, + // then it also applies to any declared identifier. + checkParent := chain.PropertyName() == "Rhs" + checkParent = checkParent || (node.Tok == token.DEFINE && chain.PropertyName() == "Lhs") + if checkParent && d.matchesChain(parent) { + return true + } + case *dst.CallExpr: + // For call expressions, the directive only applies to the called function, not its type + // signature or arguments list. + if chain.PropertyName() == "Fun" && d.matchesChain(parent) { + return true + } + case *dst.SendStmt: + // For chanel send statements, the directive only applies to the value being sent, not to the + // receiving channel. + if chain.PropertyName() == "Value" && d.matchesChain(parent) { + return true + } + case *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.ReturnStmt: + // Defer statements, go statements, and return statements all forward the directive to the + // value(s); and expression statements are just wrappers of expressions, so naturally directives + // that apply to the statement also apply to the expression. + if d.matchesChain(parent) { return true } } diff --git a/internal/injector/aspect/join/directive_test.go b/internal/injector/aspect/join/directive_test.go index 45c6df24..f3c6abf9 100644 --- a/internal/injector/aspect/join/directive_test.go +++ b/internal/injector/aspect/join/directive_test.go @@ -6,8 +6,18 @@ package join import ( + "bytes" + "go/parser" + "go/printer" + "go/token" + "strings" "testing" + "github.com/DataDog/orchestrion/internal/injector/aspect/context" + "github.com/dave/dst" + "github.com/dave/dst/decorator" + "github.com/dave/dst/dstutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,3 +36,157 @@ func TestDirectiveMatch(t *testing.T) { // Not a directive (not a single-line comment syntax) require.False(t, dir.matches("\t/*test:directive*/")) } + +func TestDirective(t *testing.T) { + type testCase struct { + preamble string + statement string + expectMatches []string // The [testCase.statement] is always matched (last), and must not be represented here. + } + tests := map[string]testCase{ + "assignment-declaration": { + statement: "foo := func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "foo", // Matches because it's being defined here + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "assignment": { + preamble: "var foo func(...int)", + statement: "foo = func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "multi-assignment": { + preamble: "var foo func(...int)", + statement: "_, foo = nil, func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "nil", + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "call": { + preamble: "var foo func(...int)", + statement: "foo(1, 2, 3)", + expectMatches: []string{ + "foo", + "foo(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression) + }, + }, + "immediately-invoked-function-expression": { + statement: "func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression) + }, + }, + "defer": { + statement: "defer func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "go": { + statement: "go func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "return": { + statement: "return func(...int) int { return 0 }(1, 2, 3)", + expectMatches: []string{ + "func(...int) int { return 0 }", + "func(...int) int { return 0 }(1, 2, 3)", + }, + }, + "chan-send": { + preamble: "var ch chan <-int", + statement: "ch <- func(...int) int { return 0 }(1, 2, 3)", + expectMatches: []string{ + "func(...int) int { return 0 }", + "func(...int) int { return 0 }(1, 2, 3)", + }, + }, + } + + const pragma = "//test:directive" + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + source := strings.Join( + []string{ + "package main", + "func main() {", + tc.preamble, + pragma, + tc.statement, + "}", + }, + "\n", + ) + + fset := token.NewFileSet() + astFile, err := parser.ParseFile(fset, "input.go", source, parser.ParseComments) + require.NoError(t, err) + + dec := decorator.NewDecorator(fset) + + dstFile, err := dec.DecorateFile(astFile) + require.NoError(t, err) + + visitor := &visitor{pragma: Directive("test:directive")} + dstutil.Apply( + dstFile, + visitor.pre, + visitor.post, + ) + require.NotEmpty(t, visitor.matches) + + descriptors := make([]string, len(visitor.matches)) + for idx, match := range visitor.matches { + node := dec.Ast.Nodes[match.Node()] + + var str bytes.Buffer + printer.Fprint(&str, fset, node) + + descriptors[idx] = str.String() + } + assert.Equal(t, tc.statement, strings.TrimPrefix(descriptors[len(descriptors)-1], pragma+"\n"), + "the statement itself should always be matched") + assert.Equal(t, tc.expectMatches, descriptors[:len(descriptors)-1]) + }) + } +} + +type visitor struct { + pragma directive + file *dst.File + node *context.NodeChain + matches []*context.NodeChain +} + +func (v *visitor) pre(cursor *dstutil.Cursor) bool { + if cursor.Node() == nil { + return false + } + + if file, ok := cursor.Node().(*dst.File); ok { + v.file = file + } + + v.node = v.node.Child(cursor) + return true +} + +func (v *visitor) post(*dstutil.Cursor) bool { + if v.pragma.matchesChain(v.node) { + v.matches = append(v.matches, v.node) + } + v.node = v.node.Parent() + return true +}