diff --git a/contextcheck.go b/contextcheck.go index 543a802..0434eb8 100644 --- a/contextcheck.go +++ b/contextcheck.go @@ -11,6 +11,7 @@ import ( "github.com/gostaticanalysis/analysisutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" ) @@ -18,7 +19,7 @@ func NewAnalyzer() *analysis.Analyzer { return &analysis.Analyzer{ Name: "contextcheck", Doc: "check the function whether use a non-inherited context", - Run: NewRun(), + Run: NewRun(nil), Requires: []*analysis.Analyzer{ buildssa.Analyzer, }, @@ -41,6 +42,7 @@ const ( var ( checkedMap = make(map[string]bool) checkedMapLock sync.RWMutex + c *collector ) type runner struct { @@ -51,8 +53,11 @@ type runner struct { skipFile map[*ast.File]bool } -func NewRun() func(pass *analysis.Pass) (interface{}, error) { +func NewRun(pkgs []*packages.Package) func(pass *analysis.Pass) (interface{}, error) { + c = newCollector(pkgs) return func(pass *analysis.Pass) (interface{}, error) { + defer c.DecUse(pass) + r := new(runner) r.run(pass) return nil, nil @@ -264,15 +269,14 @@ func (r *runner) collectCtxRef(f *ssa.Function) (refMap map[ssa.Instruction]bool return } -func (r *runner) buildPkg(f *ssa.Function) { +func (r *runner) buildPkg(f *ssa.Function) (ff *ssa.Function) { if f.Blocks != nil { + ff = f return } - // only build the pkg which is in the same repo - if r.checkIsSameRepo(f.Pkg.Pkg.Path()) { - f.Pkg.Build() - } + ff = c.GetFunction(f) + return } func (r *runner) checkIsSameRepo(s string) bool { @@ -324,7 +328,9 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function) { // if ff has no ctx, start deep traversal check if !r.checkIsEntry(ff, instr.Pos()) { - r.buildPkg(ff) + if ff = r.buildPkg(ff); ff == nil { + continue + } checkingMap := make(map[string]bool) checkingMap[key] = true @@ -386,7 +392,9 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo } checkingMap[key] = true - r.buildPkg(ff) + if ff = r.buildPkg(ff); ff == nil { + continue + } valid := r.checkFuncWithoutCtx(ff, checkingMap) setValue(key, valid) diff --git a/dep.go b/dep.go new file mode 100644 index 0000000..ecef1cf --- /dev/null +++ b/dep.go @@ -0,0 +1,116 @@ +package contextcheck + +import ( + "go/types" + "sync/atomic" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" +) + +type pkgInfo struct { + pkgPkg *packages.Package // to find references later + ssaPkg *ssa.Package // to find func which has been built + refCnt int32 // reference count +} + +type collector struct { + m map[string]*pkgInfo +} + +func newCollector(pkgs []*packages.Package) (c *collector) { + c = &collector{ + m: make(map[string]*pkgInfo), + } + + // self-reference + for _, pkg := range pkgs { + c.m[pkg.PkgPath] = &pkgInfo{ + pkgPkg: pkg, + refCnt: 1, + } + } + + // import reference + for _, pkg := range pkgs { + for _, imp := range pkg.Imports { + if val, ok := c.m[imp.PkgPath]; ok { + val.refCnt++ + } + } + } + + return +} + +func (c *collector) DecUse(pass *analysis.Pass) { + curPkg, ok := c.m[pass.Pkg.Path()] + if !ok { + return + } + + if atomic.AddInt32(&curPkg.refCnt, -1) != 0 { + curPkg.ssaPkg = pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA).Pkg + return + } + + var release func(info *pkgInfo) + release = func(info *pkgInfo) { + for _, pkg := range info.pkgPkg.Imports { + if val, ok := c.m[pkg.PkgPath]; ok { + if atomic.AddInt32(&val.refCnt, -1) == 0 { + release(val) + } + } + } + + info.pkgPkg = nil + info.ssaPkg = nil + } + release(curPkg) +} + +func (c *collector) GetFunction(f *ssa.Function) (ff *ssa.Function) { + info, ok := c.m[f.Pkg.Pkg.Path()] + if !ok { + return + } + + // without recv => get by Func + recv := f.Signature.Recv() + if recv == nil { + ff = info.ssaPkg.Func(f.Name()) + return + } + + // with recv => find in prog according to type + ntp, ptp := getNamedType(recv.Type()) + if ntp == nil { + return + } + sel := info.ssaPkg.Prog.MethodSets.MethodSet(ntp).Lookup(ntp.Obj().Pkg(), f.Name()) + if sel == nil { + sel = info.ssaPkg.Prog.MethodSets.MethodSet(ptp).Lookup(ntp.Obj().Pkg(), f.Name()) + } + if sel == nil { + return + } + ff = info.ssaPkg.Prog.MethodValue(sel) + return +} + +func getNamedType(tp types.Type) (ntp *types.Named, ptp *types.Pointer) { + switch t := tp.(type) { + case *types.Named: + ntp = t + ptp = types.NewPointer(tp) + case *types.Pointer: + if n, ok := t.Elem().(*types.Named); ok { + ntp = n + ptp = t + } + } + return +}