Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assignment tracking for many-to-one assignments #181

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions assertion/function/assertiontree/backprop.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/config"
"go.uber.org/nilaway/util"
"go.uber.org/nilaway/util/asthelper"
"golang.org/x/exp/slices"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/cfg"
Expand Down Expand Up @@ -591,20 +590,10 @@ buildShadowMask:
if ok && lhsNode != nil {
// Add assignment entries to the consumers of lhsNode for informative printing of errors
for _, c := range lhsNode.ConsumeTriggers() {
var lhsExprStr, rhsExprStr string
var err error
if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), c.Annotation)
if err != nil {
return err
}
if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
return err
}

c.Annotation.AddAssignment(annotation.Assignment{
LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())),
})
}

// If the lhsVal path is not only trackable but tracked, we add it as
Expand Down Expand Up @@ -645,20 +634,10 @@ buildShadowMask:
continue
}
for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] {
var lhsExprStr, rhsExprStr string
var err error
if lhsExprStr, err = asthelper.PrintExpr(lhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
return err
}
if rhsExprStr, err = asthelper.PrintExpr(rhsVal, rootNode.Pass(), true /* isShortenExpr */); err != nil {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), t.Consumer.Annotation)
if err != nil {
return err
}

t.Consumer.Annotation.AddAssignment(annotation.Assignment{
LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhsVal.Pos(), rootNode.Pass())),
})
}
default:
return errors.New("rhs expression in a 1-1 assignment was multiply returning - " +
Expand Down Expand Up @@ -734,18 +713,36 @@ func backpropAcrossManyToOneAssignment(rootNode *RootAssertionNode, lhs, rhs []a
rootNode.addProductionsForAssignmentFields(fieldProducers, lhsVal)
}

// beforeTriggersLastIndex is used to find the newly added triggers on the next line
beforeTriggersLastIndex := len(rootNode.triggers)

rootNode.AddGuardMatch(lhsVal, ContinueTracking)
rootNode.AddProduction(&annotation.ProduceTrigger{
Annotation: producers[i].GetShallow().Annotation,
Expr: lhsVal,
}, producers[i].GetDeepSlice()...)

// Update consumers of newly added triggers with assignment entries for informative printing of errors
if len(rootNode.triggers) > 0 {
for _, t := range rootNode.triggers[beforeTriggersLastIndex:len(rootNode.triggers)] {
err := addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), t.Consumer.Annotation)
if err != nil {
return err
}
}
}

// Phase 2
consumeTrigger, err := exprAsAssignmentConsumer(rootNode, lhsVal, rhsVal)
if err != nil {
return err
}
if consumeTrigger != nil {
// Update consumeTrigger with assignment entries for informative printing of errors
if err = addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), consumeTrigger); err != nil {
return err
}

// lhsVal is a field read, so this is a field assignment
// since multiple return functions aren't trackable, this is a completed trigger
// as long as the type of the expression being assigned doesn't bar nilness
Expand All @@ -767,6 +764,11 @@ func backpropAcrossManyToOneAssignment(rootNode *RootAssertionNode, lhs, rhs []a
}

if consumer := exprAsConsumedByAssignment(rootNode, lhsVal); consumer != nil {
// Update consumeTrigger with assignment entries for informative printing of errors
if err = addAssignmentToConsumer(lhsVal, rhsVal, rootNode.Pass(), consumer.Annotation); err != nil {
return err
}

rootNode.AddConsumption(consumer)
}
}
Expand Down
22 changes: 22 additions & 0 deletions assertion/function/assertiontree/backprop_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"go.uber.org/nilaway/annotation"
"go.uber.org/nilaway/util"
"go.uber.org/nilaway/util/asthelper"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/cfg"
)
Expand Down Expand Up @@ -749,3 +750,24 @@ func CheckGuardOnFullTrigger(trigger annotation.FullTrigger) annotation.FullTrig
}
return trigger
}

// addAssignmentToConsumer updates the consumer with assignment entries for informative printing of errors
func addAssignmentToConsumer(lhs, rhs ast.Expr, pass *analysis.Pass, consumer annotation.ConsumingAnnotationTrigger) error {
var lhsExprStr, rhsExprStr string
var err error

if lhsExprStr, err = asthelper.PrintExpr(lhs, pass, true /* isShortenExpr */); err != nil {
return fmt.Errorf("converting LHS of assignment to string: %w", err)
}
if rhsExprStr, err = asthelper.PrintExpr(rhs, pass, true /* isShortenExpr */); err != nil {
return fmt.Errorf("converting RHS of assignment to string: %w", err)
}

consumer.AddAssignment(annotation.Assignment{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something to be handled in this PR, but do you think we can actually move the logic (the asthelper.PrintExpr) inside the AddAssignment method? meaning make it take the AST nodes and do the expr shortening inside. I feel that would make the code simpler, but I may have missed some context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could. However, that would mean passing the *analysis.Pass to ConsumeTrigger, which is not required by ConsumeTrigger otherwise. Maybe we can do this refactor later if that seems to be the better approach.

LHSExprStr: lhsExprStr,
RHSExprStr: rhsExprStr,
Position: util.TruncatePosition(util.PosToLocation(lhs.Pos(), pass)),
})

return nil
}
93 changes: 93 additions & 0 deletions testdata/src/go.uber.org/errormessage/errormessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,96 @@ func test21() {
var 世界 *int = nil
print(*世界) //want "`nil` to `世界`"
}

// below tests check assignment flow tracking across many-to-one assignments

// nilable(result 0)
func retPtrErr() (*int, error) {
return nil, nil
}

func test22(i int) {
switch i {
case 0:
x, err := retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"

case 1:
if x, err := retPtrErr(); err == nil {
y := x
print(*y) //want "`retPtrErr\\(\\)` to `x`"
}

case 2:
var x *int
var err error
x, err = retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"

case 3:
var x, err = retPtrErr()
if err != nil {
return
}
print(*x) //want "`retPtrErr\\(\\)` to `x`"
}
}

// nilable(mp[])
func test23(mp map[int]*int, i int) {
switch i {
case 0:
v, ok := mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}

case 1:
if v, ok := mp[0]; ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
case 2:
var v *int
var ok bool
v, ok = mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
case 3:
var v, ok = mp[0]
if ok {
print(*v) //want "`mp\\[0\\]` to `v`"
}
}
}

// nilable(result 0, result 2)
func retMultiple() (*int, *int, *int) {
return nil, new(int), nil
}

func test24() {
a, b, c := retMultiple()
if dummy {
b = a
}
print(*a) //want "`retMultiple\\(\\)` to `a`"
print(*b) //want "`a` to `b`"
print(*c) //want "`retMultiple\\(\\)` to `c`"
}

// nilable(A[])
type A []*int

// nonnil(a)
func test25(a A) {
a[0], a[1], _ = retMultiple()
print(*a[0]) //want "`retMultiple\\(\\)` to `a\\[0\\]`"
print(*a[1])
}
Loading