From b91b6de1559e4bd6d3aa05b4ea222b2acdff9d13 Mon Sep 17 00:00:00 2001 From: neel-astro Date: Mon, 26 Aug 2024 12:04:16 +0530 Subject: [PATCH 1/5] Add workload identity to hosted deployment --- astro-client-core/api.gen.go | 36 ++-- cloud/deployment/deployment.go | 27 ++- cloud/deployment/deployment_test.go | 203 ++++++++++++++------- cloud/deployment/fromfile/fromfile.go | 24 +++ cloud/deployment/fromfile/fromfile_test.go | 32 +++- cloud/deployment/inspect/inspect.go | 4 +- cloud/deployment/inspect/inspect_test.go | 13 +- cmd/cloud/deployment.go | 4 +- cmd/cloud/deployment_test.go | 57 ++++++ 9 files changed, 300 insertions(+), 100 deletions(-) diff --git a/astro-client-core/api.gen.go b/astro-client-core/api.gen.go index 42c31563a..93dfdcdf8 100644 --- a/astro-client-core/api.gen.go +++ b/astro-client-core/api.gen.go @@ -50,10 +50,11 @@ const ( // Defines values for ClusterCohort. const ( - ClusterCohortCRITICAL ClusterCohort = "CRITICAL" - ClusterCohortDEFAULT ClusterCohort = "DEFAULT" - ClusterCohortINTERNAL ClusterCohort = "INTERNAL" - ClusterCohortSTABLE ClusterCohort = "STABLE" + ClusterCohortCRITICAL ClusterCohort = "CRITICAL" + ClusterCohortDEFAULT ClusterCohort = "DEFAULT" + ClusterCohortINTERNAL ClusterCohort = "INTERNAL" + ClusterCohortPREDEFAULT ClusterCohort = "PRE_DEFAULT" + ClusterCohortSTABLE ClusterCohort = "STABLE" ) // Defines values for ClusterStatus. @@ -87,10 +88,11 @@ const ( // Defines values for ClusterDetailedCohort. const ( - ClusterDetailedCohortCRITICAL ClusterDetailedCohort = "CRITICAL" - ClusterDetailedCohortDEFAULT ClusterDetailedCohort = "DEFAULT" - ClusterDetailedCohortINTERNAL ClusterDetailedCohort = "INTERNAL" - ClusterDetailedCohortSTABLE ClusterDetailedCohort = "STABLE" + ClusterDetailedCohortCRITICAL ClusterDetailedCohort = "CRITICAL" + ClusterDetailedCohortDEFAULT ClusterDetailedCohort = "DEFAULT" + ClusterDetailedCohortINTERNAL ClusterDetailedCohort = "INTERNAL" + ClusterDetailedCohortPREDEFAULT ClusterDetailedCohort = "PRE_DEFAULT" + ClusterDetailedCohortSTABLE ClusterDetailedCohort = "STABLE" ) // Defines values for ClusterDetailedStatus. @@ -544,10 +546,11 @@ const ( // Defines values for SharedClusterCohort. const ( - SharedClusterCohortCRITICAL SharedClusterCohort = "CRITICAL" - SharedClusterCohortDEFAULT SharedClusterCohort = "DEFAULT" - SharedClusterCohortINTERNAL SharedClusterCohort = "INTERNAL" - SharedClusterCohortSTABLE SharedClusterCohort = "STABLE" + SharedClusterCohortCRITICAL SharedClusterCohort = "CRITICAL" + SharedClusterCohortDEFAULT SharedClusterCohort = "DEFAULT" + SharedClusterCohortINTERNAL SharedClusterCohort = "INTERNAL" + SharedClusterCohortPREDEFAULT SharedClusterCohort = "PRE_DEFAULT" + SharedClusterCohortSTABLE SharedClusterCohort = "STABLE" ) // Defines values for SharedClusterStatus. @@ -1243,6 +1246,7 @@ type Bundle struct { type Cluster struct { AppliedHarmonyVersion *string `json:"appliedHarmonyVersion,omitempty"` AppliedTemplateVersion string `json:"appliedTemplateVersion"` + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider ClusterCloudProvider `json:"cloudProvider"` Cohort *ClusterCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -1293,6 +1297,7 @@ type ClusterType string type ClusterDetailed struct { AppliedHarmonyVersion *string `json:"appliedHarmonyVersion,omitempty"` AppliedTemplateVersion string `json:"appliedTemplateVersion"` + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider ClusterDetailedCloudProvider `json:"cloudProvider"` Cohort *ClusterDetailedCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -1461,6 +1466,7 @@ type ConnectionAuthTypeParameter struct { // CreateAwsClusterRequest defines model for CreateAwsClusterRequest. type CreateAwsClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -1480,6 +1486,7 @@ type CreateAwsClusterRequestType string // CreateAzureClusterRequest defines model for CreateAzureClusterRequest. type CreateAzureClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -1753,6 +1760,7 @@ type CreateEnvironmentObjectRequestScope string // CreateGcpClusterRequest defines model for CreateGcpClusterRequest. type CreateGcpClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` HarmonyVersion *string `json:"harmonyVersion,omitempty"` @@ -2968,6 +2976,7 @@ type SelfSignupType string // SharedCluster defines model for SharedCluster. type SharedCluster struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` CloudProvider SharedClusterCloudProvider `json:"cloudProvider"` Cohort *SharedClusterCohort `json:"cohort,omitempty"` CreatedAt time.Time `json:"createdAt"` @@ -3095,6 +3104,7 @@ type TriggerGitDeployRequestDeployType string // UpdateAwsClusterRequest defines model for UpdateAwsClusterRequest. type UpdateAwsClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` @@ -3107,6 +3117,7 @@ type UpdateAwsClusterRequest struct { // UpdateAzureClusterRequest defines model for UpdateAzureClusterRequest. type UpdateAzureClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` @@ -3245,6 +3256,7 @@ type UpdateEnvironmentObjectRequestScope string // UpdateGcpClusterRequest defines model for UpdateGcpClusterRequest. type UpdateGcpClusterRequest struct { + BlockInternetAccess *bool `json:"blockInternetAccess,omitempty"` DbInstanceType string `json:"dbInstanceType"` DbInstanceVersion *string `json:"dbInstanceVersion,omitempty"` DisableHarmonyVersionUpgrades *bool `json:"disableHarmonyVersionUpgrades,omitempty"` diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index c09a876ac..24af5f9ed 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -63,12 +63,11 @@ const ( ) var ( - sleepTime = 180 - tickNum = 10 - timeoutNum = 180 - listLimit = 1000 - dedicatedDeploymentRequest = astroplatformcore.UpdateDedicatedDeploymentRequest{} - dagDeployEnabled bool + sleepTime = 180 + tickNum = 10 + timeoutNum = 180 + listLimit = 1000 + dagDeployEnabled bool ) func newTableOut() *printutil.Table { @@ -212,6 +211,7 @@ func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logSch return nil } +// TODO: move these input arguements to a struct, and drop the nolint func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy, executor, cloudProvider, region, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, deploymentType astroplatformcore.DeploymentType, schedulerAU, schedulerReplicas int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, waitForStatus bool) error { //nolint var organizationID string var currentWorkspace astrocore.Workspace @@ -325,6 +325,10 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy if resourceQuotaMemory == "" { resourceQuotaMemory = configOption.ResourceQuotas.ResourceQuota.Memory.Default } + var deplWorkloadIdentity *string + if workloadIdentity != "" { + deplWorkloadIdentity = &workloadIdentity + } // build standard input if IsDeploymentStandard(deploymentType) { var requestedCloudProvider astroplatformcore.CreateStandardDeploymentRequestCloudProvider @@ -360,6 +364,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy DefaultTaskPodMemory: defaultTaskPodMemory, ResourceQuotaCpu: resourceQuotaCpu, ResourceQuotaMemory: resourceQuotaMemory, + WorkloadIdentity: deplWorkloadIdentity, } if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) { standardDeploymentRequest.WorkerQueues = &defautWorkerQueue @@ -409,6 +414,7 @@ func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy DefaultTaskPodMemory: defaultTaskPodMemory, ResourceQuotaCpu: resourceQuotaCpu, ResourceQuotaMemory: resourceQuotaMemory, + WorkloadIdentity: deplWorkloadIdentity, } if strings.EqualFold(executor, CeleryExecutor) || strings.EqualFold(executor, CELERY) { dedicatedDeploymentRequest.WorkerQueues = &defautWorkerQueue @@ -738,6 +744,7 @@ func HealthPoll(deploymentID, ws string, sleepTime, tickNum, timeoutNum int, pla } } +// TODO: move these input arguements to a struct, and drop the nolint func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) error { //nolint var queueCreateUpdate, confirmWithUser bool // get deployment @@ -911,6 +918,10 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec if resourceQuotaMemory == "" { resourceQuotaMemory = *currentDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if workloadIdentity != "" { + deplWorkloadIdentity = &workloadIdentity + } if IsDeploymentStandard(*currentDeployment.Type) { var requestedExecutor astroplatformcore.UpdateStandardDeploymentRequestExecutor switch strings.ToUpper(executor) { @@ -941,6 +952,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec EnvironmentVariables: deploymentEnvironmentVariablesRequest, DefaultTaskPodCpu: defaultTaskPodCpu, DefaultTaskPodMemory: defaultTaskPodMemory, + WorkloadIdentity: deplWorkloadIdentity, } switch schedulerSize { case strings.ToLower(string(astrocore.CreateStandardDeploymentRequestSchedulerSizeSMALL)): @@ -988,7 +1000,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec case strings.ToUpper(KUBERNETES): requestedExecutor = astroplatformcore.UpdateDedicatedDeploymentRequestExecutorKUBERNETES } - dedicatedDeploymentRequest = astroplatformcore.UpdateDedicatedDeploymentRequest{ + dedicatedDeploymentRequest := astroplatformcore.UpdateDedicatedDeploymentRequest{ Description: &description, Name: name, Executor: requestedExecutor, @@ -1004,6 +1016,7 @@ func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, exec ResourceQuotaMemory: resourceQuotaMemory, EnvironmentVariables: deploymentEnvironmentVariablesRequest, WorkerQueues: &workerQueuesRequest, + WorkloadIdentity: deplWorkloadIdentity, } switch schedulerSize { case strings.ToLower(string(astrocore.CreateStandardDeploymentRequestSchedulerSizeSMALL)): diff --git a/cloud/deployment/deployment_test.go b/cloud/deployment/deployment_test.go index 56848c9f6..d59ee57ea 100644 --- a/cloud/deployment/deployment_test.go +++ b/cloud/deployment/deployment_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/astronomer/astro-cli/context" testUtil "github.com/astronomer/astro-cli/pkg/testing" "github.com/astronomer/astro-cli/pkg/util" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" ) @@ -30,12 +32,45 @@ func TestDeployment(t *testing.T) { } var ( - hybridQueueList = []astroplatformcore.HybridWorkerQueueRequest{} - workerQueueRequest = []astroplatformcore.WorkerQueueRequest{} - newEnvironmentVariables = []astroplatformcore.DeploymentEnvironmentVariableRequest{} - errMock = errors.New("mock error") - errCreateFailed = errors.New("failed to create deployment") - nodePools = []astroplatformcore.NodePool{ + hybridQueueList = []astroplatformcore.HybridWorkerQueueRequest{} + workerQueueRequest = []astroplatformcore.WorkerQueueRequest{} + newEnvironmentVariables = []astroplatformcore.DeploymentEnvironmentVariableRequest{} + errMock = errors.New("mock error") + errCreateFailed = errors.New("failed to create deployment") + nodePools = []astroplatformcore.NodePool{} + mockListClustersResponse = astroplatformcore.ListClustersResponse{} + cluster = astroplatformcore.Cluster{} + mockGetClusterResponse = astroplatformcore.GetClusterResponse{} + standardType = astroplatformcore.DeploymentTypeSTANDARD + dedicatedType = astroplatformcore.DeploymentTypeDEDICATED + hybridType = astroplatformcore.DeploymentTypeHYBRID + testRegion = "region" + testProvider = astroplatformcore.DeploymentCloudProviderGCP + testCluster = "cluster" + testWorkloadIdentity = "test-workload-identity" + mockCoreDeploymentResponse = []astroplatformcore.Deployment{} + mockListDeploymentsResponse = astroplatformcore.ListDeploymentsResponse{} + emptyListDeploymentsResponse = astroplatformcore.ListDeploymentsResponse{} + schedulerAU = 0 + clusterID = "cluster-id" + executorCelery = astroplatformcore.DeploymentExecutorCELERY + executorKubernetes = astroplatformcore.DeploymentExecutorKUBERNETES + highAvailability = true + isDevelopmentMode = true + resourceQuotaCPU = "1cpu" + ResourceQuotaMemory = "1" + schedulerSize = astroplatformcore.DeploymentSchedulerSizeSMALL + deploymentResponse = astroplatformcore.GetDeploymentResponse{} + deploymentResponse2 = astroplatformcore.GetDeploymentResponse{} + GetDeploymentOptionsResponseOK = astrocore.GetDeploymentOptionsResponse{} + workspaceTestDescription = "test workspace" + workspace1 = astrocore.Workspace{} + workspaces = []astrocore.Workspace{} + ListWorkspacesResponseOK = astrocore.ListWorkspacesResponse{} +) + +func MockResponseInit() { + nodePools = []astroplatformcore.NodePool{ { Id: "test-pool-id", IsDefault: false, @@ -76,13 +111,6 @@ var ( }, JSON200: &cluster, } - standardType = astroplatformcore.DeploymentTypeSTANDARD - dedicatedType = astroplatformcore.DeploymentTypeDEDICATED - hybridType = astroplatformcore.DeploymentTypeHYBRID - testRegion = "region" - testProvider = astroplatformcore.DeploymentCloudProviderGCP - testCluster = "cluster" - testWorkloadIdentity = "test-workload-identity" mockCoreDeploymentResponse = []astroplatformcore.Deployment{ { Id: "test-id-1", @@ -117,16 +145,7 @@ var ( Deployments: []astroplatformcore.Deployment{}, }, } - schedulerAU = 0 - clusterID = "cluster-id" - executorCelery = astroplatformcore.DeploymentExecutorCELERY - executorKubernetes = astroplatformcore.DeploymentExecutorKUBERNETES - highAvailability = true - isDevelopmentMode = true - resourceQuotaCPU = "1cpu" - ResourceQuotaMemory = "1" - schedulerSize = astroplatformcore.DeploymentSchedulerSizeSMALL - deploymentResponse = astroplatformcore.GetDeploymentResponse{ + deploymentResponse = astroplatformcore.GetDeploymentResponse{ HTTPResponse: &http.Response{ StatusCode: 200, }, @@ -213,7 +232,7 @@ var ( }, } workspaceTestDescription = "test workspace" - workspace1 = astrocore.Workspace{ + workspace1 = astrocore.Workspace{ Name: "test-workspace", Description: &workspaceTestDescription, ApiKeyOnlyDeploymentsDefault: false, @@ -234,7 +253,7 @@ var ( Workspaces: workspaces, }, } -) +} const ( org = "test-org-id" @@ -250,8 +269,25 @@ var ( ) func (s *Suite) SetupTest() { + // init mocks + mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface) + mockCoreClient = new(astrocore_mocks.ClientWithResponsesInterface) + + // init responses object + MockResponseInit() +} + +func (s *Suite) TearDownSubTest() { + // assert expectations + mockPlatformCoreClient.AssertExpectations(s.T()) + mockCoreClient.AssertExpectations(s.T()) + + // reset mocks mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface) mockCoreClient = new(astrocore_mocks.ClientWithResponsesInterface) + + // reset responses object + MockResponseInit() } func (s *Suite) TestList() { @@ -1076,6 +1112,31 @@ func (s *Suite) TestCreate() { mockPlatformCoreClient.AssertExpectations(s.T()) }) + s.Run("success with hosted deployment with workload identity", func() { + mockWorkloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + // Set up mock responses and expectations + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() + mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.CreateDeploymentRequest) bool { + request, _ := input.AsCreateDedicatedDeploymentRequest() + return *request.WorkloadIdentity == mockWorkloadIdentity + }, + )).Return(&mockCreateDeploymentResponse, nil).Once() + + // Mock user input for deployment name + defer testUtil.MockUserInput(s.T(), "test-name")() + + // Call the Create function with a non-empty workload ID + err := Create("test-name", ws, "test-desc", csID, "12.0.0", dagDeploy, CeleryExecutor, "aws", "us-west-2", strings.ToLower(string(astrocore.DeploymentSchedulerSizeSMALL)), "", "", "", "", "", "", "", mockWorkloadIdentity, astroplatformcore.DeploymentTypeDEDICATED, 0, 0, mockPlatformCoreClient, mockCoreClient, false) + s.NoError(err) + + // Assert expectations + mockCoreClient.AssertExpectations(s.T()) + mockPlatformCoreClient.AssertExpectations(s.T()) + }) + s.Run("success with standard/dedicated type different scheduler sizes", func() { // Set up mock responses and expectations mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(8) @@ -1515,8 +1576,6 @@ func (s *Suite) TestUpdate() { //nolint // success with dedicated updating to kubernetes executor err = Update("test-id-1", "", ws, "", "", "", KubeExecutor, "", "", "", "", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("successfully update schedulerSize and highAvailability and CICDEnforement", func() { mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(6) @@ -1569,10 +1628,6 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Executor = &executorKubernetes err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - deploymentResponse.JSON200.Executor = &executorKubernetes - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) - deploymentResponse.JSON200.Executor = &executorCelery }) s.Run("successfully update developmentMode", func() { @@ -1602,9 +1657,6 @@ func (s *Suite) TestUpdate() { //nolint // success with dedicated type err = Update("", "", ws, "", "test-1", "enable", CeleryExecutor, "medium", "disable", "enable", "disable", "", "", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("failed to validate resources", func() { @@ -1619,8 +1671,6 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Type = &hybridType err = Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "10Gi", "2CPU", "10Gi", "", 100, 100, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, ErrInvalidResourceRequest) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("list deployments failure", func() { @@ -1628,8 +1678,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, errMock) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("invalid deployment id", func() { @@ -1654,8 +1702,6 @@ func (s *Suite) TestUpdate() { //nolint // invalid selection err = Update("", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorContains(err, "invalid Deployment selected") - - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("cancel update", func() { @@ -1671,8 +1717,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment failure", func() { @@ -1685,8 +1729,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.ErrorIs(err, errMock) s.NotContains(err.Error(), organization.AstronomerConnectionErrMsg) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment to enable dag deploy if already enabled", func() { @@ -1696,8 +1738,6 @@ func (s *Suite) TestUpdate() { //nolint err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("throw warning to enable dag deploy if ci-cd enforcement is enabled", func() { @@ -1711,25 +1751,28 @@ func (s *Suite) TestUpdate() { //nolint defer testUtil.MockUserInput(s.T(), "n")() err := Update("test-id-1", "", ws, "update", "", "enable", CeleryExecutor, "medium", "enable", "", "enable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment to disable dag deploy if already disabled", func() { + deploymentResponse.JSON200.IsDagDeployEnabled = false mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) - mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) - mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) err := Update("test-id-1", "", ws, "update", "", "disable", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment to change executor to KubernetesExecutor", func() { mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(3) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + // converting to hybrid deployment request type works for all three tests because the executor and worker queues are only being checked and + // it's common in all three deployment types, if we have to test more than we should break this into multiple test scenarios + request, err := input.AsUpdateHybridDeploymentRequest() + s.NoError(err) + return request.Executor == KUBERNETES && request.WorkerQueues == nil + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(3) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(3) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) @@ -1755,9 +1798,6 @@ func (s *Suite) TestUpdate() { //nolint err = Update("test-id-1", "", ws, "update", "", "", KubeExecutor, "medium", "enable", "", "disable", "", "", "", "", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - s.Equal((*[]astroplatformcore.WorkerQueueRequest)(nil), dedicatedDeploymentRequest.WorkerQueues) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("update deployment to change executor to CeleryExecutor", func() { @@ -1791,25 +1831,22 @@ func (s *Suite) TestUpdate() { //nolint deploymentResponse.JSON200.Type = &hybridType err = Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("do not update deployment if user says no to the executor change", func() { + // change type to hybrid + deploymentResponse.JSON200.Type = &hybridType + deploymentResponse.JSON200.Executor = &executorKubernetes + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Times(1) - deploymentResponse.JSON200.Executor = &executorKubernetes - defer testUtil.MockUserInput(s.T(), "n")() err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) }) s.Run("no node pools on hybrid cluster", func() { @@ -1824,17 +1861,43 @@ func (s *Suite) TestUpdate() { //nolint }, } - mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil) - mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil) - mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil) - mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponseWithNoNodePools, nil) + // change type to hybrid + deploymentResponse.JSON200.Type = &hybridType + deploymentResponse.JSON200.Executor = &executorKubernetes + + defer testUtil.MockUserInput(s.T(), "y")() + + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Once() + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() + mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponseWithNoNodePools, nil).Once() err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "medium", "enable", "", "disable", "2CPU", "2Gi", "2CPU", "2Gi", "", 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, false, mockCoreClient, mockPlatformCoreClient) s.NoError(err) + }) - mockCoreClient.AssertExpectations(s.T()) - mockPlatformCoreClient.AssertExpectations(s.T()) + s.Run("update workload identity for a hosted deployment", func() { + mockWorkloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + + // change type to dedicated + deploymentResponse.JSON200.Type = &dedicatedType + + // Set up mock responses and expectations + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateStandardDeploymentRequest() + assert.NoError(s.T(), err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == mockWorkloadIdentity + }, + )).Return(&mockUpdateDeploymentResponse, nil).Once() + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Once() + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() + + // Call the Create function with a non-empty workload ID + err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "small", "enable", "", "disable", "", "", "", "", mockWorkloadIdentity, 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, true, mockCoreClient, mockPlatformCoreClient) + s.NoError(err) }) } diff --git a/cloud/deployment/fromfile/fromfile.go b/cloud/deployment/fromfile/fromfile.go index 92a906cf1..bb3ea4047 100644 --- a/cloud/deployment/fromfile/fromfile.go +++ b/cloud/deployment/fromfile/fromfile.go @@ -337,6 +337,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c schedulerSize = astroplatformcore.CreateStandardDeploymentRequestSchedulerSizeEXTRALARGE } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + standardDeploymentRequest := astroplatformcore.CreateStandardDeploymentRequest{ AstroRuntimeVersion: deploymentFromFile.Deployment.Configuration.RunTimeVersion, CloudProvider: &requestedCloudProvider, @@ -356,6 +361,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c ResourceQuotaMemory: deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory, WorkerQueues: &listQueuesRequest, SchedulerSize: schedulerSize, + WorkloadIdentity: deplWorkloadIdentity, } if standardDeploymentRequest.IsDevelopmentMode != nil && *standardDeploymentRequest.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -391,6 +397,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c schedulerSize = astroplatformcore.CreateDedicatedDeploymentRequestSchedulerSizeEXTRALARGE } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + dedicatedDeploymentRequest := astroplatformcore.CreateDedicatedDeploymentRequest{ AstroRuntimeVersion: deploymentFromFile.Deployment.Configuration.RunTimeVersion, Description: &deploymentFromFile.Deployment.Configuration.Description, @@ -409,6 +420,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c ResourceQuotaMemory: deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory, WorkerQueues: &listQueuesRequest, SchedulerSize: schedulerSize, + WorkloadIdentity: deplWorkloadIdentity, } if dedicatedDeploymentRequest.IsDevelopmentMode != nil && *dedicatedDeploymentRequest.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -511,6 +523,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory = *existingDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + standardDeploymentRequest := astroplatformcore.UpdateStandardDeploymentRequest{ Description: &deploymentFromFile.Deployment.Configuration.Description, Name: deploymentFromFile.Deployment.Configuration.Name, @@ -528,6 +545,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c SchedulerSize: schedulerSize, ContactEmails: &deploymentFromFile.Deployment.AlertEmails, EnvironmentVariables: envVars, + WorkloadIdentity: deplWorkloadIdentity, } if existingDeployment.IsDevelopmentMode != nil && *existingDeployment.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) @@ -575,6 +593,11 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c deploymentFromFile.Deployment.Configuration.ResourceQuotaMemory = *existingDeployment.ResourceQuotaMemory } + var deplWorkloadIdentity *string + if deploymentFromFile.Deployment.Configuration.WorkloadIdentity != "" { + deplWorkloadIdentity = &deploymentFromFile.Deployment.Configuration.WorkloadIdentity + } + dedicatedDeploymentRequest := astroplatformcore.UpdateDedicatedDeploymentRequest{ Description: &deploymentFromFile.Deployment.Configuration.Description, Name: deploymentFromFile.Deployment.Configuration.Name, @@ -592,6 +615,7 @@ func createOrUpdateDeployment(deploymentFromFile *inspect.FormattedDeployment, c SchedulerSize: schedulerSize, ContactEmails: &deploymentFromFile.Deployment.AlertEmails, EnvironmentVariables: envVars, + WorkloadIdentity: deplWorkloadIdentity, } if existingDeployment.IsDevelopmentMode != nil && *existingDeployment.IsDevelopmentMode { hibernationSchedules := ToDeploymentHibernationSchedules(deploymentFromFile.Deployment.HibernationSchedules) diff --git a/cloud/deployment/fromfile/fromfile_test.go b/cloud/deployment/fromfile/fromfile_test.go index 672c2d04e..fb74eb79a 100644 --- a/cloud/deployment/fromfile/fromfile_test.go +++ b/cloud/deployment/fromfile/fromfile_test.go @@ -1258,7 +1258,8 @@ deployment: "deployment_type": "STANDARD", "region": "test-region", "cloud_provider": "aws", - "is_development_mode": true + "is_development_mode": true, + "workload_identity": "test-workload-identity" }, "worker_queues": [ { @@ -1307,7 +1308,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(2) - mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockCreateDeploymentResponse, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.CreateDeploymentRequest) bool { + request, err := input.AsCreateStandardDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.CreateStandardDeploymentRequestTypeSTANDARD + }, + )).Return(&mockCreateDeploymentResponse, nil).Once() mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) err = CreateOrUpdate("deployment.yaml", "create", mockPlatformCoreClient, mockCoreClient, out) @@ -1767,6 +1775,7 @@ deployment: scheduler_size: medium workspace_name: test-workspace deployment_type: STANDARD + workload_identity: test-workload-identity worker_queues: - name: default is_default: true @@ -1809,7 +1818,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateStandardDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.UpdateStandardDeploymentRequestTypeSTANDARD + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) err = CreateOrUpdate("deployment.yaml", "update", mockPlatformCoreClient, mockCoreClient, out) @@ -1853,6 +1869,7 @@ deployment: scheduler_size: medium workspace_name: test-workspace deployment_type: DEDICATED + workload_identity: test-workload-identity worker_queues: - name: default is_default: true @@ -1885,7 +1902,14 @@ deployment: mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseOK, nil).Times(1) mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsCreateResponse, nil).Times(3) - mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy( + func(input astroplatformcore.UpdateDeploymentRequest) bool { + request, err := input.AsUpdateDedicatedDeploymentRequest() + s.NoError(err) + return request.WorkloadIdentity != nil && *request.WorkloadIdentity == "test-workload-identity" && + request.Type == astroplatformcore.UpdateDedicatedDeploymentRequestTypeDEDICATED + }, + )).Return(&mockUpdateDeploymentResponse, nil).Times(1) mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Times(3) mockPlatformCoreClient.On("GetClusterWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockGetClusterResponse, nil).Once() mockPlatformCoreClient.On("ListClustersWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListClustersResponse, nil).Once() diff --git a/cloud/deployment/inspect/inspect.go b/cloud/deployment/inspect/inspect.go index 6d3a0c2a4..c679c88cd 100644 --- a/cloud/deployment/inspect/inspect.go +++ b/cloud/deployment/inspect/inspect.go @@ -272,7 +272,9 @@ func getDeploymentConfig(coreDeploymentPointer *astroplatformcore.Deployment, pl if coreDeployment.Region != nil { deploymentMap["region"] = *coreDeployment.Region } - + if coreDeployment.WorkloadIdentity != nil { + deploymentMap["workload_identity"] = *coreDeployment.WorkloadIdentity + } return deploymentMap, nil } diff --git a/cloud/deployment/inspect/inspect_test.go b/cloud/deployment/inspect/inspect_test.go index 979025b2b..0513017d6 100644 --- a/cloud/deployment/inspect/inspect_test.go +++ b/cloud/deployment/inspect/inspect_test.go @@ -458,10 +458,17 @@ func TestGetDeploymentInspectInfo(t *testing.T) { func TestGetDeploymentConfig(t *testing.T) { t.Run("returns deployment config for the requested cloud deployment", func(t *testing.T) { sourceDeployment.Type = &hybridType - sourceDeployment.WorkloadIdentity = &workloadIdentity cloudProvider := astroplatformcore.DeploymentCloudProviderAWS sourceDeployment.CloudProvider = &cloudProvider var actualDeploymentConfig deploymentConfig + sourceDeployment.WorkloadIdentity = &workloadIdentity + + // TODO: use test suite and have setup and teardown functions to make the tests stateless + defer func() { + // clear workload identity + sourceDeployment.WorkloadIdentity = nil + }() + testUtil.InitTestConfig(testUtil.LocalPlatform) expectedDeploymentConfig := deploymentConfig{ Name: sourceDeployment.Name, @@ -477,15 +484,13 @@ func TestGetDeploymentConfig(t *testing.T) { DeploymentType: string(*sourceDeployment.Type), CloudProvider: string(*sourceDeployment.CloudProvider), IsDevelopmentMode: *sourceDeployment.IsDevelopmentMode, + WorkloadIdentity: workloadIdentity, } rawDeploymentConfig, err := getDeploymentConfig(&sourceDeployment, mockPlatformCoreClient) assert.NoError(t, err) err = decodeToStruct(rawDeploymentConfig, &actualDeploymentConfig) assert.NoError(t, err) assert.Equal(t, expectedDeploymentConfig, actualDeploymentConfig) - - // clear workload identity - sourceDeployment.WorkloadIdentity = nil }) t.Run("returns deployment config for the requested cloud standard deployment", func(t *testing.T) { diff --git a/cmd/cloud/deployment.go b/cmd/cloud/deployment.go index acbbef1ea..cd3c1942e 100644 --- a/cmd/cloud/deployment.go +++ b/cmd/cloud/deployment.go @@ -395,6 +395,7 @@ func newDeploymentCreateCmd(out io.Writer) *cobra.Command { cmd.Flags().StringVarP(&inputFile, "deployment-file", "", "", "Location of file containing the Deployment to create. File can be in either JSON or YAML format.") cmd.Flags().BoolVarP(&waitForStatus, "wait", "i", false, "Wait for the Deployment to become healthy before ending the command") cmd.Flags().BoolVarP(&cleanOutput, "clean-output", "", false, "clean output to only include inspect yaml or json file in any situation.") + cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") if organization.IsOrgHosted() { cmd.Flags().StringVarP(&deploymentType, "cluster-type", "", standard, "The Cluster Type to use for the Deployment. Possible values can be standard or dedicated. This flag has been deprecated for the --type flag.") err := cmd.Flags().MarkDeprecated("cluster-type", "use --type instead") @@ -414,7 +415,6 @@ func newDeploymentCreateCmd(out io.Writer) *cobra.Command { } else { cmd.Flags().IntVarP(&schedulerAU, "scheduler-au", "s", 0, "The Deployment's scheduler resources in AUs") cmd.Flags().IntVarP(&schedulerReplicas, "scheduler-replicas", "r", 0, "The number of scheduler replicas for the Deployment") - cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") } cmd.Flags().StringVarP(&clusterID, "cluster-id", "c", "", "Cluster to create the Deployment in") return cmd @@ -445,6 +445,7 @@ func newDeploymentUpdateCmd(out io.Writer) *cobra.Command { cmd.Flags().StringVarP(&deploymentName, "deployment-name", "", "", "Name of the deployment to update") cmd.Flags().StringVarP(&dagDeploy, "dag-deploy", "", "", "Enables DAG-only deploys for the deployment") cmd.Flags().BoolVarP(&cleanOutput, "clean-output", "c", false, "clean output to only include inspect yaml or json file in any situation.") + cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") if organization.IsOrgHosted() { cmd.Flags().StringVarP(&schedulerSize, "scheduler-size", "", "", "The size of Scheduler for the Deployment. Possible values can be small, medium, large, extra_large") cmd.Flags().StringVarP(&highAvailability, "high-availability", "a", "", "Enables High Availability for the Deployment") @@ -456,7 +457,6 @@ func newDeploymentUpdateCmd(out io.Writer) *cobra.Command { } else { cmd.Flags().IntVarP(&updateSchedulerAU, "scheduler-au", "s", 0, "The Deployment's Scheduler resources in AUs.") cmd.Flags().IntVarP(&updateSchedulerReplicas, "scheduler-replicas", "r", 0, "The number of Scheduler replicas for the Deployment.") - cmd.Flags().StringVarP(&workloadIdentity, "workload-identity", "", "", "The Workload Identity to use for the Deployment") } return cmd } diff --git a/cmd/cloud/deployment_test.go b/cmd/cloud/deployment_test.go index b3495d7b6..85809d8f2 100644 --- a/cmd/cloud/deployment_test.go +++ b/cmd/cloud/deployment_test.go @@ -778,6 +778,33 @@ deployment: mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) }) + + t.Run("creates a hosted deployment with workload identity", func(t *testing.T) { + ctx, err := context.GetCurrentContext() + assert.NoError(t, err) + workloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + mockCreateDeploymentResponse.JSON200.WorkloadIdentity = &workloadIdentity + ctx.SetContextKey("organization_product", "HOSTED") + ctx.SetContextKey("organization", "test-org-id") + ctx.SetContextKey("workspace", ws) + ctx.SetContextKey("organization_short_name", "test-org") + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseAlphaOK, nil).Once() + mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once() + mockPlatformCoreClient.On("CreateDeploymentWithResponse", mock.Anything, mock.Anything, mock.MatchedBy(func(i astroplatformcore.CreateDeploymentRequest) bool { + input, _ := i.AsCreateStandardDeploymentRequest() + return input.WorkloadIdentity != nil && *input.WorkloadIdentity == workloadIdentity + })).Return(&mockCreateDeploymentResponse, nil).Once() + astroCoreClient = mockCoreClient + platformCoreClient = mockPlatformCoreClient + cmdArgs := []string{ + "create", "--name", "test-name", "--workspace-id", ws, "--type", "standard", "--workload-identity", workloadIdentity, "--cloud-provider", "aws", "--region", "us-west-2", + } + + _, err = execDeploymentCmd(cmdArgs...) + assert.NoError(t, err) + mockPlatformCoreClient.AssertExpectations(t) + mockCoreClient.AssertExpectations(t) + }) } func TestDeploymentUpdate(t *testing.T) { @@ -1037,6 +1064,36 @@ deployment: mockPlatformCoreClient.AssertExpectations(t) mockCoreClient.AssertExpectations(t) }) + + t.Run("updates a hosted deployment with workload identity", func(t *testing.T) { + ctx, err := context.GetCurrentContext() + assert.NoError(t, err) + ctx.SetContextKey("organization_product", "HOSTED") + ctx.SetContextKey("organization", "test-org-id") + ctx.SetContextKey("workspace", ws) + + workloadIdentity := "arn:aws:iam::1234567890:role/unit-test-1" + mockUpdateDeploymentResponse.JSON200.WorkloadIdentity = &workloadIdentity + + mockCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsResponseAlphaOK, nil).Once() + mockPlatformCoreClient.On("UpdateDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(i astroplatformcore.UpdateDeploymentRequest) bool { + input, _ := i.AsUpdateDedicatedDeploymentRequest() + return input.WorkloadIdentity != nil && *input.WorkloadIdentity == workloadIdentity + })).Return(&mockUpdateDeploymentResponse, nil).Times(1) + mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Times(1) + mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&hostedDeploymentResponse, nil).Times(1) + + astroCoreClient = mockCoreClient + platformCoreClient = mockPlatformCoreClient + cmdArgs := []string{ + "update", "test-id-1", "--name", "test-name", "--workload-identity", workloadIdentity, + } + + _, err = execDeploymentCmd(cmdArgs...) + assert.NoError(t, err) + mockPlatformCoreClient.AssertExpectations(t) + mockCoreClient.AssertExpectations(t) + }) } func TestDeploymentDelete(t *testing.T) { From d14a55ea8f78d04189dd623a521b63a5a4abd685 Mon Sep 17 00:00:00 2001 From: neel-astro Date: Mon, 26 Aug 2024 12:10:08 +0530 Subject: [PATCH 2/5] fix typo s/arguements/arguments --- cloud/deployment/deployment.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index 24af5f9ed..8fc51537b 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -211,7 +211,7 @@ func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logSch return nil } -// TODO: move these input arguements to a struct, and drop the nolint +// TODO: move these input arguments to a struct, and drop the nolint func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy, executor, cloudProvider, region, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, deploymentType astroplatformcore.DeploymentType, schedulerAU, schedulerReplicas int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, waitForStatus bool) error { //nolint var organizationID string var currentWorkspace astrocore.Workspace @@ -744,7 +744,7 @@ func HealthPoll(deploymentID, ws string, sleepTime, tickNum, timeoutNum int, pla } } -// TODO: move these input arguements to a struct, and drop the nolint +// TODO: move these input arguments to a struct, and drop the nolint func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) error { //nolint var queueCreateUpdate, confirmWithUser bool // get deployment From f30b593c54ec0b953ae37d4572c57280181eec6a Mon Sep 17 00:00:00 2001 From: neel-astro Date: Mon, 26 Aug 2024 22:02:14 +0530 Subject: [PATCH 3/5] drop workload identity changes from deployment inspect --- cloud/deployment/inspect/inspect.go | 3 --- cloud/deployment/inspect/inspect_test.go | 13 ++++--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/cloud/deployment/inspect/inspect.go b/cloud/deployment/inspect/inspect.go index c679c88cd..03e1ffe15 100644 --- a/cloud/deployment/inspect/inspect.go +++ b/cloud/deployment/inspect/inspect.go @@ -272,9 +272,6 @@ func getDeploymentConfig(coreDeploymentPointer *astroplatformcore.Deployment, pl if coreDeployment.Region != nil { deploymentMap["region"] = *coreDeployment.Region } - if coreDeployment.WorkloadIdentity != nil { - deploymentMap["workload_identity"] = *coreDeployment.WorkloadIdentity - } return deploymentMap, nil } diff --git a/cloud/deployment/inspect/inspect_test.go b/cloud/deployment/inspect/inspect_test.go index 0513017d6..979025b2b 100644 --- a/cloud/deployment/inspect/inspect_test.go +++ b/cloud/deployment/inspect/inspect_test.go @@ -458,17 +458,10 @@ func TestGetDeploymentInspectInfo(t *testing.T) { func TestGetDeploymentConfig(t *testing.T) { t.Run("returns deployment config for the requested cloud deployment", func(t *testing.T) { sourceDeployment.Type = &hybridType + sourceDeployment.WorkloadIdentity = &workloadIdentity cloudProvider := astroplatformcore.DeploymentCloudProviderAWS sourceDeployment.CloudProvider = &cloudProvider var actualDeploymentConfig deploymentConfig - sourceDeployment.WorkloadIdentity = &workloadIdentity - - // TODO: use test suite and have setup and teardown functions to make the tests stateless - defer func() { - // clear workload identity - sourceDeployment.WorkloadIdentity = nil - }() - testUtil.InitTestConfig(testUtil.LocalPlatform) expectedDeploymentConfig := deploymentConfig{ Name: sourceDeployment.Name, @@ -484,13 +477,15 @@ func TestGetDeploymentConfig(t *testing.T) { DeploymentType: string(*sourceDeployment.Type), CloudProvider: string(*sourceDeployment.CloudProvider), IsDevelopmentMode: *sourceDeployment.IsDevelopmentMode, - WorkloadIdentity: workloadIdentity, } rawDeploymentConfig, err := getDeploymentConfig(&sourceDeployment, mockPlatformCoreClient) assert.NoError(t, err) err = decodeToStruct(rawDeploymentConfig, &actualDeploymentConfig) assert.NoError(t, err) assert.Equal(t, expectedDeploymentConfig, actualDeploymentConfig) + + // clear workload identity + sourceDeployment.WorkloadIdentity = nil }) t.Run("returns deployment config for the requested cloud standard deployment", func(t *testing.T) { From 0d951b08f5d52196b46190af5774ddfc05239068 Mon Sep 17 00:00:00 2001 From: Neel Dalsania Date: Tue, 27 Aug 2024 20:45:50 +0530 Subject: [PATCH 4/5] Add issue to the TODO statement to keep track of it --- cloud/deployment/deployment.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cloud/deployment/deployment.go b/cloud/deployment/deployment.go index 8fc51537b..576e9303b 100644 --- a/cloud/deployment/deployment.go +++ b/cloud/deployment/deployment.go @@ -134,6 +134,7 @@ func List(ws string, fromAllWorkspaces bool, platformCoreClient astroplatformcor return nil } +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logScheduler, logTriggerer, logWorkers, warnLogs, errorLogs, infoLogs bool, logCount int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) error { var logLevel string var i int @@ -211,7 +212,7 @@ func Logs(deploymentID, ws, deploymentName, keyword string, logWebserver, logSch return nil } -// TODO: move these input arguments to a struct, and drop the nolint +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Create(name, workspaceID, description, clusterID, runtimeVersion, dagDeploy, executor, cloudProvider, region, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, deploymentType astroplatformcore.DeploymentType, schedulerAU, schedulerReplicas int, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, waitForStatus bool) error { //nolint var organizationID string var currentWorkspace astrocore.Workspace @@ -744,7 +745,7 @@ func HealthPoll(deploymentID, ws string, sleepTime, tickNum, timeoutNum int, pla } } -// TODO: move these input arguments to a struct, and drop the nolint +// TODO (https://github.com/astronomer/astro-cli/issues/1709): move these input arguments to a struct, and drop the nolint func Update(deploymentID, name, ws, description, deploymentName, dagDeploy, executor, schedulerSize, highAvailability, developmentMode, cicdEnforcement, defaultTaskPodCpu, defaultTaskPodMemory, resourceQuotaCpu, resourceQuotaMemory, workloadIdentity string, schedulerAU, schedulerReplicas int, wQueueList []astroplatformcore.WorkerQueueRequest, hybridQueueList []astroplatformcore.HybridWorkerQueueRequest, newEnvironmentVariables []astroplatformcore.DeploymentEnvironmentVariableRequest, force bool, coreClient astrocore.CoreClient, platformCoreClient astroplatformcore.CoreClient) error { //nolint var queueCreateUpdate, confirmWithUser bool // get deployment From a6b78bd8a35d1c80b9455c2462911c75e5f08f97 Mon Sep 17 00:00:00 2001 From: Neel Dalsania Date: Tue, 27 Aug 2024 20:46:27 +0530 Subject: [PATCH 5/5] code review changes: fix typo Co-authored-by: kushalmalani --- cloud/deployment/deployment_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloud/deployment/deployment_test.go b/cloud/deployment/deployment_test.go index d59ee57ea..8507bfb7b 100644 --- a/cloud/deployment/deployment_test.go +++ b/cloud/deployment/deployment_test.go @@ -1895,7 +1895,7 @@ func (s *Suite) TestUpdate() { //nolint mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once() mockPlatformCoreClient.On("GetDeploymentWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&deploymentResponse, nil).Once() - // Call the Create function with a non-empty workload ID + // Call the Update function with a non-empty workload ID err := Update("test-id-1", "", ws, "update", "", "", CeleryExecutor, "small", "enable", "", "disable", "", "", "", "", mockWorkloadIdentity, 0, 0, workerQueueRequest, hybridQueueList, newEnvironmentVariables, true, mockCoreClient, mockPlatformCoreClient) s.NoError(err) })