Skip to content

Commit

Permalink
Merge pull request #8 from arnavrneo/better-imports
Browse files Browse the repository at this point in the history
Ver to 1.2.1
  • Loading branch information
arnavrneo authored Dec 26, 2023
2 parents 76f4e0c + e603439 commit 3bd6c05
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@ jobs:
- uses: ncipollo/release-action@v1
with:
skipIfReleaseExists: true
tag: v1.2
tag: v1.2.1
artifacts: "build/*"
29 changes: 25 additions & 4 deletions cmd/clean/clean.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

var reqPath string

func cleanReq(reqPath string, dirPath string, venvPath string, ignoreDirs string, printReq bool) {
func cleanReq(reqPath string, dirPath string, venvPath string, ignoreDirs string, printReq bool, debug bool) {

// TODO: better way of showing errors
if reqPath == " " {
Expand Down Expand Up @@ -61,30 +61,49 @@ func cleanReq(reqPath string, dirPath string, venvPath string, ignoreDirs string
importsInfo[j] = true
}

if debug {
fmt.Print("Imports from project: => ")
for i := range importsInfo {
fmt.Print(i, " ")
}
fmt.Println()
}

// read and write
reqFile, err := os.OpenFile(reqPath, os.O_RDWR, 0644)
utils.Check(err)
scanner := bufio.NewScanner(reqFile)
var bs []byte
buf := bytes.NewBuffer(bs) // stores all the user imports

matchedImports := []string{}

for scanner.Scan() {
text := scanner.Text()
if importsInfo[strings.Split(text, "==")[0]] {
if printReq == true {
fmt.Println(text)
if debug || printReq {
matchedImports = append(matchedImports, strings.Split(text, "==")[0])
}
_, err = buf.WriteString(scanner.Text() + "\n")
utils.Check(err)
}
}

if debug {
fmt.Print("Imports Matched: => ")
for _, i := range matchedImports {
fmt.Print(i, " ")
}
}

err = reqFile.Truncate(0)
utils.Check(err)
_, err = reqFile.Seek(0, 0)
utils.Check(err)
_, err = buf.WriteTo(reqFile)
utils.Check(err)

fmt.Print("\nSuccessfully cleaned requirements.txt!\n")
}

// Cmd represents the clean command
Expand All @@ -99,10 +118,12 @@ var Cmd = &cobra.Command{
utils.Check(err)
ignoreDirs, err := cmd.Flags().GetString("ignore")
utils.Check(err)
debug, err := cmd.Flags().GetBool("debug")
utils.Check(err)
printReq, err := cmd.Flags().GetBool("print")
utils.Check(err)

cleanReq(reqPath, dirPath, venvPath, ignoreDirs, printReq)
cleanReq(reqPath, dirPath, venvPath, ignoreDirs, printReq, debug)
},
}

Expand Down
89 changes: 73 additions & 16 deletions cmd/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (

var (
savePath string
debug bool
pinType string
)

var importRegEx = `^import\s+([^\s]+)(\s+as\s+([^\s]+))?$`
Expand All @@ -39,6 +41,18 @@ var mappings embed.FS

// MAIN

func getPinType(pinType string) string {
if pinType == "gt" {
return ">="
} else if pinType == "compat" {
return "~="
} else if pinType == "none" {
return ""
} else {
return "=="
}
}

// GetPaths returns the path of python files in the dir
func GetPaths(dirs string, ignoreDirs string) ([]string, []string) {
var pyFiles []string
Expand Down Expand Up @@ -70,8 +84,9 @@ func GetPaths(dirs string, ignoreDirs string) ([]string, []string) {
if ignoreDirList != nil {
for _, l := range ignoreDirList {
if f.Name() == l.Name() {
// TODO: better debug info
fmt.Println("Skipped: ", l)
if debug {
fmt.Println("Skipped Directory: ", l)
}
return filepath.SkipDir
}
}
Expand All @@ -89,6 +104,11 @@ func GetPaths(dirs string, ignoreDirs string) ([]string, []string) {
dirList = append(dirList, file)
}
}

if debug {
fmt.Print("Found ", len(pyFiles), " .py files.\n")
}

return pyFiles, dirList
}

Expand Down Expand Up @@ -159,16 +179,22 @@ func ReadImports(pyFiles []string, dirList []string) map[string]struct{} {
}
}

totalImportCount := len(imports)
dirImports := make(map[string]bool)
// ignore all the directory imports
for k := range imports {
for _, l := range dirList {
if strings.Contains(l, k) {
delete(imports, k)
for _, m := range strings.Split(l, "/") {
if k == m {
dirImports[k] = true
delete(imports, k)
}
}
}
}

// python inbuilt imports
inbuiltImportCount := 0
predefinedLib, err := stdlib.Open("stdlib")
utils.Check(err)
scanner := bufio.NewScanner(predefinedLib)
Expand All @@ -180,9 +206,18 @@ func ReadImports(pyFiles []string, dirList []string) map[string]struct{} {

for j := range imports {
if inbuiltImportsSet[j] {
inbuiltImportCount += 1
delete(imports, j)
}
}

if debug {
fmt.Println("Total Imports: ", totalImportCount)
fmt.Println("Total Directory Imports: ", len(dirImports))
fmt.Println("Total Python Inbuilt Imports: ", inbuiltImportCount)
fmt.Println("Total User Imports: ", totalImportCount-(len(dirImports)+inbuiltImportCount))
}

return imports
}

Expand Down Expand Up @@ -273,10 +308,11 @@ func writeRequirements(venvDir string, codesDir string, savePath string, print b
// ver string
// }

_, err := os.Stat(filepath.Join(savePath, "requirements.txt"))
if !os.IsNotExist(err) {
fmt.Printf("requirements.txt already exists. It will be overwritten.\n")
}
// TODO: prompt for req.txt being overwritten
//_, err := os.Stat(filepath.Join(savePath, "requirements.txt"))
//if !os.IsNotExist(err) {
// fmt.Printf("requirements.txt already exists. It will be overwritten.\n")
//}
file, err := os.Create(filepath.Join(savePath, "requirements.txt"))
utils.Check(err)

Expand All @@ -303,6 +339,7 @@ func writeRequirements(venvDir string, codesDir string, savePath string, print b
}

// imports from pypi server
// TODO: can we change o(n2) to something less demanding?
var pypiStore []string
for i := range imports {
cntr := 0
Expand All @@ -312,29 +349,47 @@ func writeRequirements(venvDir string, codesDir string, savePath string, print b
}
}
if cntr == 0 {
// name, ver := fetchPyPIServer(i)
pypiStore = append(pypiStore, i)
}
}
// TODO: do we need the following level of verbosity?
//fmt.Println(pypiStore)
pypiSet := FetchPyPIServer(pypiStore)

if debug {
fmt.Println("Total Local Imports (from venv): ", len(localSet))
fmt.Println("Total PyPI server Imports: ", len(pypiSet))
}

importsInfo := make(map[string]string)
maps.Copy(importsInfo, localSet)
maps.Copy(importsInfo, pypiSet)

pintype := getPinType(pinType)

for i, j := range importsInfo {
if i != "" || j != "" {
fullImport := i + "==" + j + "\n"
if _, err := file.Write([]byte(fullImport)); err != nil {
panic(err)
}
if print {
fmt.Println(strings.TrimSuffix(fullImport, "\n"))
if pintype != "" {
fullImport := i + pintype + j + "\n"
if _, err := file.Write([]byte(fullImport)); err != nil {
panic(err)
}
if print {
fmt.Println(strings.TrimSuffix(fullImport, "\n"))
}
} else {
fullImport := i + "\n"
if _, err := file.Write([]byte(fullImport)); err != nil {
panic(err)
}
if print {
fmt.Println(strings.TrimSuffix(fullImport, "\n"))
}
}
}
}

fmt.Println("Created successfully!")
fmt.Printf("Created successfully!\nSaved to %s\n", savePath+"requirements.txt")

}

Expand All @@ -353,6 +408,7 @@ var Cmd = &cobra.Command{
utils.Check(err3)
printReq, err4 := cmd.Flags().GetBool("print")
utils.Check(err4)
debug, _ = cmd.Flags().GetBool("debug")

writeRequirements(venvPath, dirPath, savePath, printReq, ignoreDirs)
},
Expand All @@ -369,5 +425,6 @@ func init() {
// Cobra supports local flags which will only run when this command
// is called directly, e.g.:
// createCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle")
Cmd.Flags().StringVarP(&pinType, "mode", "m", "", "imports pin-type: ==, >=, ~=")
Cmd.Flags().StringVarP(&savePath, "savePath", "s", "./", "save path for requirements.txt")
}
73 changes: 32 additions & 41 deletions cmd/create/create_test.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
package create

import (
"bufio"
"os"
"strings"
"testing"
)

// func Test_getPaths(t *testing.T) {
// dirs := "/home/runner/work/pyreqs/pyreqs/ultralytics"
// _, dirList := getPaths(dirs, "")

// for _, j := range dirList {
// if strings.Contains(j, "venv") || strings.Contains(j, "env") || strings.Contains(j, "__pycache__") || strings.Contains(j, ".git") || strings.Contains(j, ".tox") {
// t.Error("Directory list has ignored directories.")
// }
// }
// }


func Test_fetchPyPIServer(t *testing.T) {
testImports := []string{"pandas", "numpy", "notapakage"}
fetchedImports := FetchPyPIServer(testImports)
Expand All @@ -28,32 +18,33 @@ func Test_fetchPyPIServer(t *testing.T) {
}
}

// func Test_writeRequirements(t *testing.T) {
// writeRequirements("", "/home/runner/work/pyreqs/pyreqs/ultralytics", "./", false)

// reqRead, err := os.Open("requirements.txt")
// if err != nil {
// panic(err)
// }
// defer reqRead.Close()

// testImports := map[string]bool{"tensorflow": true, "coremltools": true, "clip": true,
// "pytest": true, "comet-ml": true, "pycocotools": true, "shapely": true, "nncf": true,
// "thop": true, "onnxsim": true, "onnxruntime": true, "albumentations": true, "ipython": true,
// "x2paddle": true, "dvclive": true, "opencv-python": true, "matplotlib": true, "wandb": true,
// "ncnn": true, "setuptools": true, "SciPy": true, "yt-dlp": true, "psutil": true, "super-gradients": true,
// "torchvision": true, "Pillow": true, "tflite-runtime": true, "tflite-support": true, "seaborn": true,
// "tqdm": true, "lap": true, "requests": true, "numpy": true, "tritonclient": true, "pandas": true}

// var imports []string
// scanner := bufio.NewScanner(reqRead)
// for scanner.Scan() {
// imports = append(imports, scanner.Text())
// }

// for _, j := range imports {
// if !testImports[strings.Split(j, "==")[0]] {
// t.Error("Missing or extra imports found.")
// }
// }
// }
func Test_writeRequirements(t *testing.T) {
writeRequirements("", "testdata/", "./", false, " ")

reqRead, err := os.Open("requirements.txt")
if err != nil {
panic(err)
}
defer reqRead.Close()

testImports := map[string]bool{
"torch": true,
"ipython": true,
"tensorflow": true,
"pandas": true,
"PyYAML": true,
"transformers": true,
"numpy": true}

var imports []string
scanner := bufio.NewScanner(reqRead)
for scanner.Scan() {
imports = append(imports, scanner.Text())
}

for _, j := range imports {
if !testImports[strings.Split(j, "==")[0]] {
t.Error("Missing or extra imports found.")
}
}
}
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func init() {
rootCmd.PersistentFlags().StringP("venvPath", "v", " ", "directory to venv (virtual env)")
rootCmd.PersistentFlags().StringP("ignore", "i", " ", "ignore specific directories (each seperated by comma)")
rootCmd.PersistentFlags().BoolP("print", "p", false, "print requirements.txt to terminal")
rootCmd.PersistentFlags().Bool("debug", false, "print the debug information")

// Cobra also supports local flags, which will only run
// when this action is called directly.

Expand Down

0 comments on commit 3bd6c05

Please sign in to comment.