Skip to content

Commit

Permalink
fix: detect span calls within func literals (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjti authored Feb 23, 2024
1 parent 0bf80bc commit 9c61e03
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 68 deletions.
146 changes: 78 additions & 68 deletions spancheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package spancheck
import (
"go/ast"
"go/types"
"log"
"regexp"

"golang.org/x/tools/go/analysis"
Expand Down Expand Up @@ -170,23 +169,23 @@ func runFunc(pass *analysis.Pass, node ast.Node, config *Config) {
for _, sv := range spanVars {
if config.endCheckEnabled {
// Check if there's no End to the span.
if ret := missingSpanCalls(pass, g, sv, "End", func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { return ret }, nil); ret != nil {
if ret := getMissingSpanCalls(pass, g, sv, "End", func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { return ret }, nil); ret != nil {
pass.ReportRangef(sv.stmt, "%s.End is not called on all paths, possible memory leak", sv.vr.Name())
pass.ReportRangef(ret, "return can be reached without calling %s.End", sv.vr.Name())
}
}

if config.setStatusEnabled {
// Check if there's no SetStatus to the span setting an error.
if ret := missingSpanCalls(pass, g, sv, "SetStatus", returnsErr, config.ignoreChecksSignatures); ret != nil {
if ret := getMissingSpanCalls(pass, g, sv, "SetStatus", getErrorReturn, config.ignoreChecksSignatures); ret != nil {
pass.ReportRangef(sv.stmt, "%s.SetStatus is not called on all paths", sv.vr.Name())
pass.ReportRangef(ret, "return can be reached without calling %s.SetStatus", sv.vr.Name())
}
}

if config.recordErrorEnabled && sv.spanType == spanOpenTelemetry { // RecordError only exists in OpenTelemetry
// Check if there's no RecordError to the span setting an error.
if ret := missingSpanCalls(pass, g, sv, "RecordError", returnsErr, config.ignoreChecksSignatures); ret != nil {
if ret := getMissingSpanCalls(pass, g, sv, "RecordError", getErrorReturn, config.ignoreChecksSignatures); ret != nil {
pass.ReportRangef(sv.stmt, "%s.RecordError is not called on all paths", sv.vr.Name())
pass.ReportRangef(ret, "return can be reached without calling %s.RecordError", sv.vr.Name())
}
Expand Down Expand Up @@ -236,76 +235,22 @@ func getID(node ast.Node) *ast.Ident {
return nil
}

// missingSpanCalls finds a path through the CFG, from stmt (which defines
// getMissingSpanCalls finds a path through the CFG, from stmt (which defines
// the 'span' variable v) to a return statement, that doesn't call the passed selector on the span.
func missingSpanCalls(
func getMissingSpanCalls(
pass *analysis.Pass,
g *cfg.CFG,
sv spanVar,
selName string,
checkErr func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt,
ignoreCheckSig *regexp.Regexp,
) *ast.ReturnStmt {
// usesCall reports whether stmts contain a use of the selName call on variable v.
usesCall := func(pass *analysis.Pass, stmts []ast.Node) bool {
found, reAssigned := false, false
for _, subStmt := range stmts {
stack := []ast.Node{}
ast.Inspect(subStmt, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit:
if len(stack) > 0 {
return false // don't stray into nested functions
}
case *ast.CallExpr:
if ident, ok := n.Fun.(*ast.Ident); ok {
fnSig := pass.TypesInfo.ObjectOf(ident).String()
if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) {
found = true
return false
}
}
case nil:
stack = stack[:len(stack)-1] // pop
return true
}
stack = append(stack, n) // push

// Check whether the span was assigned over top of its old value.
_, spanStart := isSpanStart(pass.TypesInfo, n)
if spanStart {
if id := getID(stack[len(stack)-3]); id != nil && id.Obj.Decl == sv.id.Obj.Decl {
reAssigned = true
return false
}
}

if n, ok := n.(*ast.SelectorExpr); ok {
// Selector (End, SetStatus, RecordError) hit.
if n.Sel.Name == selName {
id, ok := n.X.(*ast.Ident)
found = ok && id.Obj.Decl == sv.id.Obj.Decl
}

// Check if an ignore signature matches.
fnSig := pass.TypesInfo.ObjectOf(n.Sel).String()
if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) {
found = true
}
}

return !found
})
}
return found && !reAssigned
}

// blockUses computes "uses" for each block, caching the result.
memo := make(map[*cfg.Block]bool)
blockUses := func(pass *analysis.Pass, b *cfg.Block) bool {
res, ok := memo[b]
if !ok {
res = usesCall(pass, b.Nodes)
res = usesCall(pass, b.Nodes, sv, selName, ignoreCheckSig, 0)
memo[b] = res
}
return res
Expand All @@ -325,12 +270,9 @@ outer:
}
}
}
if defBlock == nil {
log.Default().Print("[ERROR] internal error: can't find defining block for span var")
}

// Is the call "used" in the remainder of its defining block?
if usesCall(pass, rest) {
if usesCall(pass, rest, sv, selName, ignoreCheckSig, 0) {
return nil
}

Expand All @@ -356,12 +298,12 @@ outer:
}

// Found path to return statement?
if ret := returnsErr(pass, b.Return()); ret != nil {
if ret := getErrorReturn(pass, b.Return()); ret != nil {
return ret // found
}

// Recur
if ret := returnsErr(pass, search(b.Succs)); ret != nil {
if ret := getErrorReturn(pass, search(b.Succs)); ret != nil {
return ret
}
}
Expand All @@ -371,7 +313,75 @@ outer:
return search(defBlock.Succs)
}

func returnsErr(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt {
// usesCall reports whether stmts contain a use of the selName call on variable v.
func usesCall(pass *analysis.Pass, stmts []ast.Node, sv spanVar, selName string, ignoreCheckSig *regexp.Regexp, depth int) bool {
if depth > 1 { // for perf reasons, do not dive too deep thru func literals, just one level deep check.
return false
}

found, reAssigned := false, false
for _, subStmt := range stmts {
stack := []ast.Node{}
ast.Inspect(subStmt, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.FuncLit:
if len(stack) > 0 {
cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs)
g := cfgs.FuncLit(n)
if g != nil && len(g.Blocks) > 0 {
return usesCall(pass, g.Blocks[0].Nodes, sv, selName, ignoreCheckSig, depth+1)
}

return false
}
case *ast.CallExpr:
if ident, ok := n.Fun.(*ast.Ident); ok {
fnSig := pass.TypesInfo.ObjectOf(ident).String()
if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) {
found = true
return false
}
}
case nil:
if len(stack) > 0 {
stack = stack[:len(stack)-1] // pop
return true
}
return false
}
stack = append(stack, n) // push

// Check whether the span was assigned over top of its old value.
_, spanStart := isSpanStart(pass.TypesInfo, n)
if spanStart {
if id := getID(stack[len(stack)-3]); id != nil && id.Obj.Decl == sv.id.Obj.Decl {
reAssigned = true
return false
}
}

if n, ok := n.(*ast.SelectorExpr); ok {
// Selector (End, SetStatus, RecordError) hit.
if n.Sel.Name == selName {
id, ok := n.X.(*ast.Ident)
found = ok && id.Obj.Decl == sv.id.Obj.Decl
}

// Check if an ignore signature matches.
fnSig := pass.TypesInfo.ObjectOf(n.Sel).String()
if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) {
found = true
}
}

return !found
})
}

return found && !reAssigned
}

func getErrorReturn(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt {
if ret == nil {
return nil
}
Expand Down
19 changes: 19 additions & 0 deletions testdata/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,22 @@ func _() {
_, span := trace.StartSpanWithRemoteParent(context.Background(), "foo", trace.SpanContext{})
defer span.End()
}

// This tests that we detect when the span is closed within a deferred func.
// https://github.com/jjti/go-spancheck/issues/12
func _() {
_, span := otel.Tracer("foo").Start(context.Background(), "bar")
defer func() {
span.End()
}()
}

// Despite above, we do not wander more than one level deep into the defer stack.
func _() {
_, span := otel.Tracer("foo").Start(context.Background(), "bar") // want "span.End is not called on all paths, possible memory leak"
defer func() {
defer func() {
span.End()
}()
}()
} // want "return can be reached without calling span.End"

0 comments on commit 9c61e03

Please sign in to comment.