Skip to content

Commit

Permalink
Fix deployment selection filter and validation logic (#1573)
Browse files Browse the repository at this point in the history
* fix deployment selection filter logic

* fix edge cases

* fix linter errors

* add new unit tests

* code review changes: fix error naming and usage

* code review changes: use lower case and replace the message in inspect package

* add missing unit tests in wq package

* code review changes: s/workspaces/Workspace
  • Loading branch information
neel-astro committed Feb 29, 2024
1 parent 61efb70 commit ad44036
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 27 deletions.
50 changes: 27 additions & 23 deletions cloud/deployment/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,25 @@ var (
)

const (
noWorkspaceMsg = "no workspaces with id (%s) found"
KubeExecutor = "KubernetesExecutor"
CeleryExecutor = "CeleryExecutor"
KUBERNETES = "KUBERNETES"
CELERY = "CELERY"
notApplicable = "N/A"
gcpCloud = "gcp"
awsCloud = "aws"
azureCloud = "azure"
standard = "standard"
LargeScheduler = "large"
MediumScheduler = "medium"
SmallScheduler = "small"
SMALL = "SMALL"
MEDIUM = "MEDIUM"
LARGE = "LARGE"
disable = "disable"
enable = "enable"
NoDeploymentInWSMsg = "no Deployments found in workspace"
noWorkspaceMsg = "no Workspace with id (%s) found"
KubeExecutor = "KubernetesExecutor"
CeleryExecutor = "CeleryExecutor"
KUBERNETES = "KUBERNETES"
CELERY = "CELERY"
notApplicable = "N/A"
gcpCloud = "gcp"
awsCloud = "aws"
azureCloud = "azure"
standard = "standard"
LargeScheduler = "large"
MediumScheduler = "medium"
SmallScheduler = "small"
SMALL = "SMALL"
MEDIUM = "MEDIUM"
LARGE = "LARGE"
disable = "disable"
enable = "enable"
)

var (
Expand Down Expand Up @@ -124,7 +125,7 @@ func List(ws string, fromAllWorkspaces bool, platformCoreClient astroplatformcor
}

if len(deployments) == 0 {
fmt.Printf("No Deployments found in workspace %s\n", ansi.Bold(ws))
fmt.Printf("%s %s\n", NoDeploymentInWSMsg, ansi.Bold(ws))
return nil
}

Expand Down Expand Up @@ -1155,7 +1156,7 @@ func Delete(deploymentID, ws, deploymentName string, forceDelete bool, platformC
}

if currentDeployment.Id == "" {
fmt.Printf("No Deployments found in workspace %s to delete\n", ansi.Bold(ws))
fmt.Printf("%s %s to delete\n", NoDeploymentInWSMsg, ansi.Bold(ws))
return nil
}

Expand Down Expand Up @@ -1195,7 +1196,7 @@ func UpdateDeploymentHibernationOverride(deploymentID, ws, deploymentName string
return err
}
if currentDeployment.Id == "" {
fmt.Printf("No Deployments found in workspace %s to %s\n", ansi.Bold(ws), action)
fmt.Printf("%s %s to %s\n", NoDeploymentInWSMsg, ansi.Bold(ws), action)
return nil
}

Expand Down Expand Up @@ -1623,7 +1624,7 @@ func GetDeployment(ws, deploymentID, deploymentName string, disableCreateFlow bo

// select deployment if deploymentID is empty
if deploymentID == "" {
currentDeployment, err = deploymentSelectionProcess(ws, deployments, selectionFilter, platformCoreClient, coreClient)
currentDeployment, err = deploymentSelectionProcess(ws, deployments, selectionFilter, platformCoreClient, coreClient, disableCreateFlow)
if err != nil {
return astroplatformcore.Deployment{}, err
}
Expand All @@ -1645,11 +1646,14 @@ func GetDeployment(ws, deploymentID, deploymentName string, disableCreateFlow bo
return currentDeployment, nil
}

func deploymentSelectionProcess(ws string, deployments []astroplatformcore.Deployment, deploymentFilter func(deployment astroplatformcore.Deployment) bool, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient) (astroplatformcore.Deployment, error) {
func deploymentSelectionProcess(ws string, deployments []astroplatformcore.Deployment, deploymentFilter func(deployment astroplatformcore.Deployment) bool, platformCoreClient astroplatformcore.CoreClient, coreClient astrocore.CoreClient, disableCreateFlow bool) (astroplatformcore.Deployment, error) {
// filter deployments
if deploymentFilter != nil {
deployments = util.Filter(deployments, deploymentFilter)
}
if len(deployments) == 0 && disableCreateFlow {
return astroplatformcore.Deployment{}, fmt.Errorf("%s %s", NoDeploymentInWSMsg, ws) //nolint:goerr113
}
currentDeployment, err := SelectDeployment(deployments, "Select a Deployment")
if err != nil {
return astroplatformcore.Deployment{}, err
Expand Down
29 changes: 28 additions & 1 deletion cloud/deployment/deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ func TestCreate(t *testing.T) {
mockCoreClient.On("ListWorkspacesWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&ListWorkspacesResponseOK, nil).Once()

err := Create("test-name", "wrong-workspace", "test-desc", csID, "4.2.5", dagDeploy, CeleryExecutor, "", "", "", "", "", "", "", "", "", "", "", astroplatformcore.DeploymentTypeHYBRID, 0, 0, mockPlatformCoreClient, mockCoreClient, false)
assert.ErrorContains(t, err, "no workspaces with id")
assert.ErrorContains(t, err, "no Workspace with id")
mockCoreClient.AssertExpectations(t)
})
}
Expand Down Expand Up @@ -1934,6 +1934,33 @@ func TestUpdateDeploymentHibernationOverride(t *testing.T) {
mockPlatformCoreClient.AssertExpectations(t)
})

t.Run("returns an error if none of the deployments are in development mode", func(t *testing.T) {
mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface)

var deploymentList []astroplatformcore.Deployment
for _, deployment := range mockCoreDeploymentResponse {
if deployment.IsDevelopmentMode != nil && *deployment.IsDevelopmentMode == true {
continue
}
deploymentList = append(deploymentList, deployment)
}
mockDeploymentListResponse := astroplatformcore.ListDeploymentsResponse{
HTTPResponse: &http.Response{
StatusCode: 200,
},
JSON200: &astroplatformcore.DeploymentsPaginated{
Deployments: deploymentList,
},
}

mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockDeploymentListResponse, nil).Once()

err := UpdateDeploymentHibernationOverride("", ws, "", true, nil, false, mockPlatformCoreClient)
assert.Error(t, err)
assert.Equal(t, err.Error(), fmt.Sprintf("%s %s", NoDeploymentInWSMsg, ws))
mockPlatformCoreClient.AssertExpectations(t)
})

t.Run("cancels if requested", func(t *testing.T) {
mockPlatformCoreClient = new(astroplatformcore_mocks.ClientWithResponsesInterface)
mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockListDeploymentsResponse, nil).Once()
Expand Down
2 changes: 1 addition & 1 deletion cloud/deployment/inspect/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func Inspect(wsID, deploymentName, deploymentID, outputFormat string, platformCo
}

if requestedDeployment.Id == "" {
fmt.Printf("No Deployments found in workspace %s\n", ansi.Bold(wsID))
fmt.Printf("%s %s\n", deployment.NoDeploymentInWSMsg, ansi.Bold(wsID))
return nil
}
// create a map for deployment.information
Expand Down
4 changes: 2 additions & 2 deletions cloud/deployment/workerqueue/workerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func CreateOrUpdate(ws, deploymentID, deploymentName, name, action, workerType s
}

if requestedDeployment.Id == "" {
fmt.Printf("No Deployments found in workspace %s\n", ansi.Bold(ws))
fmt.Printf("%s %s\n", deployment.NoDeploymentInWSMsg, ansi.Bold(ws))
return nil
}

Expand Down Expand Up @@ -517,7 +517,7 @@ func Delete(ws, deploymentID, deploymentName, name string, force bool, platformC
}

if requestedDeployment.Id == "" {
fmt.Printf("No Deployments found in workspace %s\n", ansi.Bold(ws))
fmt.Printf("%s %s\n", deployment.NoDeploymentInWSMsg, ansi.Bold(ws))
return nil
}

Expand Down
16 changes: 16 additions & 0 deletions cloud/deployment/workerqueue/workerqueue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ func TestCreate(t *testing.T) {
assert.ErrorIs(t, err, deployment.ErrInvalidDeploymentKey)
mockPlatformCoreClient.AssertExpectations(t)
})
t.Run("exit early when no deployments in the workspace", func(t *testing.T) {
out := new(bytes.Buffer)
defer testUtil.MockUserInput(t, "test-invalid-deployment-id")()
mockDeploymentListResponse := astroplatformcore.ListDeploymentsResponse{
HTTPResponse: &http.Response{
StatusCode: 200,
},
JSON200: &astroplatformcore.DeploymentsPaginated{
Deployments: []astroplatformcore.Deployment{},
},
}
mockPlatformCoreClient.On("ListDeploymentsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&mockDeploymentListResponse, nil).Times(1)
err := CreateOrUpdate("test-ws-id", "", "", "", "", "", 0, 0, 0, false, mockPlatformCoreClient, mockCoreClient, out)
assert.NoError(t, err)
mockPlatformCoreClient.AssertExpectations(t)
})
t.Run("returns an error when selecting a node pool fails", func(t *testing.T) {
out := new(bytes.Buffer)
mockPlatformCoreClient.On("GetDeploymentOptionsWithResponse", mock.Anything, mock.Anything, mock.Anything).Return(&GetDeploymentOptionsPlatformResponseOK, nil).Times(1)
Expand Down

0 comments on commit ad44036

Please sign in to comment.