Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AWS DynamoDB Streams scaler configuration #6351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 18 additions & 51 deletions pkg/scalers/aws_dynamodb_streams_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scalers
import (
"context"
"fmt"
"strconv"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
Expand Down Expand Up @@ -31,11 +30,11 @@ type awsDynamoDBStreamsScaler struct {
}

type awsDynamoDBStreamsMetadata struct {
targetShardCount int64
activationTargetShardCount int64
tableName string
awsRegion string
awsEndpoint string
TargetShardCount int64 `keda:"name=shardCount, order=triggerMetadata, default=2"`
ActivationTargetShardCount int64 `keda:"name=activationShardCount, order=triggerMetadata, default=0"`
TableName string `keda:"name=tableName, order=triggerMetadata"`
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata"`
AwsEndpoint string `keda:"name=awsEndpoint, order=triggerMetadata, optional"`
awsAuthorization awsutils.AuthorizationMetadata
triggerIndex int
}
Expand All @@ -49,7 +48,7 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal

logger := InitializeLogger(config, "aws_dynamodb_streams_scaler")

meta, err := parseAwsDynamoDBStreamsMetadata(config, logger)
meta, err := parseAwsDynamoDBStreamsMetadata(config)
if err != nil {
return nil, fmt.Errorf("error parsing dynamodb stream metadata: %w", err)
}
Expand All @@ -58,7 +57,7 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
if err != nil {
return nil, fmt.Errorf("error when creating dynamodbstream client: %w", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.TableName)
if err != nil {
return nil, fmt.Errorf("error dynamodb stream arn: %w", err)
}
Expand All @@ -74,43 +73,11 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
}, nil
}

func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) (*awsDynamoDBStreamsMetadata, error) {
func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig) (*awsDynamoDBStreamsMetadata, error) {
meta := awsDynamoDBStreamsMetadata{}
meta.targetShardCount = defaultTargetDBStreamsShardCount

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, fmt.Errorf("no awsRegion given")
}

if val, ok := config.TriggerMetadata["awsEndpoint"]; ok {
meta.awsEndpoint = val
}

if val, ok := config.TriggerMetadata["tableName"]; ok && val != "" {
meta.tableName = val
} else {
return nil, fmt.Errorf("no tableName given")
}

if val, ok := config.TriggerMetadata["shardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.targetShardCount = defaultTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata shardCount, using default %n", defaultTargetDBStreamsShardCount)
} else {
meta.targetShardCount = shardCount
}
}
if val, ok := config.TriggerMetadata["activationShardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.activationTargetShardCount = defaultActivationTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata activationTargetShardCount, using default %n", defaultActivationTargetDBStreamsShardCount)
} else {
meta.activationTargetShardCount = shardCount
}
if err := config.TypedConfig(&meta); err != nil {
return nil, fmt.Errorf("error parsing dynamodb stream metadata: %w", err)
}

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
Expand All @@ -125,18 +92,18 @@ func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger
}

func createClientsForDynamoDBStreamsScaler(ctx context.Context, metadata *awsDynamoDBStreamsMetadata) (*dynamodb.Client, *dynamodbstreams.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
if err != nil {
return nil, nil, err
}
dbClient := dynamodb.NewFromConfig(*cfg, func(options *dynamodb.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})
dbStreamClient := dynamodbstreams.NewFromConfig(*cfg, func(options *dynamodbstreams.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})

Expand Down Expand Up @@ -176,9 +143,9 @@ func (s *awsDynamoDBStreamsScaler) Close(_ context.Context) error {
func (s *awsDynamoDBStreamsScaler) GetMetricSpecForScaling(_ context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.tableName))),
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.TableName))),
},
Target: GetMetricTarget(s.metricType, s.metadata.targetShardCount),
Target: GetMetricTarget(s.metricType, s.metadata.TargetShardCount),
}
metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType}
return []v2.MetricSpec{metricSpec}
Expand All @@ -195,7 +162,7 @@ func (s *awsDynamoDBStreamsScaler) GetMetricsAndActivity(ctx context.Context, me

metric := GenerateMetricInMili(metricName, float64(shardCount))

return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.activationTargetShardCount, nil
return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.ActivationTargetShardCount, nil
}

// GetDynamoDBStreamShardCount Get DynamoDB Stream Shard Count
Expand Down
84 changes: 37 additions & 47 deletions pkg/scalers/aws_dynamodb_streams_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -159,11 +159,11 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsEndpoint": testAWSDynamoDBStreamsEndpoint},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
awsEndpoint: testAWSDynamoDBStreamsEndpoint,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
AwsEndpoint: testAWSDynamoDBStreamsEndpoint,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -204,10 +204,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
activationTargetShardCount: defaultActivationTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: defaultTargetDBStreamsShardCount,
ActivationTargetShardCount: defaultActivationTargetDBStreamsShardCount,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -224,20 +224,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"tableName": testAWSDynamoDBSmallTable,
"shardCount": "a",
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
PodIdentityOwner: true,
},
triggerIndex: 4,
},
isError: false,
comment: "properly formed table name and region, wrong shard count",
wozniakjan marked this conversation as resolved.
Show resolved Hide resolved
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{},
isError: true,
comment: "invalid value - should cause error",
triggerIndex: 4,
},
{
Expand Down Expand Up @@ -278,9 +268,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsSessionToken": testAWSDynamoDBStreamsSessionToken,
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -330,9 +320,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRoleArn": testAWSDynamoDBStreamsRoleArn,
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsRoleArn: testAWSDynamoDBStreamsRoleArn,
PodIdentityOwner: true,
Expand All @@ -350,9 +340,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"identityOwner": "operator"},
authParams: map[string]string{},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
PodIdentityOwner: false,
},
Expand All @@ -370,15 +360,15 @@ var awsDynamoDBStreamMetricIdentifiers = []awsDynamoDBStreamsMetricIdentifier{
}

var awsDynamoDBStreamsGetMetricTestData = []*awsDynamoDBStreamsMetadata{
{tableName: testAWSDynamoDBBigTable},
{tableName: testAWSDynamoDBSmallTable},
{tableName: testAWSDynamoDBErrorTable},
{tableName: testAWSDynamoDBInvalidTable},
{TableName: testAWSDynamoDBBigTable},
{TableName: testAWSDynamoDBSmallTable},
{TableName: testAWSDynamoDBErrorTable},
{TableName: testAWSDynamoDBInvalidTable},
}

func TestParseAwsDynamoDBStreamsMetadata(t *testing.T) {
for _, testData := range testAwsDynamoDBStreamMetadata {
result, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
result, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil && !testData.isError {
t.Errorf("Expected success because %s got error, %s", testData.comment, err)
}
Expand All @@ -395,11 +385,11 @@ func TestParseAwsDynamoDBStreamsMetadata(t *testing.T) {
func TestAwsDynamoDBStreamsGetMetricSpecForScaling(t *testing.T) {
for _, testData := range awsDynamoDBStreamMetricIdentifiers {
ctx := context.Background()
meta, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
meta, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err != nil {
t.Fatal("Could not get dynamodb stream arn:", err)
}
Expand All @@ -418,12 +408,12 @@ func TestAwsDynamoDBStreamsScalerGetMetrics(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
value, _, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand All @@ -442,12 +432,12 @@ func TestAwsDynamoDBStreamsScalerIsActive(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
_, value, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand Down
Loading