diff --git a/ciao-vendor/ciao-vendor.go b/ciao-vendor/ciao-vendor.go index 0fc771bef..b1861f03f 100644 --- a/ciao-vendor/ciao-vendor.go +++ b/ciao-vendor/ciao-vendor.go @@ -313,6 +313,28 @@ func updateNonVendoredDeps(deps piList, projectRoot string) error { return nil } +func getCurrentBranch(repo string) (string, error) { + cmd := exec.Command("git", "symbolic-ref", "HEAD") + cmd.Dir = repo + output, err := cmd.Output() + if err != nil { + return "", err + } + scanner := bufio.NewScanner(bytes.NewBuffer(output)) + if !scanner.Scan() { + return "", fmt.Errorf("Unable to determine current branch of %s", + repo) + } + branch := strings.TrimSpace(scanner.Text()) + const prefix = "refs/heads/" + if !strings.HasPrefix(branch, prefix) { + return "", fmt.Errorf("Unable to determine current branch of %s", + repo) + } + + return branch[len(prefix):], nil +} + func checkoutVersion(sourceRoot string) { for k, v := range repos { cmd := exec.Command("git", "checkout", v.version) @@ -885,6 +907,55 @@ func updates(sourceRoot, projectRoot string) error { return nil } +func test(sudo bool, sourceRoot, projectRoot, pkg, version string, goTestFlags []string) error { + fmt.Printf("Go getting %s\n", pkg) + cmd := exec.Command("go", "get", "-t", "-u", pkg) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Run() + if err != nil { + return fmt.Errorf("Unable to go get %s", pkg) + } + + branch, err := getCurrentBranch(path.Join(sourceRoot, pkg)) + if err != nil { + return fmt.Errorf("Unable to determine current branch of %s: %v", pkg, err) + } + cmd = exec.Command("git", "checkout", version) + cmd.Dir = path.Join(sourceRoot, pkg) + err = cmd.Run() + if err != nil { + return fmt.Errorf("Unable to checkout version %s of %s: %v", + version, pkg, err) + } + + var args []string + var command string + if sudo { + command = "sudo" + args = []string{"-E", "go"} + } else { + command = "go" + } + args = append(args, "test") + args = append(args, goTestFlags...) + args = append(args, pkg) + cmd = exec.Command(command, args...) + cmd.Dir = path.Join(sourceRoot, pkg) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if sudo { + cmd.Stdin = os.Stdin + } + err = cmd.Run() + + cmd = exec.Command("git", "checkout", branch) + cmd.Dir = path.Join(sourceRoot, pkg) + _ = cmd.Run() + + return err +} + func runCommand(cwd, sourceRoot string, args []string) error { var err error @@ -910,6 +981,17 @@ func runCommand(cwd, sourceRoot string, args []string) error { err = uses(fs.Args()[0], projectRoot, direct) case "updates": err = updates(sourceRoot, projectRoot) + case "test": + fs := flag.NewFlagSet("test", flag.ExitOnError) + sudo := false + fs.BoolVar(&sudo, "s", false, "run tests with sudo") + + if err := fs.Parse(args[2:]); err != nil { + return err + } + + args = fs.Args() + err = test(sudo, sourceRoot, projectRoot, args[0], args[1], args[2:]) } return err @@ -918,8 +1000,12 @@ func runCommand(cwd, sourceRoot string, args []string) error { func main() { if !((len(os.Args) == 2 && (os.Args[1] == "vendor" || os.Args[1] == "check" || os.Args[1] == "deps" || - os.Args[1] == "packages" || os.Args[1] == "updates")) || (len(os.Args) >= 3 && os.Args[1] == "uses")) { - fmt.Fprintln(os.Stderr, "Usage: ciao-vendor vendor|check|deps|packages") + os.Args[1] == "packages" || os.Args[1] == "updates")) || + (len(os.Args) >= 3 && (os.Args[1] == "uses")) || + (len(os.Args) >= 4 && (os.Args[1] == "test"))) { + fmt.Fprintln(os.Stderr, "Usage: ciao-vendor vendor|check|deps|packages|updates") + fmt.Fprintln(os.Stderr, "Usage: ciao-vendor uses [-d] package") + fmt.Fprintln(os.Stderr, "Usage: ciao-vendor test package version [go-test flags]") os.Exit(1) }