diff --git a/spancheck.go b/spancheck.go index ebfc1ac..6f069a0 100644 --- a/spancheck.go +++ b/spancheck.go @@ -3,7 +3,6 @@ package spancheck import ( "go/ast" "go/types" - "log" "regexp" "golang.org/x/tools/go/analysis" @@ -170,7 +169,7 @@ func runFunc(pass *analysis.Pass, node ast.Node, config *Config) { for _, sv := range spanVars { if config.endCheckEnabled { // Check if there's no End to the span. - if ret := missingSpanCalls(pass, g, sv, "End", func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { return ret }, nil); ret != nil { + if ret := getMissingSpanCalls(pass, g, sv, "End", func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { return ret }, nil); ret != nil { pass.ReportRangef(sv.stmt, "%s.End is not called on all paths, possible memory leak", sv.vr.Name()) pass.ReportRangef(ret, "return can be reached without calling %s.End", sv.vr.Name()) } @@ -178,7 +177,7 @@ func runFunc(pass *analysis.Pass, node ast.Node, config *Config) { if config.setStatusEnabled { // Check if there's no SetStatus to the span setting an error. - if ret := missingSpanCalls(pass, g, sv, "SetStatus", returnsErr, config.ignoreChecksSignatures); ret != nil { + if ret := getMissingSpanCalls(pass, g, sv, "SetStatus", getErrorReturn, config.ignoreChecksSignatures); ret != nil { pass.ReportRangef(sv.stmt, "%s.SetStatus is not called on all paths", sv.vr.Name()) pass.ReportRangef(ret, "return can be reached without calling %s.SetStatus", sv.vr.Name()) } @@ -186,7 +185,7 @@ func runFunc(pass *analysis.Pass, node ast.Node, config *Config) { if config.recordErrorEnabled && sv.spanType == spanOpenTelemetry { // RecordError only exists in OpenTelemetry // Check if there's no RecordError to the span setting an error. - if ret := missingSpanCalls(pass, g, sv, "RecordError", returnsErr, config.ignoreChecksSignatures); ret != nil { + if ret := getMissingSpanCalls(pass, g, sv, "RecordError", getErrorReturn, config.ignoreChecksSignatures); ret != nil { pass.ReportRangef(sv.stmt, "%s.RecordError is not called on all paths", sv.vr.Name()) pass.ReportRangef(ret, "return can be reached without calling %s.RecordError", sv.vr.Name()) } @@ -236,9 +235,9 @@ func getID(node ast.Node) *ast.Ident { return nil } -// missingSpanCalls finds a path through the CFG, from stmt (which defines +// getMissingSpanCalls finds a path through the CFG, from stmt (which defines // the 'span' variable v) to a return statement, that doesn't call the passed selector on the span. -func missingSpanCalls( +func getMissingSpanCalls( pass *analysis.Pass, g *cfg.CFG, sv spanVar, @@ -246,66 +245,12 @@ func missingSpanCalls( checkErr func(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt, ignoreCheckSig *regexp.Regexp, ) *ast.ReturnStmt { - // usesCall reports whether stmts contain a use of the selName call on variable v. - usesCall := func(pass *analysis.Pass, stmts []ast.Node) bool { - found, reAssigned := false, false - for _, subStmt := range stmts { - stack := []ast.Node{} - ast.Inspect(subStmt, func(n ast.Node) bool { - switch n := n.(type) { - case *ast.FuncLit: - if len(stack) > 0 { - return false // don't stray into nested functions - } - case *ast.CallExpr: - if ident, ok := n.Fun.(*ast.Ident); ok { - fnSig := pass.TypesInfo.ObjectOf(ident).String() - if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) { - found = true - return false - } - } - case nil: - stack = stack[:len(stack)-1] // pop - return true - } - stack = append(stack, n) // push - - // Check whether the span was assigned over top of its old value. - _, spanStart := isSpanStart(pass.TypesInfo, n) - if spanStart { - if id := getID(stack[len(stack)-3]); id != nil && id.Obj.Decl == sv.id.Obj.Decl { - reAssigned = true - return false - } - } - - if n, ok := n.(*ast.SelectorExpr); ok { - // Selector (End, SetStatus, RecordError) hit. - if n.Sel.Name == selName { - id, ok := n.X.(*ast.Ident) - found = ok && id.Obj.Decl == sv.id.Obj.Decl - } - - // Check if an ignore signature matches. - fnSig := pass.TypesInfo.ObjectOf(n.Sel).String() - if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) { - found = true - } - } - - return !found - }) - } - return found && !reAssigned - } - // blockUses computes "uses" for each block, caching the result. memo := make(map[*cfg.Block]bool) blockUses := func(pass *analysis.Pass, b *cfg.Block) bool { res, ok := memo[b] if !ok { - res = usesCall(pass, b.Nodes) + res = usesCall(pass, b.Nodes, sv, selName, ignoreCheckSig, 0) memo[b] = res } return res @@ -325,12 +270,9 @@ outer: } } } - if defBlock == nil { - log.Default().Print("[ERROR] internal error: can't find defining block for span var") - } // Is the call "used" in the remainder of its defining block? - if usesCall(pass, rest) { + if usesCall(pass, rest, sv, selName, ignoreCheckSig, 0) { return nil } @@ -356,12 +298,12 @@ outer: } // Found path to return statement? - if ret := returnsErr(pass, b.Return()); ret != nil { + if ret := getErrorReturn(pass, b.Return()); ret != nil { return ret // found } // Recur - if ret := returnsErr(pass, search(b.Succs)); ret != nil { + if ret := getErrorReturn(pass, search(b.Succs)); ret != nil { return ret } } @@ -371,7 +313,75 @@ outer: return search(defBlock.Succs) } -func returnsErr(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { +// usesCall reports whether stmts contain a use of the selName call on variable v. +func usesCall(pass *analysis.Pass, stmts []ast.Node, sv spanVar, selName string, ignoreCheckSig *regexp.Regexp, depth int) bool { + if depth > 1 { // for perf reasons, do not dive too deep thru func literals, just one level deep check. + return false + } + + found, reAssigned := false, false + for _, subStmt := range stmts { + stack := []ast.Node{} + ast.Inspect(subStmt, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.FuncLit: + if len(stack) > 0 { + cfgs := pass.ResultOf[ctrlflow.Analyzer].(*ctrlflow.CFGs) + g := cfgs.FuncLit(n) + if g != nil && len(g.Blocks) > 0 { + return usesCall(pass, g.Blocks[0].Nodes, sv, selName, ignoreCheckSig, depth+1) + } + + return false + } + case *ast.CallExpr: + if ident, ok := n.Fun.(*ast.Ident); ok { + fnSig := pass.TypesInfo.ObjectOf(ident).String() + if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) { + found = true + return false + } + } + case nil: + if len(stack) > 0 { + stack = stack[:len(stack)-1] // pop + return true + } + return false + } + stack = append(stack, n) // push + + // Check whether the span was assigned over top of its old value. + _, spanStart := isSpanStart(pass.TypesInfo, n) + if spanStart { + if id := getID(stack[len(stack)-3]); id != nil && id.Obj.Decl == sv.id.Obj.Decl { + reAssigned = true + return false + } + } + + if n, ok := n.(*ast.SelectorExpr); ok { + // Selector (End, SetStatus, RecordError) hit. + if n.Sel.Name == selName { + id, ok := n.X.(*ast.Ident) + found = ok && id.Obj.Decl == sv.id.Obj.Decl + } + + // Check if an ignore signature matches. + fnSig := pass.TypesInfo.ObjectOf(n.Sel).String() + if ignoreCheckSig != nil && ignoreCheckSig.MatchString(fnSig) { + found = true + } + } + + return !found + }) + } + + return found && !reAssigned +} + +func getErrorReturn(pass *analysis.Pass, ret *ast.ReturnStmt) *ast.ReturnStmt { if ret == nil { return nil } diff --git a/testdata/base/base.go b/testdata/base/base.go index 427b695..eb27852 100644 --- a/testdata/base/base.go +++ b/testdata/base/base.go @@ -132,3 +132,22 @@ func _() { _, span := trace.StartSpanWithRemoteParent(context.Background(), "foo", trace.SpanContext{}) defer span.End() } + +// This tests that we detect when the span is closed within a deferred func. +// https://github.com/jjti/go-spancheck/issues/12 +func _() { + _, span := otel.Tracer("foo").Start(context.Background(), "bar") + defer func() { + span.End() + }() +} + +// Despite above, we do not wander more than one level deep into the defer stack. +func _() { + _, span := otel.Tracer("foo").Start(context.Background(), "bar") // want "span.End is not called on all paths, possible memory leak" + defer func() { + defer func() { + span.End() + }() + }() +} // want "return can be reached without calling span.End"