diff --git a/internal/backends/nodejs/grab.go b/internal/backends/nodejs/grab.go index 4dd28e3f..3d3dc81b 100644 --- a/internal/backends/nodejs/grab.go +++ b/internal/backends/nodejs/grab.go @@ -90,7 +90,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) { foundImportPaths := map[string]bool{} js := javascript.GetLanguage() - jsPkgs, err := util.GuessWithTreeSitter(ctx, dir, js, importsQuery, jsPathGlobs, []string{}) + jsPkgs, err := util.GuessWithTreeSitter(ctx, dir, js, importsQuery, jsPathGlobs, nodeIgnorePathSegments) if err != nil { return nil, err } @@ -100,7 +100,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) { } ts := typescript.GetLanguage() - tsPkgs, err := util.GuessWithTreeSitter(ctx, dir, ts, importsQuery, tsPathGlobs, []string{}) + tsPkgs, err := util.GuessWithTreeSitter(ctx, dir, ts, importsQuery, tsPathGlobs, nodeIgnorePathSegments) if err != nil { return nil, err } @@ -110,7 +110,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) { } tsx := tsx.GetLanguage() - tsxPkgs, err := util.GuessWithTreeSitter(ctx, dir, tsx, importsQuery, tsxPathGlobs, []string{}) + tsxPkgs, err := util.GuessWithTreeSitter(ctx, dir, tsx, importsQuery, tsxPathGlobs, nodeIgnorePathSegments) if err != nil { return nil, err } diff --git a/internal/backends/nodejs/nodejs.go b/internal/backends/nodejs/nodejs.go index c7669def..a836e37e 100644 --- a/internal/backends/nodejs/nodejs.go +++ b/internal/backends/nodejs/nodejs.go @@ -372,6 +372,10 @@ var nodejsGuessRegexps = util.Regexps([]string{ `(?m)(?:require|import)\s*\(\s*['"]([^'"{}]+)['"]\s*\)`, }) +var nodeIgnorePathSegments = map[string]bool{ + "node_modules": true, +} + var jsPathGlobs = []string{ "*.js", "*.jsx", diff --git a/internal/backends/python/grab.go b/internal/backends/python/grab.go index e078765e..4125965d 100644 --- a/internal/backends/python/grab.go +++ b/internal/backends/python/grab.go @@ -26,12 +26,13 @@ var importsQuery = ` (comment)? @pragma) ` -var pyPathGlobs = []string{"*.py"} +var pyPathSegmentPatterns = []string{"*.py"} -var pyIgnoreGlobs = []string{ - "**/__pycache__/**", - "**/venv/**", - "**/.pythonlibs/**", +var pyIgnorePathSegments = map[string]bool{ + "__pycache__": true, + "venv": true, + ".pythonlibs": true, + ".git": true, } var internalModules = map[string]bool{ @@ -263,7 +264,7 @@ func findImports(ctx context.Context, dir string) (map[string]bool, error) { span, ctx := tracer.StartSpanFromContext(ctx, "python.grab.findImports") defer span.Finish() py := python.GetLanguage() - pkgs, err := util.GuessWithTreeSitter(ctx, dir, py, importsQuery, pyPathGlobs, pyIgnoreGlobs) + pkgs, err := util.GuessWithTreeSitter(ctx, dir, py, importsQuery, pyPathSegmentPatterns, pyIgnorePathSegments) if err != nil { return nil, err diff --git a/internal/util/tree-sitter.go b/internal/util/tree-sitter.go index 5341bd24..53b3fd9b 100644 --- a/internal/util/tree-sitter.go +++ b/internal/util/tree-sitter.go @@ -23,42 +23,56 @@ type importPragma struct { Package string } +const ( + // Represents filesystem nodes, including directories. + MaximumVisits = 5000 +) + // GuessWithTreeSitter guesses the imports of a directory using tree-sitter. // For every file in dir that matches a pattern in searchGlobPatterns, but // not in ignoreGlobPatterns, it will parse the file using lang and queryImports. // When there's a capture tagged as `@import`, it reports the capture as an import. // If there's a capture tagged as `@pragma` that's on the same line as an import, // it will include the pragma in the results. -func GuessWithTreeSitter(ctx context.Context, dir string, lang *sitter.Language, queryImports string, searchGlobPatterns, ignoreGlobPatterns []string) ([]string, error) { +func GuessWithTreeSitter(ctx context.Context, root string, lang *sitter.Language, queryImports string, pathSegmentPatterns []string, ignorePathSegments map[string]bool) ([]string, error) { //nolint:ineffassign,wastedassign,staticcheck span, ctx := tracer.StartSpanFromContext(ctx, "GuessWithTreeSitter") defer span.Finish() - dirFS := os.DirFS(dir) + dirFS := os.DirFS(root) - ignoredPaths := map[string]bool{} - for _, pattern := range ignoreGlobPatterns { - globIgnorePaths, err := fs.Glob(dirFS, pattern) + var visited int + pathsToSearch := []string{} + err := fs.WalkDir(dirFS, ".", func(dir string, d fs.DirEntry, err error) error { + dir = path.Join(root, dir) if err != nil { - return nil, err + return err } - for _, gPath := range globIgnorePaths { - ignoredPaths[gPath] = true + visited += 1 + + // Avoid locking up UPM on pathological project configurations + if visited > MaximumVisits { + return fs.SkipAll } - } - pathsToSearch := []string{} - for _, pattern := range searchGlobPatterns { - globSearchPaths, err := fs.Glob(dirFS, pattern) - if err != nil { - return nil, err + if ignorePathSegments[d.Name()] { + return fs.SkipDir } - for _, gPath := range globSearchPaths { - if !ignoredPaths[gPath] { - pathsToSearch = append(pathsToSearch, path.Join(dir, gPath)) + for _, pattern := range pathSegmentPatterns { + var ok bool + if ok, err = path.Match(pattern, d.Name()); ok { + pathsToSearch = append(pathsToSearch, dir) + } + if err != nil { + return err } } + + return nil + }) + if err != nil { + return nil, err } query, err := sitter.NewQuery([]byte(queryImports), lang) diff --git a/internal/util/tree-sitter_test.go b/internal/util/tree-sitter_test.go new file mode 100644 index 00000000..2b135b0f --- /dev/null +++ b/internal/util/tree-sitter_test.go @@ -0,0 +1,109 @@ +package util + +import ( + "context" + "os" + "path" + "strings" + "testing" + + "github.com/smacker/go-tree-sitter/python" +) + +var importsQuery = ` +(module + [(import_statement + name: [(dotted_name) @import + (aliased_import + name: (dotted_name) @import)]) + + (import_from_statement + module_name: (dotted_name) @import)] + + . + + (comment)? @pragma) +` + +func writeFile(dir, name string, contents []byte) error { + err := os.MkdirAll(dir, 0o755) + if err != nil { + return err + } + err = os.WriteFile(path.Join(dir, name), contents, 0644) + if err != nil { + return err + } + return nil +} + +func TestTreeSitter(t *testing.T) { + testDir := t.TempDir() + + expected := map[string]bool{ + "from_root": true, + "from_inner": true, + } + + innerDir := path.Join(testDir, "src", "my_module", "inner") + venvDir := path.Join(testDir, "venv", "ignored_module") + + var err error + err = writeFile(testDir, "root.py", []byte("import from_root")) + if err != nil { + t.Error(err) + } + err = writeFile(innerDir, "inner.py", []byte("import from_inner")) + if err != nil { + t.Error(err) + } + err = writeFile(venvDir, "ignored.py", []byte("import from_venv")) + if err != nil { + t.Error(err) + } + + pathSegmentPatterns := []string{"*.py"} + ignorePathSegments := map[string]bool{ + "venv": true, + } + + ctx := context.Background() + py := python.GetLanguage() + foundMap := map[string]bool{} + { + var found []string + found, err = GuessWithTreeSitter(ctx, testDir, py, importsQuery, pathSegmentPatterns, ignorePathSegments) + if err != nil { + t.Error(err) + } + + for _, pkg := range found { + foundMap[pkg] = true + } + } + + for pkg := range foundMap { + if expected[pkg] { + delete(expected, pkg) + delete(foundMap, pkg) + } else { + t.Error("Missing match: ", pkg) + } + } + + if len(expected) > 0 { + formatted := []string{} + for pkg := range expected { + formatted = append(formatted, pkg) + } + t.Error("Not all expected checks were passed. Missing:", strings.Join(formatted, ", ")) + } + + if len(foundMap) > 0 { + formatted := []string{} + for pkg := range foundMap { + formatted = append(formatted, pkg) + } + t.Error("Not all expected checks were passed. Extra:", strings.Join(formatted, ", ")) + } +}