Skip to content

Commit

Permalink
Use correct cloud environment in azcore client (#3799)
Browse files Browse the repository at this point in the history
When instantiating the new azcore-based Azure client, the cloud
environment (public, usgov, etc.) was looked up in a wrong way. Since it
defaults to "public", it didn't error and worked for most users. This PR
refactors and adds more tests for more solid coverage.

Fixes #3795.
  • Loading branch information
thomas11 authored Dec 17, 2024
1 parent f89092f commit c57661c
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 26 deletions.
23 changes: 23 additions & 0 deletions provider/pkg/provider/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/go-autorest/autorest"
azureEnv "github.com/Azure/go-autorest/autorest/azure"
"github.com/hashicorp/go-azure-helpers/authentication"
"github.com/hashicorp/go-azure-helpers/sender"
"github.com/manicminer/hamilton/environments"
"github.com/pkg/errors"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/azure"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/logging"

goversion "github.com/hashicorp/go-version"
Expand All @@ -29,6 +32,26 @@ type authConfig struct {
useCli bool
}

func (a *authConfig) autorestEnvironment() (azureEnv.Environment, error) {
envName := a.Environment
env, err := azureEnv.EnvironmentFromName(envName)
if err != nil {
env, err = azureEnv.EnvironmentFromName(fmt.Sprintf("AZURE%sCLOUD", envName))
if err != nil {
return azureEnv.Environment{}, errors.Wrapf(err, "environment %q was not found", envName)
}
}
return env, nil
}

func (a *authConfig) cloud() azcloud.Configuration {
cloudName := "public"
if a.Config != nil && a.Config.Environment != "" {
cloudName = a.Config.Environment
}
return azure.GetCloudByName(cloudName)
}

type oidcConfig struct {
oidcToken string
oidcTokenFilePath string
Expand Down
9 changes: 1 addition & 8 deletions provider/pkg/provider/auth_azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/pkg/errors"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/azure"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/logging"
)

Expand Down Expand Up @@ -263,17 +262,11 @@ func (k *azureNativeProvider) readAuthConfig() (*authConfiguration, error) {
}
}

cloudName := k.getConfig("environment", "ARM_ENVIRONMENT")
if cloudName == "" {
cloudName = "public"
}
cloud := azure.GetCloudByName(cloudName)

return &authConfiguration{
clientId: k.getConfig("clientId", "ARM_CLIENT_ID"),
tenantId: k.getConfig("tenantId", "ARM_TENANT_ID"),
auxTenants: auxTenants,
cloud: cloud,
cloud: k.cloud,

clientSecret: k.getConfig("clientSecret", "ARM_CLIENT_SECRET"),
clientCertPath: k.getConfig("clientCertificatePath", "ARM_CLIENT_CERTIFICATE_PATH"),
Expand Down
17 changes: 8 additions & 9 deletions provider/pkg/provider/auth_azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ import (
var testPfxCert []byte

func TestGetAuthConfig(t *testing.T) {
setAuthEnvVariables := func(value, boolValue, cloudValue string) {
setAuthEnvVariables := func(value, boolValue string) {
if value != "" {
t.Setenv("ARM_AUXILIARY_TENANT_IDS", `["`+value+`"]`)
}
t.Setenv("ARM_CLIENT_CERTIFICATE_PASSWORD", value)
t.Setenv("ARM_CLIENT_CERTIFICATE_PATH", value)
t.Setenv("ARM_CLIENT_ID", value)
t.Setenv("ARM_CLIENT_SECRET", value)
t.Setenv("ARM_ENVIRONMENT", cloudValue)
t.Setenv("ARM_OIDC_TOKEN", value)
t.Setenv("ARM_OIDC_TOKEN_FILE_PATH", value)
t.Setenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN", value)
Expand All @@ -40,7 +39,7 @@ func TestGetAuthConfig(t *testing.T) {
}

t.Run("empty", func(t *testing.T) {
setAuthEnvVariables("", "", "")
setAuthEnvVariables("", "")
p := azureNativeProvider{}
c, err := p.readAuthConfig()
require.NoError(t, err)
Expand All @@ -50,7 +49,6 @@ func TestGetAuthConfig(t *testing.T) {
require.Empty(t, c.clientCertPath)
require.Empty(t, c.clientId)
require.Empty(t, c.clientSecret)
require.Equal(t, cloud.AzurePublic, c.cloud)
require.Empty(t, c.oidcToken)
require.Empty(t, c.oidcTokenFilePath)
require.Empty(t, c.oidcTokenRequestToken)
Expand All @@ -61,7 +59,7 @@ func TestGetAuthConfig(t *testing.T) {
})

t.Run("values from config take precedence", func(t *testing.T) {
setAuthEnvVariables("env", "false", "china")
setAuthEnvVariables("env", "false")

p := azureNativeProvider{
config: map[string]string{
Expand All @@ -79,6 +77,7 @@ func TestGetAuthConfig(t *testing.T) {
"useMsi": "true",
"useOidc": "true",
},
cloud: cloud.AzureGovernment,
}

c, err := p.readAuthConfig()
Expand All @@ -89,7 +88,6 @@ func TestGetAuthConfig(t *testing.T) {
require.Equal(t, "conf", c.clientCertPath)
require.Equal(t, "conf", c.clientId)
require.Equal(t, "conf", c.clientSecret)
require.Equal(t, cloud.AzureGovernment, c.cloud)
require.Equal(t, "conf", c.oidcToken)
require.Equal(t, "conf", c.oidcTokenFilePath)
require.Equal(t, "conf", c.oidcTokenRequestToken)
Expand All @@ -100,8 +98,10 @@ func TestGetAuthConfig(t *testing.T) {
})

t.Run("values from env", func(t *testing.T) {
p := azureNativeProvider{}
setAuthEnvVariables("env", "true", "china")
p := azureNativeProvider{
cloud: cloud.AzureChina,
}
setAuthEnvVariables("env", "true")

c, err := p.readAuthConfig()
require.NoError(t, err)
Expand All @@ -111,7 +111,6 @@ func TestGetAuthConfig(t *testing.T) {
require.Equal(t, "env", c.clientCertPath)
require.Equal(t, "env", c.clientId)
require.Equal(t, "env", c.clientSecret)
require.Equal(t, cloud.AzureChina, c.cloud)
require.Equal(t, "env", c.oidcToken)
require.Equal(t, "env", c.oidcTokenFilePath)
require.Equal(t, "env", c.oidcTokenRequestToken)
Expand Down
112 changes: 112 additions & 0 deletions provider/pkg/provider/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"os"
"testing"

azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/hashicorp/go-azure-helpers/authentication"
goversion "github.com/hashicorp/go-version"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -259,3 +261,113 @@ func TestAuthConfigs(t *testing.T) {
})
}
}

func TestGetAuthConfigCloud(t *testing.T) {
t.Run("Default is public", func(t *testing.T) {
p := azureNativeProvider{
config: map[string]string{
"useMsi": "true",
},
}
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzurePublic, conf.cloud())
})

t.Run("Public", func(t *testing.T) {
p := azureNativeProvider{
config: map[string]string{
"environment": "public",
"useMsi": "true",
},
}
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzurePublic, conf.cloud())
})

t.Run("China", func(t *testing.T) {
p := azureNativeProvider{
config: map[string]string{
"environment": "china",
"useMsi": "true",
},
}
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzureChina, conf.cloud())
})

t.Run("US Government", func(t *testing.T) {
p := azureNativeProvider{
config: map[string]string{
"environment": "usgovernment",
"useMsi": "true",
},
}
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzureGovernment, conf.cloud())
})

t.Run("Public from environment variable", func(t *testing.T) {
p := azureNativeProvider{}
t.Setenv("ARM_ENVIRONMENT", "public")
t.Setenv("ARM_USE_MSI", "true")
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzurePublic, conf.cloud())
})

t.Run("China from environment variable", func(t *testing.T) {
p := azureNativeProvider{}
t.Setenv("ARM_ENVIRONMENT", "china")
t.Setenv("ARM_USE_MSI", "true")
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzureChina, conf.cloud())
})

t.Run("US Government from environment variable", func(t *testing.T) {
p := azureNativeProvider{}
t.Setenv("ARM_ENVIRONMENT", "usgovernment")
t.Setenv("ARM_USE_MSI", "true")
conf, err := p.getAuthConfig()
require.NoError(t, err)
assert.Equal(t, azcloud.AzureGovernment, conf.cloud())
})
}

func TestGetCloud(t *testing.T) {
t.Run("Public", func(t *testing.T) {
a := authConfig{
Config: &authentication.Config{
Environment: "public",
},
}
assert.Equal(t, azcloud.AzurePublic, a.cloud())
})

t.Run("China", func(t *testing.T) {
a := authConfig{
Config: &authentication.Config{
Environment: "china",
},
}
assert.Equal(t, azcloud.AzureChina, a.cloud())
})

t.Run("US Government", func(t *testing.T) {
a := authConfig{
Config: &authentication.Config{
Environment: "usgov",
},
}
assert.Equal(t, azcloud.AzureGovernment, a.cloud())
})

t.Run("Unknown", func(t *testing.T) {
a := authConfig{}
assert.Equal(t, azcloud.AzurePublic, a.cloud())
})
}
17 changes: 8 additions & 9 deletions provider/pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"google.golang.org/protobuf/types/known/emptypb"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
azcloud "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/go-autorest/autorest"
azureEnv "github.com/Azure/go-autorest/autorest/azure"
Expand Down Expand Up @@ -61,6 +62,7 @@ type azureNativeProvider struct {
version string
subscriptionID string
environment azureEnv.Environment
cloud azcloud.Configuration
resourceMap *resources.PartialAzureAPIMetadata
config map[string]string
schemaBytes []byte
Expand Down Expand Up @@ -183,15 +185,12 @@ func (k *azureNativeProvider) Configure(ctx context.Context,
return nil, err
}

envName := authConfig.Environment
env, err := azureEnv.EnvironmentFromName(envName)
k.environment, err = authConfig.autorestEnvironment()
if err != nil {
env, err = azureEnv.EnvironmentFromName(fmt.Sprintf("AZURE%sCLOUD", envName))
if err != nil {
return nil, errors.Wrapf(err, "environment %q was not found", envName)
}
return nil, err
}
k.environment = env

k.cloud = authConfig.cloud()

hamiltonEnv := k.autorestEnvToHamiltonEnv()

Expand Down Expand Up @@ -239,7 +238,7 @@ func (k *azureNativeProvider) Configure(ctx context.Context,
return nil, fmt.Errorf("creating Azure client: %w", err)
}

k.customResources, err = customresources.BuildCustomResources(&env, k.azureClient, k.LookupResource, k.newCrudClient, k.subscriptionID,
k.customResources, err = customresources.BuildCustomResources(&k.environment, k.azureClient, k.LookupResource, k.newCrudClient, k.subscriptionID,
resourceManagerBearerAuth, resourceManagerAuth, keyVaultBearerAuth, userAgent, credential)
if err != nil {
return nil, fmt.Errorf("initializing custom resources: %w", err)
Expand All @@ -256,7 +255,7 @@ func (k *azureNativeProvider) Configure(ctx context.Context,
func (k *azureNativeProvider) newAzureClient(armAuth autorest.Authorizer, tokenCred azcore.TokenCredential, userAgent string) (azure.AzureClient, error) {
if util.EnableAzcoreBackend() {
logging.V(9).Infof("AzureClient: using azcore and azidentity")
return azure.NewAzCoreClient(tokenCred, userAgent, azure.GetCloudByName(k.environment.Name), nil)
return azure.NewAzCoreClient(tokenCred, userAgent, k.cloud, nil)
}
logging.V(9).Infof("AzureClient: using autorest")
return azure.NewAzureClient(k.environment, armAuth, userAgent), nil
Expand Down
50 changes: 50 additions & 0 deletions provider/pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"reflect"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/fake"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/convert"
Expand Down Expand Up @@ -422,6 +423,55 @@ func TestUsesCorrectAzureClient(t *testing.T) {
})
}

func TestAzcoreAzureClientUsesCorrectCloud(t *testing.T) {
for expectedHost, cloudInstance := range map[string]cloud.Configuration{
"https://management.azure.com": cloud.AzurePublic,
"https://management.chinacloudapi.cn": cloud.AzureChina,
"https://management.usgovcloudapi.net": cloud.AzureGovernment,
} {
p := azureNativeProvider{
cloud: cloudInstance,
}

client, err := p.newAzureClient(nil, &fake.TokenCredential{}, "pulumi")
require.NoError(t, err)
require.NotNil(t, client)

// Use reflection to get the value of the private 'host' field
clientValue := reflect.ValueOf(client).Elem()
hostField := clientValue.FieldByName("host")
require.True(t, hostField.IsValid(), "host field should be valid", expectedHost)

assert.Equal(t, expectedHost, hostField.String())
}
}

func TestAutorestAzureClientUsesCorrectCloud(t *testing.T) {
for expectedEnv, environment := range map[string]azure.Environment{
azure.PublicCloud.Name: azure.PublicCloud,
azure.ChinaCloud.Name: azure.ChinaCloud,
azure.USGovernmentCloud.Name: azure.USGovernmentCloud,
} {
p := azureNativeProvider{
environment: environment,
}
t.Setenv("PULUMI_ENABLE_AZCORE_BACKEND", "false")

client, err := p.newAzureClient(nil, nil, "pulumi")
require.NoError(t, err)
require.NotNil(t, client)

// Use reflection to get the value of the private 'environment' field
clientValue := reflect.ValueOf(client).Elem()
environmentField := clientValue.FieldByName("environment")
require.True(t, environmentField.IsValid(), "environment field should be valid")
nameField := environmentField.FieldByName("Name")
require.True(t, nameField.IsValid(), "environment.name field should be valid")

assert.Equal(t, expectedEnv, nameField.String())
}
}

type mockAzureClient struct {
getIds []string
}
Expand Down

0 comments on commit c57661c

Please sign in to comment.