diff --git a/astro-client/client.go b/astro-client/client.go index bef18713f..da4f0a583 100644 --- a/astro-client/client.go +++ b/astro-client/client.go @@ -13,7 +13,7 @@ import ( ) const ( - AstronomerConnectionErrMsg = "cannot connect to Astronomer. Try to log in with astro login or check your internet connection and user permissions.\n\nDetails" + AstronomerConnectionErrMsg = "cannot connect to Astronomer. Try to log in with astro login or check your internet connection and user permissions. If you are using an API Key or Token make sure your context is correct.\n\nDetails" permissionsErrMsg = "you do not have the appropriate permissions for that" ) diff --git a/cloud/deployment/fromfile/fromfile.go b/cloud/deployment/fromfile/fromfile.go index 6ff8fa118..de6b1b640 100644 --- a/cloud/deployment/fromfile/fromfile.go +++ b/cloud/deployment/fromfile/fromfile.go @@ -83,6 +83,7 @@ func CreateOrUpdate(inputFile, action string, client astro.Client, out io.Writer if err != nil { return err } + existingDeployments, err = client.ListDeployments(c.Organization, workspaceID) if err != nil { return err diff --git a/cloud/deployment/inspect/inspect.go b/cloud/deployment/inspect/inspect.go index df4c1256f..1e752a519 100644 --- a/cloud/deployment/inspect/inspect.go +++ b/cloud/deployment/inspect/inspect.go @@ -303,10 +303,14 @@ func getTemplate(formattedDeployment *FormattedDeployment) FormattedDeployment { template := *formattedDeployment template.Deployment.Configuration.Name = "" template.Deployment.Metadata = nil + newEnvVars := []EnvironmentVariable{} for i := range template.Deployment.EnvVars { - // zero out updated at timestamp - template.Deployment.EnvVars[i].UpdatedAt = "" + if !template.Deployment.EnvVars[i].IsSecret { + newEnvVars = append(newEnvVars, template.Deployment.EnvVars[i]) + } } + template.Deployment.EnvVars = newEnvVars + return template } diff --git a/cloud/deployment/inspect/inspect_test.go b/cloud/deployment/inspect/inspect_test.go index 303f59ca4..673983561 100644 --- a/cloud/deployment/inspect/inspect_test.go +++ b/cloud/deployment/inspect/inspect_test.go @@ -866,10 +866,8 @@ func TestFormatPrintableDeployment(t *testing.T) { environment_variables: - is_secret: false key: foo + updated_at: NOW value: bar - - is_secret: true - key: bar - value: baz configuration: name: "" description: description @@ -1017,12 +1015,8 @@ func TestFormatPrintableDeployment(t *testing.T) { { "is_secret": false, "key": "foo", + "updated_at": "NOW", "value": "bar" - }, - { - "is_secret": true, - "key": "bar", - "value": "baz" } ], "configuration": { @@ -1446,9 +1440,17 @@ func TestGetTemplate(t *testing.T) { assert.NoError(t, err) expected.Deployment.Configuration.Name = "" expected.Deployment.Metadata = nil + newEnvVars := []EnvironmentVariable{} for i := range expected.Deployment.EnvVars { - expected.Deployment.EnvVars[i].UpdatedAt = "" + if !expected.Deployment.EnvVars[i].IsSecret { + newEnvVars = append(newEnvVars, expected.Deployment.EnvVars[i]) + } + } + expected.Deployment.EnvVars = newEnvVars + for i := range expected.Deployment.EnvVars { + expected.Deployment.EnvVars[i].UpdatedAt = "NOW" } + actual := getTemplate(&decoded) assert.Equal(t, expected, actual) }) @@ -1473,9 +1475,16 @@ func TestGetTemplate(t *testing.T) { expected.Deployment.Configuration.Name = "" expected.Deployment.Metadata = nil expected.Deployment.EnvVars = nil + newEnvVars := []EnvironmentVariable{} for i := range expected.Deployment.EnvVars { - expected.Deployment.EnvVars[i].UpdatedAt = "" + if !expected.Deployment.EnvVars[i].IsSecret { + newEnvVars = append(newEnvVars, expected.Deployment.EnvVars[i]) + } } + for i := range expected.Deployment.EnvVars { + expected.Deployment.EnvVars[i].UpdatedAt = "NOW" + } + expected.Deployment.EnvVars = newEnvVars actual := getTemplate(&decoded) assert.Equal(t, expected, actual) }) @@ -1496,9 +1505,16 @@ func TestGetTemplate(t *testing.T) { expected.Deployment.Configuration.Name = "" expected.Deployment.Metadata = nil expected.Deployment.AlertEmails = nil + newEnvVars := []EnvironmentVariable{} + for i := range expected.Deployment.EnvVars { + if !expected.Deployment.EnvVars[i].IsSecret { + newEnvVars = append(newEnvVars, expected.Deployment.EnvVars[i]) + } + } for i := range expected.Deployment.EnvVars { expected.Deployment.EnvVars[i].UpdatedAt = "" } + expected.Deployment.EnvVars = newEnvVars actual := getTemplate(&decoded) assert.Equal(t, expected, actual) }) diff --git a/cloud/workspace/workspace_test.go b/cloud/workspace/workspace_test.go index 9a1fe7d10..88c34bc02 100644 --- a/cloud/workspace/workspace_test.go +++ b/cloud/workspace/workspace_test.go @@ -109,7 +109,7 @@ func TestListError(t *testing.T) { buf := new(bytes.Buffer) err := List(astroAPI, buf) - assert.EqualError(t, err, "cannot connect to Astronomer. Try to log in with astro login or check your internet connection and user permissions.\n\nDetails: Error processing GraphQL request: API error (500): Internal Server Error") + assert.EqualError(t, err, "cannot connect to Astronomer. Try to log in with astro login or check your internet connection and user permissions. If you are using an API Key or Token make sure your context is correct.\n\nDetails: Error processing GraphQL request: API error (500): Internal Server Error") } func TestGetWorkspaceSelection(t *testing.T) { diff --git a/cmd/cloud/setup.go b/cmd/cloud/setup.go index 430476e96..1c1e1cea6 100644 --- a/cmd/cloud/setup.go +++ b/cmd/cloud/setup.go @@ -18,20 +18,26 @@ import ( "github.com/astronomer/astro-cli/cloud/organization" "github.com/astronomer/astro-cli/context" "github.com/astronomer/astro-cli/pkg/httputil" + "github.com/astronomer/astro-cli/pkg/util" + "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" "github.com/spf13/cobra" ) var ( - authLogin = auth.Login - - client = httputil.NewHTTPClient() + authLogin = auth.Login + defaultDomain = "astronomer.io" + client = httputil.NewHTTPClient() + isDeploymentFile = false + parseAPIToken = util.ParseAPIToken + errNotAPIToken = errors.New("the API token given does not appear to be an Astro API Token") ) const ( accessTokenExpThreshold = 5 * time.Minute topLvlCmd = "astro" + deploymentCmd = "deployment" ) type TokenResponse struct { @@ -44,6 +50,18 @@ type TokenResponse struct { ErrorDescription string `json:"error_description,omitempty"` } +type CustomClaims struct { + OrgAuthServiceID string `json:"org_id"` + Scope string `json:"scope"` + Permissions []string `json:"permissions"` + Version string `json:"version"` + IsAstronomerGenerated bool `json:"isAstronomerGenerated"` + RsaKeyID string `json:"kid"` + APITokenID string `json:"apiTokenId"` + jwt.RegisteredClaims +} + +//nolint:gocognit func Setup(cmd *cobra.Command, args []string, client astro.Client, coreClient astrocore.CoreClient) error { // If the user is trying to login or logout no need to go through auth setup. if cmd.CalledAs() == "login" || cmd.CalledAs() == "logout" { @@ -79,11 +97,25 @@ func Setup(cmd *cobra.Command, args []string, client astro.Client, coreClient as return nil } + // if deployment inspect, create, or udpate commands are used + deploymentCmds := []string{"inspect", "create", "update"} + if util.Contains(deploymentCmds, cmd.CalledAs()) && cmd.Parent().Use == deploymentCmd { + isDeploymentFile = true + } + + // Check for APITokens before API keys or refresh tokens + apiToken, err := checkAPIToken(isDeploymentFile, args) + if err != nil { + return err + } + if apiToken { + return nil + } + // run auth setup for any command that requires auth apiKey, err := checkAPIKeys(client, coreClient, args) if err != nil { - fmt.Println(err) - fmt.Println("\nThere was an error using API keys, using regular auth instead") + return err } if apiKey { return nil @@ -207,7 +239,7 @@ func checkAPIKeys(astroClient astro.Client, coreClient astrocore.CoreClient, arg c, err := context.GetCurrentContext() // get current context if err != nil { // set context - domain := "astronomer.io" + domain := defaultDomain if !context.Exists(domain) { err := context.SetContext(domain) if err != nil { @@ -308,3 +340,71 @@ func checkAPIKeys(astroClient astro.Client, coreClient astrocore.CoreClient, arg } return true, nil } + +func checkAPIToken(isDeploymentFile bool, args []string) (bool, error) { + // check os variables + astroAPIToken := os.Getenv("ASTRO_API_TOKEN") + if astroAPIToken == "" { + return false, nil + } + if !isDeploymentFile { + fmt.Println("Using an Astro API Token") + } + + // get authConfig + c, err := context.GetCurrentContext() // get current context + if err != nil { + // set context + domain := defaultDomain + if !context.Exists(domain) { + err := context.SetContext(domain) + if err != nil { + return false, err + } + } + + // Switch context + err = context.Switch(domain) + if err != nil { + return false, err + } + + c, err = context.GetContext(domain) // get current context + if err != nil { + return false, err + } + } + + err = c.SetContextKey("token", "Bearer "+astroAPIToken) + if err != nil { + return false, err + } + + err = c.SetExpiresIn(time.Now().AddDate(1, 0, 0).Unix()) + if err != nil { + return false, err + } + // Parse the token to peek at the custom claims + claims, err := parseAPIToken(astroAPIToken) + if err != nil { + return false, err + } + if len(claims.Permissions) == 0 { + return false, errNotAPIToken + } + workspaceID = strings.Replace(claims.Permissions[1], "workspaceId:", "", 1) + orgID := strings.Replace(claims.Permissions[2], "organizationId:", "", 1) + orgShortName := strings.Replace(claims.Permissions[3], "orgShortNameId:", "", 1) + // If using api keys for virtual runtimes, we dont need to look up for this endpoint + if !(len(args) > 0 && strings.HasPrefix(args[0], "vr-")) { + err := c.SetContextKey("workspace", workspaceID) // c.Workspace + if err != nil { + fmt.Println("no workspace set") + } + } + err = c.SetOrganizationContext(orgID, orgShortName) + if err != nil { + fmt.Println("no organization context set") + } + return true, nil +} diff --git a/cmd/cloud/setup_test.go b/cmd/cloud/setup_test.go index 1ea87d3b5..1be2a630f 100644 --- a/cmd/cloud/setup_test.go +++ b/cmd/cloud/setup_test.go @@ -13,6 +13,7 @@ import ( astro_mocks "github.com/astronomer/astro-cli/astro-client/mocks" "github.com/astronomer/astro-cli/context" testUtil "github.com/astronomer/astro-cli/pkg/testing" + "github.com/astronomer/astro-cli/pkg/util" "github.com/pkg/errors" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" @@ -111,6 +112,23 @@ func TestSetup(t *testing.T) { assert.NoError(t, err) }) + t.Run("deployment cmd", func(t *testing.T) { + testUtil.SetupOSArgsForGinkgo() + cmd := &cobra.Command{Use: "inspect"} + cmd, err := cmd.ExecuteC() + assert.NoError(t, err) + + rootCmd := &cobra.Command{Use: "deployment"} + rootCmd.AddCommand(cmd) + + authLogin = func(domain, token string, client astro.Client, coreClient astrocore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + return nil + } + + err = Setup(cmd, []string{}, nil, nil) + assert.NoError(t, err) + }) + t.Run("deploy cmd", func(t *testing.T) { testUtil.SetupOSArgsForGinkgo() cmd := &cobra.Command{Use: "deploy"} @@ -318,3 +336,63 @@ func TestCheckToken(t *testing.T) { assert.Contains(t, err.Error(), "failed to login") }) } + +func TestCheckAPIToken(t *testing.T) { + testUtil.InitTestConfig(testUtil.CloudPlatform) + t.Run("test context switch", func(t *testing.T) { + permissions := []string{ + "", + "workspaceId:workspace-id", + "organizationId:org-ID", + "orgShortNameId:org-short-name", + } + mockClaims := util.CustomClaims{ + Permissions: permissions, + } + + authLogin = func(domain, token string, client astro.Client, coreClient astrocore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + return nil + } + + parseAPIToken = func(astroAPIToken string) (*util.CustomClaims, error) { + return &mockClaims, nil + } + + t.Setenv("ASTRO_API_TOKEN", "token") + + // Switch context + domain := "astronomer-dev.io" + err := context.Switch(domain) + assert.NoError(t, err) + + // run CheckAPIKeys + _, err = checkAPIToken(true, []string{}) + assert.NoError(t, err) + }) + + t.Run("bad claims", func(t *testing.T) { + permissions := []string{} + mockClaims := util.CustomClaims{ + Permissions: permissions, + } + + authLogin = func(domain, token string, client astro.Client, coreClient astrocore.CoreClient, out io.Writer, shouldDisplayLoginLink bool) error { + return nil + } + + parseAPIToken = func(astroAPIToken string) (*util.CustomClaims, error) { + return &mockClaims, nil + } + + t.Setenv("ASTRO_API_TOKEN", "token") + + // Switch context + domain := "astronomer-dev.io" + err := context.Switch(domain) + assert.NoError(t, err) + + // run CheckAPIKeys + _, err = checkAPIToken(true, []string{}) + assert.ErrorIs(t, err, errNotAPIToken) + }) +} diff --git a/config/context.go b/config/context.go index e2a2e99c0..6b5d900e0 100644 --- a/config/context.go +++ b/config/context.go @@ -3,11 +3,8 @@ package config import ( "errors" "fmt" - "os" "strings" "time" - - "github.com/astronomer/astro-cli/pkg/printutil" ) var ( @@ -21,14 +18,6 @@ const ( contextsKey = "contexts" ) -// newTableOut construct new printutil.Table -func newTableOut() *printutil.Table { - return &printutil.Table{ - Padding: []int{36, 36}, - Header: []string{"CONTEXT DOMAIN", "WORKSPACE"}, - } -} - // Contexts holds all available Context structs in a map type Contexts struct { Contexts map[string]Context `mapstructure:"contexts"` @@ -188,11 +177,6 @@ func (c *Context) SwitchContext() error { return err } - tab := newTableOut() - tab.AddRow([]string{ctx.Domain, ctx.Workspace}, false) - tab.SuccessMsg = "\n Switched context" - tab.Print(os.Stdout) - return nil } diff --git a/config/context_test.go b/config/context_test.go index e9ac155f2..ebcd28522 100644 --- a/config/context_test.go +++ b/config/context_test.go @@ -11,12 +11,6 @@ import ( var err error -func TestNewTableOut(t *testing.T) { - tab := newTableOut() - assert.NotNil(t, tab) - assert.Equal(t, []int{36, 36}, tab.Padding) -} - func TestGetCurrentContextError(t *testing.T) { fs := afero.NewMemMapFs() configRaw := []byte(`cloud: diff --git a/context/context.go b/context/context.go index 9ac612a9d..97808b9b1 100644 --- a/context/context.go +++ b/context/context.go @@ -3,6 +3,7 @@ package context import ( "fmt" "io" + "os" "regexp" "strings" @@ -31,6 +32,14 @@ var tab = printutil.Table{ ColorRowCode: [2]string{"\033[1;32m", "\033[0m"}, } +// newTableOut construct new printutil.Table +func newTableOut() *printutil.Table { + return &printutil.Table{ + Padding: []int{36, 36}, + Header: []string{"CONTEXT DOMAIN", "WORKSPACE"}, + } +} + // ContextExists checks to see if context exist in config func Exists(domain string) bool { c := config.Context{Domain: domain} @@ -107,7 +116,24 @@ func SwitchContext(cmd *cobra.Command, args []string) error { if len(args) == 1 { domain = args[0] } - return Switch(domain) + + err := Switch(domain) + if err != nil { + return err + } + + c := config.Context{Domain: domain} + ctx, err := c.GetContext() + if err != nil { + return err + } + + tab := newTableOut() + tab.AddRow([]string{ctx.Domain, ctx.Workspace}, false) + tab.SuccessMsg = "\n Switched context" + tab.Print(os.Stdout) + + return nil } func ListContext(cmd *cobra.Command, args []string, out io.Writer) error { diff --git a/context/context_test.go b/context/context_test.go index a6665a107..3954d77d9 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -11,6 +11,12 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewTableOut(t *testing.T) { + tab := newTableOut() + assert.NotNil(t, tab) + assert.Equal(t, []int{36, 36}, tab.Padding) +} + func TestExists(t *testing.T) { testUtil.InitTestConfig(testUtil.LocalPlatform) // Check that we don't have localhost123 in test config from testUtils.NewTestConfig() diff --git a/go.mod b/go.mod index 7a5c0a599..131e1ee08 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/docker/distribution v2.7.1+incompatible github.com/fatih/camelcase v1.0.0 github.com/ghodss/yaml v1.0.0 + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/go-github/v48 v48.2.0 github.com/hashicorp/go-version v1.3.0 github.com/mitchellh/mapstructure v1.4.2 diff --git a/go.sum b/go.sum index e9a17d070..b2c21c520 100644 --- a/go.sum +++ b/go.sum @@ -468,6 +468,8 @@ github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXP github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/pkg/util/util.go b/pkg/util/util.go index ed37a18e0..7126be039 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -7,10 +7,24 @@ import ( "strings" "github.com/Masterminds/semver" + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + // "Masterminds/semver" does not support the format of pre-release tags for SQL CLI, so we're using "hashicorp/go-version" goVersion "github.com/hashicorp/go-version" ) +type CustomClaims struct { + OrgAuthServiceID string `json:"org_id"` + Scope string `json:"scope"` + Permissions []string `json:"permissions"` + Version string `json:"version"` + IsAstronomerGenerated bool `json:"isAstronomerGenerated"` + RsaKeyID string `json:"kid"` + APITokenID string `json:"apiTokenId"` + jwt.RegisteredClaims +} + // coerce a string into SemVer if possible func Coerce(version string) *semver.Version { v, err := semver.NewVersion(version) @@ -103,3 +117,14 @@ func IsRequiredVersionMet(currentVersion, requiredVersion string) (bool, error) } return false, nil } + +func ParseAPIToken(astroAPIToken string) (*CustomClaims, error) { + // Parse the token to peek at the custom claims + jwtParser := jwt.NewParser() + parsedToken, _, err := jwtParser.ParseUnverified(astroAPIToken, &CustomClaims{}) + claims, ok := parsedToken.Claims.(*CustomClaims) + if !ok { + return nil, errors.Wrap(err, "failed to parse auth token") + } + return claims, nil +}