Skip to content

Commit

Permalink
Fix bug where you cannot remove cluster workspace restrictions (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandyliu authored May 11, 2024
1 parent 97a5dcc commit 6a1193a
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 41 deletions.
5 changes: 4 additions & 1 deletion internal/provider/datasources/data_source_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ func (d *clustersDataSource) Read(
if len(data.CloudProvider.ValueString()) > 0 {
params.Provider = (*platform.ListClustersParamsProvider)(data.CloudProvider.ValueStringPointer())
}
params.Names, diags = utils.TypesSetToStringSlicePtr(ctx, data.Names)
names, diags := utils.TypesSetToStringSlice(ctx, data.Names)
if len(names) > 0 {
params.Names = &names
}
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down
15 changes: 12 additions & 3 deletions internal/provider/datasources/data_source_deployments.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,26 @@ func (d *deploymentsDataSource) Read(
Limit: lo.ToPtr(1000),
}
var diags diag.Diagnostics
params.DeploymentIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.DeploymentIds)
deploymentIds, diags := utils.TypesSetToStringSlice(ctx, data.DeploymentIds)
if len(deploymentIds) > 0 {
params.DeploymentIds = &deploymentIds
}
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
}
params.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
if len(workspaceIds) > 0 {
params.WorkspaceIds = &workspaceIds
}
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
}
params.Names, diags = utils.TypesSetToStringSlicePtr(ctx, data.Names)
names, diags := utils.TypesSetToStringSlice(ctx, data.Names)
if len(names) > 0 {
params.Names = &names
}
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down
10 changes: 8 additions & 2 deletions internal/provider/datasources/data_source_workspaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,19 @@ func (d *workspacesDataSource) Read(
Limit: lo.ToPtr(1000),
}
var diags diag.Diagnostics
params.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
if len(workspaceIds) > 0 {
params.WorkspaceIds = &workspaceIds
}
if diags.HasError() {
resp.Diagnostics.AddError("Client Error", fmt.Sprintf("Unable to read workspaces, got error %v", diags.Errors()[0].Summary()))
resp.Diagnostics.Append(diags...)
return
}
params.Names, diags = utils.TypesSetToStringSlicePtr(ctx, data.Names)
names, diags := utils.TypesSetToStringSlice(ctx, data.Names)
if len(names) > 0 {
params.Names = &names
}
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down
24 changes: 17 additions & 7 deletions internal/provider/models/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,23 @@ func (data *Deployment) ReadFromResponse(
data.SchedulerSize = types.StringPointerValue((*string)(deployment.SchedulerSize))
data.IsHighAvailability = types.BoolPointerValue(deployment.IsHighAvailability)
data.IsDevelopmentMode = types.BoolPointerValue(deployment.IsDevelopmentMode)
data.ScalingStatus, diags = ScalingStatusTypesObject(ctx, deployment.ScalingStatus)
if diags.HasError() {
return diags
}
data.ScalingSpec, diags = ScalingSpecTypesObject(ctx, deployment.ScalingSpec)
if diags.HasError() {
return diags

// Currently, the scaling status and spec are only available in development mode
// However, there is a bug in the API where the scaling status and spec are returned even if the deployment is not in development mode for updated deployments
// This is a workaround to handle the bug until the API is fixed
// Issue here: https://github.com/astronomer/astro/issues/21073
if deployment.IsDevelopmentMode != nil && *deployment.IsDevelopmentMode {
data.ScalingStatus, diags = ScalingStatusTypesObject(ctx, deployment.ScalingStatus)
if diags.HasError() {
return diags
}
data.ScalingSpec, diags = ScalingSpecTypesObject(ctx, deployment.ScalingSpec)
if diags.HasError() {
return diags
}
} else {
data.ScalingStatus = types.ObjectNull(schemas.ScalingStatusAttributeTypes())
data.ScalingSpec = types.ObjectNull(schemas.ScalingSpecAttributeTypes())
}

return nil
Expand Down
12 changes: 8 additions & 4 deletions internal/provider/resources/resource_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ func (r *ClusterResource) Create(
}

// workspaceIds
createAwsDedicatedClusterRequest.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
createAwsDedicatedClusterRequest.WorkspaceIds = &workspaceIds
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -135,7 +136,8 @@ func (r *ClusterResource) Create(
}

// workspaceIds
createAzureDedicatedClusterRequest.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
createAzureDedicatedClusterRequest.WorkspaceIds = &workspaceIds
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -166,7 +168,8 @@ func (r *ClusterResource) Create(
}

// workspaceIds
createGcpDedicatedClusterRequest.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
createGcpDedicatedClusterRequest.WorkspaceIds = &workspaceIds
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -318,7 +321,8 @@ func (r *ClusterResource) Update(
}

// workspaceIds
updateDedicatedClusterRequest.WorkspaceIds, diags = utils.TypesSetToStringSlicePtr(ctx, data.WorkspaceIds)
workspaceIds, diags := utils.TypesSetToStringSlice(ctx, data.WorkspaceIds)
updateDedicatedClusterRequest.WorkspaceIds = &workspaceIds
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down
14 changes: 8 additions & 6 deletions internal/provider/resources/resource_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ func TestAcc_ResourceClusterAwsWithDedicatedDeployments(t *testing.T) {
Config: astronomerprovider.ProviderConfig(t, true) +
workspace(workspaceName, workspaceName, utils.TestResourceDescription, false) +
cluster(clusterInput{
Name: awsClusterName,
Region: "us-east-1",
CloudProvider: "AWS",
DbInstanceType: "db.m6g.large",
Name: awsClusterName,
Region: "us-east-1",
CloudProvider: "AWS",
DbInstanceType: "db.m6g.large",
RestrictedWorkspaceResourceVarName: workspaceResourceVar,
}) +
dedicatedDeployment(dedicatedDeploymentInput{
ClusterResourceVar: awsResourceVar,
Expand All @@ -73,7 +74,7 @@ func TestAcc_ResourceClusterAwsWithDedicatedDeployments(t *testing.T) {
resource.TestCheckResourceAttr(awsResourceVar, "cloud_provider", "AWS"),
resource.TestCheckResourceAttr(awsResourceVar, "db_instance_type", "db.m6g.large"),
resource.TestCheckResourceAttrSet(awsResourceVar, "vpc_subnet_range"),
resource.TestCheckResourceAttr(awsResourceVar, "workspace_ids.#", "0"),
resource.TestCheckResourceAttr(awsResourceVar, "workspace_ids.#", "1"),

// Check via API that cluster exists
testAccCheckClusterExistence(t, awsClusterName, true, true),
Expand All @@ -88,7 +89,7 @@ func TestAcc_ResourceClusterAwsWithDedicatedDeployments(t *testing.T) {
testAccCheckDeploymentExistence(t, awsDeploymentName, true, true),
),
},
// Just update cluster
// Just update cluster and remove workspace restrictions
{
Config: astronomerprovider.ProviderConfig(t, true) +
workspace(workspaceName, workspaceName, utils.TestResourceDescription, false) +
Expand Down Expand Up @@ -128,6 +129,7 @@ func TestAcc_ResourceClusterAwsWithDedicatedDeployments(t *testing.T) {
),
},
// Change properties of cluster and deployment and check they have been updated in terraform state
// Add back workspace restrictions
{
Config: astronomerprovider.ProviderConfig(t, true) +
workspace(workspaceName, workspaceName, utils.TestResourceDescription, false) +
Expand Down
18 changes: 12 additions & 6 deletions internal/provider/resources/resource_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ func (r *DeploymentResource) Create(
}

// contact emails
createStandardDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
createStandardDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -203,7 +204,8 @@ func (r *DeploymentResource) Create(
}

// contact emails
createDedicatedDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
createDedicatedDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -260,7 +262,8 @@ func (r *DeploymentResource) Create(
}

// contact emails
createHybridDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
createHybridDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -412,7 +415,8 @@ func (r *DeploymentResource) Update(
}

// contact emails
updateStandardDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
updateStandardDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -469,7 +473,8 @@ func (r *DeploymentResource) Update(
}

// contact emails
updateDedicatedDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
updateDedicatedDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down Expand Up @@ -524,7 +529,8 @@ func (r *DeploymentResource) Update(
}

// contact emails
updateHybridDeploymentRequest.ContactEmails, diags = utils.TypesSetToStringSlicePtr(ctx, data.ContactEmails)
contactEmails, diags := utils.TypesSetToStringSlice(ctx, data.ContactEmails)
updateHybridDeploymentRequest.ContactEmails = &contactEmails
if diags.HasError() {
resp.Diagnostics.Append(diags...)
return
Expand Down
2 changes: 1 addition & 1 deletion internal/provider/resources/resource_deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ func standardDeployment(input standardDeploymentInput) string {
}]
override = {
is_hibernating = true
override_until = "2025-04-25T12:58:00+05:30"
override_until = "2030-04-25T12:58:00+05:30"
}
}
}`
Expand Down
10 changes: 4 additions & 6 deletions internal/utils/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ func ObjectSet[T any](ctx context.Context, values *[]T, objectAttributeTypes map
return types.SetValue(types.ObjectType{AttrTypes: objectAttributeTypes}, objs)
}

// TypesSetToStringSlicePtr converts a types.Set to a pointer to a slice of strings
// TypesSetToStringSlice converts a types.Set to a slice of strings
// This is useful for converting a set of strings from the Terraform framework to a slice of strings used for calling the API
// We prefer to use a pointer to a slice of strings because our API client query params usually have type *[]string
// and we can easily assign the query param to the result of this function (regardless if the result is nil or not)
func TypesSetToStringSlicePtr(ctx context.Context, s types.Set) (*[]string, diag.Diagnostics) {
func TypesSetToStringSlice(ctx context.Context, s types.Set) ([]string, diag.Diagnostics) {
if len(s.Elements()) == 0 {
return nil, nil
return []string{}, nil
}
var typesStringSlice []types.String
diags := s.ElementsAs(ctx, &typesStringSlice, false)
Expand All @@ -55,5 +53,5 @@ func TypesSetToStringSlicePtr(ctx context.Context, s types.Set) (*[]string, diag
resp := lo.Map(typesStringSlice, func(v types.String, _ int) string {
return v.ValueString()
})
return &resp, nil
return resp, nil
}
10 changes: 5 additions & 5 deletions internal/utils/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ func TestUnit_StringSet(t *testing.T) {
})
}

func TestUnit_TypesSetToStringSlicePtr(t *testing.T) {
func TestUnit_TypesSetToStringSlice(t *testing.T) {
t.Run("empty", func(t *testing.T) {
s := types.SetValueMust(types.StringType, []attr.Value{})

result, diags := utils.TypesSetToStringSlicePtr(context.Background(), s)
result, diags := utils.TypesSetToStringSlice(context.Background(), s)
assert.Nil(t, diags)
assert.Nil(t, result)
assert.Empty(t, result)
})

t.Run("with values", func(t *testing.T) {
s := types.SetValueMust(types.StringType, []attr.Value{types.StringValue("string1"), types.StringValue("string2")})

expected := &[]string{"string1", "string2"}
result, diags := utils.TypesSetToStringSlicePtr(context.Background(), s)
expected := []string{"string1", "string2"}
result, diags := utils.TypesSetToStringSlice(context.Background(), s)
assert.Nil(t, diags)
assert.Equal(t, expected, result)
})
Expand Down

0 comments on commit 6a1193a

Please sign in to comment.