From 8e8fd555bd88c5492f1caef026c7691c99ae33f2 Mon Sep 17 00:00:00 2001 From: Romain Marcadier Date: Mon, 10 Feb 2025 12:26:25 +0100 Subject: [PATCH 1/4] fix: pragma directives not picked up on return statement The `//dd:span` directive is not honored when it's presented on a `return` statement (where the expression is a literal function expression). This changes the directive lookup function so it crawls up to the parent if it's a node type that wraps a single expression; effectively: - `name := ` (and other assignment styles) - `` - `defer ` - `` (the expression statement) - `go ` - `label: ` - `return ` - `ch <- ` (send statement) Fixes #539 --- internal/injector/aspect/join/directive.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/injector/aspect/join/directive.go b/internal/injector/aspect/join/directive.go index d7b7a7ec..62ec0388 100644 --- a/internal/injector/aspect/join/directive.go +++ b/internal/injector/aspect/join/directive.go @@ -49,10 +49,14 @@ func (d directive) matchesChain(chain *context.NodeChain) bool { } } - // If the parent is an assignment statement, so we also check it for directives. if parent := chain.Parent(); parent != nil { - if _, isAssign := parent.Node().(*dst.AssignStmt); isAssign && d.matchesChain(parent) { - return true + switch parent.Node().(type) { + // Also check whether the parent carries the directive if it's one of the node types that would + // typically carry directives that applies to its nested node. + case *dst.AssignStmt, *dst.CallExpr, *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.LabeledStmt, *dst.ReturnStmt, *dst.SendStmt: + if d.matchesChain(parent) { + return true + } } } From 6f9b844c348b1a795595804fb5762deee0d91061 Mon Sep 17 00:00:00 2001 From: Romain Marcadier Date: Mon, 10 Feb 2025 16:53:52 +0100 Subject: [PATCH 2/4] Add unit test coverage --- internal/injector/aspect/join/directive.go | 29 +++- .../injector/aspect/join/directive_test.go | 164 ++++++++++++++++++ 2 files changed, 191 insertions(+), 2 deletions(-) diff --git a/internal/injector/aspect/join/directive.go b/internal/injector/aspect/join/directive.go index 62ec0388..96a27788 100644 --- a/internal/injector/aspect/join/directive.go +++ b/internal/injector/aspect/join/directive.go @@ -6,6 +6,7 @@ package join import ( + "go/token" "strings" "unicode" "unicode/utf8" @@ -19,6 +20,16 @@ import ( type directive string +// Directive matches nodes that are prefaced by a special pragma comment, which +// is a single-line style comment without any blanks between the leading // and +// the directive name. Directives apply to the node they are directly attached +// to, but also to certain nested nodes: +// - For assignments, it applies to the RHS only; unless it's a delcaration +// assignment (the := token), in which case it also applies to the LHS, +// - For call expressions, it applies only to the function part (not the +// arguments)n +// - For channel send operations, it only applies to the value being sent, +// - For defer, go, and return statements, it applies to the value side. func Directive(name string) directive { return directive(name) } @@ -50,10 +61,24 @@ func (d directive) matchesChain(chain *context.NodeChain) bool { } if parent := chain.Parent(); parent != nil { - switch parent.Node().(type) { + switch node := parent.Node().(type) { // Also check whether the parent carries the directive if it's one of the node types that would // typically carry directives that applies to its nested node. - case *dst.AssignStmt, *dst.CallExpr, *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.LabeledStmt, *dst.ReturnStmt, *dst.SendStmt: + case *dst.AssignStmt: + checkParent := chain.PropertyName() == "Rhs" + checkParent = checkParent || (node.Tok == token.DEFINE && chain.PropertyName() == "Lhs") + if checkParent && d.matchesChain(parent) { + return true + } + case *dst.CallExpr: + if chain.PropertyName() == "Fun" && d.matchesChain(parent) { + return true + } + case *dst.SendStmt: + if chain.PropertyName() == "Value" && d.matchesChain(parent) { + return true + } + case *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.ReturnStmt: if d.matchesChain(parent) { return true } diff --git a/internal/injector/aspect/join/directive_test.go b/internal/injector/aspect/join/directive_test.go index 45c6df24..a82b5a34 100644 --- a/internal/injector/aspect/join/directive_test.go +++ b/internal/injector/aspect/join/directive_test.go @@ -6,8 +6,18 @@ package join import ( + "bytes" + "go/parser" + "go/printer" + "go/token" + "strings" "testing" + "github.com/DataDog/orchestrion/internal/injector/aspect/context" + "github.com/dave/dst" + "github.com/dave/dst/decorator" + "github.com/dave/dst/dstutil" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,3 +36,157 @@ func TestDirectiveMatch(t *testing.T) { // Not a directive (not a single-line comment syntax) require.False(t, dir.matches("\t/*test:directive*/")) } + +func TestDirective(t *testing.T) { + type testCase struct { + preamble string + statement string + expectMatches []string // The [testCase.statement] is always matched (last), and must not be represented here. + } + tests := map[string]testCase{ + "assignment-declaration": { + statement: "foo := func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "foo", // Matches because it's being defined here + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "assignment": { + preamble: "var foo func(...int)", + statement: "foo = func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "multi-assignment": { + preamble: "var foo func(...int)", + statement: "_, foo = nil, func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "nil", + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "call": { + preamble: "var foo func(...int)", + statement: "foo(1, 2, 3)", + expectMatches: []string{ + "foo", + "foo(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression) + }, + }, + "immediately-invoked-function-expression": { + statement: "func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", // Quirck -- this is an ExprStmt, so it matches twice (the statement, the expression) + }, + }, + "defer": { + statement: "defer func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "go": { + statement: "go func(...int) {}(1, 2, 3)", + expectMatches: []string{ + "func(...int) {}", + "func(...int) {}(1, 2, 3)", + }, + }, + "return": { + statement: "return func(...int) int { return 0 }(1, 2, 3)", + expectMatches: []string{ + "func(...int) int { return 0 }", + "func(...int) int { return 0 }(1, 2, 3)", + }, + }, + "chan-send": { + preamble: "var ch chan <-int", + statement: "ch <- func(...int) int { return 0 }(1, 2, 3)", + expectMatches: []string{ + "func(...int) int { return 0 }", + "func(...int) int { return 0 }(1, 2, 3)", + }, + }, + } + + const pragma = "//test:directive" + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + source := strings.Join( + []string{ + "package main", + "func main() {", + tc.preamble, + pragma, + tc.statement, + "}", + }, + "\n", + ) + + fset := token.NewFileSet() + astFile, err := parser.ParseFile(fset, "input.go", source, parser.ParseComments) + require.NoError(t, err) + + dec := decorator.NewDecorator(fset) + + dstFile, err := dec.DecorateFile(astFile) + require.NoError(t, err) + + visitor := &visitor{pragma: Directive("test:directive")} + dstutil.Apply( + dstFile, + visitor.pre, + visitor.post, + ) + require.NotEmpty(t, visitor.matches) + + descriptors := make([]string, len(visitor.matches)) + for idx, match := range visitor.matches { + node := dec.Ast.Nodes[match.Node()] + + var str bytes.Buffer + printer.Fprint(&str, fset, node) + + descriptors[idx] = str.String() + } + assert.Equal(t, tc.statement, strings.TrimPrefix(descriptors[len(descriptors)-1], pragma+"\n"), + "the statement itself should always be matched") + assert.Equal(t, tc.expectMatches, descriptors[:len(descriptors)-1]) + }) + } +} + +type visitor struct { + pragma directive + file *dst.File + node *context.NodeChain + matches []*context.NodeChain +} + +func (v *visitor) pre(cursor *dstutil.Cursor) bool { + if cursor.Node() == nil { + return false + } + + if file, ok := cursor.Node().(*dst.File); ok { + v.file = file + } + + v.node = v.node.Child(cursor) + return true +} + +func (v *visitor) post(cursor *dstutil.Cursor) bool { + if v.pragma.matchesChain(v.node) { + v.matches = append(v.matches, v.node) + } + v.node = v.node.Parent() + return true +} From c34dde5192aaab2596932d23b4982056dc8dc9de Mon Sep 17 00:00:00 2001 From: Romain Marcadier Date: Mon, 10 Feb 2025 17:08:27 +0100 Subject: [PATCH 3/4] Make linter happy --- internal/injector/aspect/join/directive_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/injector/aspect/join/directive_test.go b/internal/injector/aspect/join/directive_test.go index a82b5a34..f3c6abf9 100644 --- a/internal/injector/aspect/join/directive_test.go +++ b/internal/injector/aspect/join/directive_test.go @@ -183,7 +183,7 @@ func (v *visitor) pre(cursor *dstutil.Cursor) bool { return true } -func (v *visitor) post(cursor *dstutil.Cursor) bool { +func (v *visitor) post(*dstutil.Cursor) bool { if v.pragma.matchesChain(v.node) { v.matches = append(v.matches, v.node) } From c5ab2ed453711d8555b92f6bc699fb0ccbddeeda Mon Sep 17 00:00:00 2001 From: Romain Marcadier Date: Tue, 11 Feb 2025 18:19:14 +0100 Subject: [PATCH 4/4] Feedback from @eliottness --- internal/injector/aspect/join/directive.go | 56 +++++++++++++--------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/internal/injector/aspect/join/directive.go b/internal/injector/aspect/join/directive.go index 96a27788..673607aa 100644 --- a/internal/injector/aspect/join/directive.go +++ b/internal/injector/aspect/join/directive.go @@ -60,28 +60,40 @@ func (d directive) matchesChain(chain *context.NodeChain) bool { } } - if parent := chain.Parent(); parent != nil { - switch node := parent.Node().(type) { - // Also check whether the parent carries the directive if it's one of the node types that would - // typically carry directives that applies to its nested node. - case *dst.AssignStmt: - checkParent := chain.PropertyName() == "Rhs" - checkParent = checkParent || (node.Tok == token.DEFINE && chain.PropertyName() == "Lhs") - if checkParent && d.matchesChain(parent) { - return true - } - case *dst.CallExpr: - if chain.PropertyName() == "Fun" && d.matchesChain(parent) { - return true - } - case *dst.SendStmt: - if chain.PropertyName() == "Value" && d.matchesChain(parent) { - return true - } - case *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.ReturnStmt: - if d.matchesChain(parent) { - return true - } + parent := chain.Parent() + if parent == nil { + return false + } + + switch node := parent.Node().(type) { + // Also check whether the parent carries the directive if it's one of the node types that would + // typically carry directives that applies to its nested node. + case *dst.AssignStmt: + // For assignments, the directive only applies downwards to the RHS, unless it's a declaration, + // then it also applies to any declared identifier. + checkParent := chain.PropertyName() == "Rhs" + checkParent = checkParent || (node.Tok == token.DEFINE && chain.PropertyName() == "Lhs") + if checkParent && d.matchesChain(parent) { + return true + } + case *dst.CallExpr: + // For call expressions, the directive only applies to the called function, not its type + // signature or arguments list. + if chain.PropertyName() == "Fun" && d.matchesChain(parent) { + return true + } + case *dst.SendStmt: + // For chanel send statements, the directive only applies to the value being sent, not to the + // receiving channel. + if chain.PropertyName() == "Value" && d.matchesChain(parent) { + return true + } + case *dst.DeferStmt, *dst.ExprStmt, *dst.GoStmt, *dst.ReturnStmt: + // Defer statements, go statements, and return statements all forward the directive to the + // value(s); and expression statements are just wrappers of expressions, so naturally directives + // that apply to the statement also apply to the expression. + if d.matchesChain(parent) { + return true } }