Skip to content

Commit

Permalink
Adds Ability to Chain Cobra RunE Commands (#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
schnie authored Dec 20, 2024
1 parent d2482e4 commit 76ea818
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 42 deletions.
49 changes: 7 additions & 42 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
package cmd

import (
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/astronomer/astro-cli/cmd/registry"
"github.com/sirupsen/logrus"

airflowclient "github.com/astronomer/astro-cli/airflow-client"
astrocore "github.com/astronomer/astro-cli/astro-client-core"
astroiamcore "github.com/astronomer/astro-cli/astro-client-iam-core"
astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core"
cloudCmd "github.com/astronomer/astro-cli/cmd/cloud"
"github.com/astronomer/astro-cli/cmd/registry"
softwareCmd "github.com/astronomer/astro-cli/cmd/software"
"github.com/astronomer/astro-cli/config"
"github.com/astronomer/astro-cli/cmd/utils"
"github.com/astronomer/astro-cli/context"
"github.com/astronomer/astro-cli/houston"
"github.com/astronomer/astro-cli/pkg/ansi"
"github.com/astronomer/astro-cli/pkg/httputil"
"github.com/astronomer/astro-cli/version"

"github.com/google/go-github/v48/github"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -73,37 +65,10 @@ func NewRootCmd() *cobra.Command {
\__\/\__\/ \_____\/ \__\/ \_\/ \_\/ \_____\/ \_____\/ \_____\/\________\/
Welcome to the Astro CLI, the modern command line interface for data orchestration. You can use it for Astro, Astronomer Software, or Local Development.`,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
// Check for latest version
if config.CFG.UpgradeMessage.GetBool() {
// create github client with 3 second timeout, setting an aggressive timeout since its not mandatory to get a response in each command execution
githubClient := github.NewClient(&http.Client{Timeout: 3 * time.Second})
// compare current version to latest
err = version.CompareVersions(githubClient, "astronomer", "astro-cli")
if err != nil {
softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error comparing CLI versions: "+err.Error())
}
}
if isCloudCtx {
err = cloudCmd.Setup(cmd, platformCoreClient, astroCoreClient)
if err != nil {
if strings.Contains(err.Error(), "token is invalid or malformed") {
return errors.New("API Token is invalid or malformed") //nolint
}
if strings.Contains(err.Error(), "the API token given has expired") {
return errors.New("API Token is expired") //nolint
}
softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error())
}
}
// common PersistentPreRunE component between software & cloud
// setting up log verbosity and dumping debug logs collected during CLI-initialization
if err := softwareCmd.SetUpLogs(os.Stdout, verboseLevel); err != nil {
return err
}
softwareCmd.PrintDebugLogs()
return nil
},
PersistentPreRunE: utils.ChainRunEs(
SetupLoggingPersistentPreRunE,
CreateRootPersistentPreRunE(astroCoreClient, platformCoreClient),
),
}

rootCmd.AddCommand(
Expand Down
56 changes: 56 additions & 0 deletions cmd/root_hooks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package cmd

import (
"errors"
"net/http"
"os"
"strings"
"time"

astrocore "github.com/astronomer/astro-cli/astro-client-core"
astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core"
cloudCmd "github.com/astronomer/astro-cli/cmd/cloud"
softwareCmd "github.com/astronomer/astro-cli/cmd/software"
"github.com/astronomer/astro-cli/config"
"github.com/astronomer/astro-cli/context"
"github.com/astronomer/astro-cli/version"
"github.com/google/go-github/v48/github"
"github.com/spf13/cobra"
)

// SetupLoggingPersistentPreRunE is a pre-run hook shared between software & cloud
// setting up log verbosity.
func SetupLoggingPersistentPreRunE(_ *cobra.Command, _ []string) error {
return softwareCmd.SetUpLogs(os.Stdout, verboseLevel)
}

// CreateRootPersistentPreRunE takes clients as arguments and returns a cobra
// pre-run hook that sets up the context and checks for the latest version.
func CreateRootPersistentPreRunE(astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) error {
// Check for latest version
if config.CFG.UpgradeMessage.GetBool() {
// create github client with 3 second timeout, setting an aggressive timeout since its not mandatory to get a response in each command execution
githubClient := github.NewClient(&http.Client{Timeout: 3 * time.Second})
// compare current version to latest
err := version.CompareVersions(githubClient, "astronomer", "astro-cli")
if err != nil {
softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error comparing CLI versions: "+err.Error())
}
}
if context.IsCloudContext() {
err := cloudCmd.Setup(cmd, platformCoreClient, astroCoreClient)
if err != nil {
if strings.Contains(err.Error(), "token is invalid or malformed") {
return errors.New("API Token is invalid or malformed") //nolint
}
if strings.Contains(err.Error(), "the API token given has expired") {
return errors.New("API Token is expired") //nolint
}
softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error())
}
}
softwareCmd.PrintDebugLogs()
return nil
}
}
14 changes: 14 additions & 0 deletions cmd/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,20 @@ import (
"github.com/spf13/cobra"
)

type RunE func(cmd *cobra.Command, args []string) error

// ChainRunEs chains multiple RunE functions together for cleaner composition.
func ChainRunEs(runEs ...RunE) RunE {
return func(cmd *cobra.Command, args []string) error {
for _, runE := range runEs {
if err := runE(cmd, args); err != nil {
return err
}
}
return nil
}
}

func EnsureProjectDir(cmd *cobra.Command, args []string) error {
isProjectDir, err := config.IsProjectDir(config.WorkingPath)
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions cmd/utils/utils_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package utils

import (
"errors"
"testing"

"github.com/astronomer/astro-cli/config"
Expand Down Expand Up @@ -46,3 +47,42 @@ func TestGetDefaultDeployDescription(t *testing.T) {
descriptionWithDags := GetDefaultDeployDescription(true)
assert.Equal(t, "Deployed via <astro deploy --dags>", descriptionWithDags)
}

func TestChainRunEsExecutesAllFunctionsSuccessfully(t *testing.T) {
runE1 := func(cmd *cobra.Command, args []string) error {
return nil
}
runE2 := func(cmd *cobra.Command, args []string) error {
return nil
}
chain := ChainRunEs(runE1, runE2)
err := chain(&cobra.Command{}, []string{})
assert.NoError(t, err)
}

func TestChainRunEsReturnsErrorIfAnyFunctionFails(t *testing.T) {
runE1 := func(cmd *cobra.Command, args []string) error {
return nil
}
runE2 := func(cmd *cobra.Command, args []string) error {
return errors.New("error in runE2")
}
chain := ChainRunEs(runE1, runE2)
err := chain(&cobra.Command{}, []string{})
assert.Error(t, err)
assert.Equal(t, "error in runE2", err.Error())
}

func TestChainRunEsStopsExecutionAfterError(t *testing.T) {
runE1 := func(cmd *cobra.Command, args []string) error {
return errors.New("error in runE1")
}
runE2 := func(cmd *cobra.Command, args []string) error {
t.FailNow() // This should not be called
return nil
}
chain := ChainRunEs(runE1, runE2)
err := chain(&cobra.Command{}, []string{})
assert.Error(t, err)
assert.Equal(t, "error in runE1", err.Error())
}

0 comments on commit 76ea818

Please sign in to comment.