From 76ea818f7029f4c734b2741c83abdf5fd0dfedea Mon Sep 17 00:00:00 2001 From: Greg Neiheisel <1036482+schnie@users.noreply.github.com> Date: Fri, 20 Dec 2024 13:36:02 -0500 Subject: [PATCH] Adds Ability to Chain Cobra RunE Commands (#1771) --- cmd/root.go | 49 ++++++------------------------------ cmd/root_hooks.go | 56 +++++++++++++++++++++++++++++++++++++++++ cmd/utils/utils.go | 14 +++++++++++ cmd/utils/utils_test.go | 40 +++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 42 deletions(-) create mode 100644 cmd/root_hooks.go diff --git a/cmd/root.go b/cmd/root.go index 8108bb730..30f4ef49d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,30 +1,22 @@ package cmd import ( - "errors" "fmt" - "net/http" "os" - "strings" - "time" - - "github.com/astronomer/astro-cli/cmd/registry" - "github.com/sirupsen/logrus" airflowclient "github.com/astronomer/astro-cli/airflow-client" astrocore "github.com/astronomer/astro-cli/astro-client-core" astroiamcore "github.com/astronomer/astro-cli/astro-client-iam-core" astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" cloudCmd "github.com/astronomer/astro-cli/cmd/cloud" + "github.com/astronomer/astro-cli/cmd/registry" softwareCmd "github.com/astronomer/astro-cli/cmd/software" - "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/cmd/utils" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/houston" "github.com/astronomer/astro-cli/pkg/ansi" "github.com/astronomer/astro-cli/pkg/httputil" - "github.com/astronomer/astro-cli/version" - - "github.com/google/go-github/v48/github" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -73,37 +65,10 @@ func NewRootCmd() *cobra.Command { \__\/\__\/ \_____\/ \__\/ \_\/ \_\/ \_____\/ \_____\/ \_____\/\________\/ Welcome to the Astro CLI, the modern command line interface for data orchestration. You can use it for Astro, Astronomer Software, or Local Development.`, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - // Check for latest version - if config.CFG.UpgradeMessage.GetBool() { - // create github client with 3 second timeout, setting an aggressive timeout since its not mandatory to get a response in each command execution - githubClient := github.NewClient(&http.Client{Timeout: 3 * time.Second}) - // compare current version to latest - err = version.CompareVersions(githubClient, "astronomer", "astro-cli") - if err != nil { - softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error comparing CLI versions: "+err.Error()) - } - } - if isCloudCtx { - err = cloudCmd.Setup(cmd, platformCoreClient, astroCoreClient) - if err != nil { - if strings.Contains(err.Error(), "token is invalid or malformed") { - return errors.New("API Token is invalid or malformed") //nolint - } - if strings.Contains(err.Error(), "the API token given has expired") { - return errors.New("API Token is expired") //nolint - } - softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error()) - } - } - // common PersistentPreRunE component between software & cloud - // setting up log verbosity and dumping debug logs collected during CLI-initialization - if err := softwareCmd.SetUpLogs(os.Stdout, verboseLevel); err != nil { - return err - } - softwareCmd.PrintDebugLogs() - return nil - }, + PersistentPreRunE: utils.ChainRunEs( + SetupLoggingPersistentPreRunE, + CreateRootPersistentPreRunE(astroCoreClient, platformCoreClient), + ), } rootCmd.AddCommand( diff --git a/cmd/root_hooks.go b/cmd/root_hooks.go new file mode 100644 index 000000000..640ebc540 --- /dev/null +++ b/cmd/root_hooks.go @@ -0,0 +1,56 @@ +package cmd + +import ( + "errors" + "net/http" + "os" + "strings" + "time" + + astrocore "github.com/astronomer/astro-cli/astro-client-core" + astroplatformcore "github.com/astronomer/astro-cli/astro-client-platform-core" + cloudCmd "github.com/astronomer/astro-cli/cmd/cloud" + softwareCmd "github.com/astronomer/astro-cli/cmd/software" + "github.com/astronomer/astro-cli/config" + "github.com/astronomer/astro-cli/context" + "github.com/astronomer/astro-cli/version" + "github.com/google/go-github/v48/github" + "github.com/spf13/cobra" +) + +// SetupLoggingPersistentPreRunE is a pre-run hook shared between software & cloud +// setting up log verbosity. +func SetupLoggingPersistentPreRunE(_ *cobra.Command, _ []string) error { + return softwareCmd.SetUpLogs(os.Stdout, verboseLevel) +} + +// CreateRootPersistentPreRunE takes clients as arguments and returns a cobra +// pre-run hook that sets up the context and checks for the latest version. +func CreateRootPersistentPreRunE(astroCoreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) func(cmd *cobra.Command, args []string) error { + return func(cmd *cobra.Command, args []string) error { + // Check for latest version + if config.CFG.UpgradeMessage.GetBool() { + // create github client with 3 second timeout, setting an aggressive timeout since its not mandatory to get a response in each command execution + githubClient := github.NewClient(&http.Client{Timeout: 3 * time.Second}) + // compare current version to latest + err := version.CompareVersions(githubClient, "astronomer", "astro-cli") + if err != nil { + softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error comparing CLI versions: "+err.Error()) + } + } + if context.IsCloudContext() { + err := cloudCmd.Setup(cmd, platformCoreClient, astroCoreClient) + if err != nil { + if strings.Contains(err.Error(), "token is invalid or malformed") { + return errors.New("API Token is invalid or malformed") //nolint + } + if strings.Contains(err.Error(), "the API token given has expired") { + return errors.New("API Token is expired") //nolint + } + softwareCmd.InitDebugLogs = append(softwareCmd.InitDebugLogs, "Error during cmd setup: "+err.Error()) + } + } + softwareCmd.PrintDebugLogs() + return nil + } +} diff --git a/cmd/utils/utils.go b/cmd/utils/utils.go index 44b4aa572..e07ad247c 100644 --- a/cmd/utils/utils.go +++ b/cmd/utils/utils.go @@ -7,6 +7,20 @@ import ( "github.com/spf13/cobra" ) +type RunE func(cmd *cobra.Command, args []string) error + +// ChainRunEs chains multiple RunE functions together for cleaner composition. +func ChainRunEs(runEs ...RunE) RunE { + return func(cmd *cobra.Command, args []string) error { + for _, runE := range runEs { + if err := runE(cmd, args); err != nil { + return err + } + } + return nil + } +} + func EnsureProjectDir(cmd *cobra.Command, args []string) error { isProjectDir, err := config.IsProjectDir(config.WorkingPath) if err != nil { diff --git a/cmd/utils/utils_test.go b/cmd/utils/utils_test.go index 073468fca..15efc0871 100644 --- a/cmd/utils/utils_test.go +++ b/cmd/utils/utils_test.go @@ -1,6 +1,7 @@ package utils import ( + "errors" "testing" "github.com/astronomer/astro-cli/config" @@ -46,3 +47,42 @@ func TestGetDefaultDeployDescription(t *testing.T) { descriptionWithDags := GetDefaultDeployDescription(true) assert.Equal(t, "Deployed via ", descriptionWithDags) } + +func TestChainRunEsExecutesAllFunctionsSuccessfully(t *testing.T) { + runE1 := func(cmd *cobra.Command, args []string) error { + return nil + } + runE2 := func(cmd *cobra.Command, args []string) error { + return nil + } + chain := ChainRunEs(runE1, runE2) + err := chain(&cobra.Command{}, []string{}) + assert.NoError(t, err) +} + +func TestChainRunEsReturnsErrorIfAnyFunctionFails(t *testing.T) { + runE1 := func(cmd *cobra.Command, args []string) error { + return nil + } + runE2 := func(cmd *cobra.Command, args []string) error { + return errors.New("error in runE2") + } + chain := ChainRunEs(runE1, runE2) + err := chain(&cobra.Command{}, []string{}) + assert.Error(t, err) + assert.Equal(t, "error in runE2", err.Error()) +} + +func TestChainRunEsStopsExecutionAfterError(t *testing.T) { + runE1 := func(cmd *cobra.Command, args []string) error { + return errors.New("error in runE1") + } + runE2 := func(cmd *cobra.Command, args []string) error { + t.FailNow() // This should not be called + return nil + } + chain := ChainRunEs(runE1, runE2) + err := chain(&cobra.Command{}, []string{}) + assert.Error(t, err) + assert.Equal(t, "error in runE1", err.Error()) +}