Skip to content

Commit

Permalink
Refactor AWS DynamoDB Streams scaler configuration (kedacore#6351)
Browse files Browse the repository at this point in the history
* Refactor AWS DynamoDB Streams scaler configuration

Signed-off-by: Omer Aplatony <[email protected]>

* fixed unit tests

Signed-off-by: Omer Aplatony <[email protected]>

* Fix invalid value test

Signed-off-by: Omer Aplatony <[email protected]>

* go fmt

Signed-off-by: Omer Aplatony <[email protected]>

---------

Signed-off-by: Omer Aplatony <[email protected]>
  • Loading branch information
omerap12 authored and rickbrouwer committed Dec 4, 2024
1 parent 5954970 commit 362614b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 98 deletions.
69 changes: 18 additions & 51 deletions pkg/scalers/aws_dynamodb_streams_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package scalers
import (
"context"
"fmt"
"strconv"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
Expand Down Expand Up @@ -31,11 +30,11 @@ type awsDynamoDBStreamsScaler struct {
}

type awsDynamoDBStreamsMetadata struct {
targetShardCount int64
activationTargetShardCount int64
tableName string
awsRegion string
awsEndpoint string
TargetShardCount int64 `keda:"name=shardCount, order=triggerMetadata, default=2"`
ActivationTargetShardCount int64 `keda:"name=activationShardCount, order=triggerMetadata, default=0"`
TableName string `keda:"name=tableName, order=triggerMetadata"`
AwsRegion string `keda:"name=awsRegion, order=triggerMetadata"`
AwsEndpoint string `keda:"name=awsEndpoint, order=triggerMetadata, optional"`
awsAuthorization awsutils.AuthorizationMetadata
triggerIndex int
}
Expand All @@ -49,7 +48,7 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal

logger := InitializeLogger(config, "aws_dynamodb_streams_scaler")

meta, err := parseAwsDynamoDBStreamsMetadata(config, logger)
meta, err := parseAwsDynamoDBStreamsMetadata(config)
if err != nil {
return nil, fmt.Errorf("error parsing dynamodb stream metadata: %w", err)
}
Expand All @@ -58,7 +57,7 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
if err != nil {
return nil, fmt.Errorf("error when creating dynamodbstream client: %w", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, dbClient, &meta.TableName)
if err != nil {
return nil, fmt.Errorf("error dynamodb stream arn: %w", err)
}
Expand All @@ -74,43 +73,11 @@ func NewAwsDynamoDBStreamsScaler(ctx context.Context, config *scalersconfig.Scal
}, nil
}

func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) (*awsDynamoDBStreamsMetadata, error) {
func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig) (*awsDynamoDBStreamsMetadata, error) {
meta := awsDynamoDBStreamsMetadata{}
meta.targetShardCount = defaultTargetDBStreamsShardCount

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, fmt.Errorf("no awsRegion given")
}

if val, ok := config.TriggerMetadata["awsEndpoint"]; ok {
meta.awsEndpoint = val
}

if val, ok := config.TriggerMetadata["tableName"]; ok && val != "" {
meta.tableName = val
} else {
return nil, fmt.Errorf("no tableName given")
}

if val, ok := config.TriggerMetadata["shardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.targetShardCount = defaultTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata shardCount, using default %n", defaultTargetDBStreamsShardCount)
} else {
meta.targetShardCount = shardCount
}
}
if val, ok := config.TriggerMetadata["activationShardCount"]; ok && val != "" {
shardCount, err := strconv.ParseInt(val, 10, 64)
if err != nil {
meta.activationTargetShardCount = defaultActivationTargetDBStreamsShardCount
logger.Error(err, "error parsing dyanmodb stream metadata activationTargetShardCount, using default %n", defaultActivationTargetDBStreamsShardCount)
} else {
meta.activationTargetShardCount = shardCount
}
if err := config.TypedConfig(&meta); err != nil {
return nil, fmt.Errorf("error parsing dynamodb stream metadata: %w", err)
}

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
Expand All @@ -125,18 +92,18 @@ func parseAwsDynamoDBStreamsMetadata(config *scalersconfig.ScalerConfig, logger
}

func createClientsForDynamoDBStreamsScaler(ctx context.Context, metadata *awsDynamoDBStreamsMetadata) (*dynamodb.Client, *dynamodbstreams.Client, error) {
cfg, err := awsutils.GetAwsConfig(ctx, metadata.awsRegion, metadata.awsAuthorization)
cfg, err := awsutils.GetAwsConfig(ctx, metadata.AwsRegion, metadata.awsAuthorization)
if err != nil {
return nil, nil, err
}
dbClient := dynamodb.NewFromConfig(*cfg, func(options *dynamodb.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})
dbStreamClient := dynamodbstreams.NewFromConfig(*cfg, func(options *dynamodbstreams.Options) {
if metadata.awsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.awsEndpoint)
if metadata.AwsEndpoint != "" {
options.BaseEndpoint = aws.String(metadata.AwsEndpoint)
}
})

Expand Down Expand Up @@ -176,9 +143,9 @@ func (s *awsDynamoDBStreamsScaler) Close(_ context.Context) error {
func (s *awsDynamoDBStreamsScaler) GetMetricSpecForScaling(_ context.Context) []v2.MetricSpec {
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.tableName))),
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("aws-dynamodb-streams-%s", s.metadata.TableName))),
},
Target: GetMetricTarget(s.metricType, s.metadata.targetShardCount),
Target: GetMetricTarget(s.metricType, s.metadata.TargetShardCount),
}
metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType}
return []v2.MetricSpec{metricSpec}
Expand All @@ -195,7 +162,7 @@ func (s *awsDynamoDBStreamsScaler) GetMetricsAndActivity(ctx context.Context, me

metric := GenerateMetricInMili(metricName, float64(shardCount))

return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.activationTargetShardCount, nil
return []external_metrics.ExternalMetricValue{metric}, shardCount > s.metadata.ActivationTargetShardCount, nil
}

// GetDynamoDBStreamShardCount Get DynamoDB Stream Shard Count
Expand Down
84 changes: 37 additions & 47 deletions pkg/scalers/aws_dynamodb_streams_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -159,11 +159,11 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsEndpoint": testAWSDynamoDBStreamsEndpoint},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
activationTargetShardCount: 1,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
awsEndpoint: testAWSDynamoDBStreamsEndpoint,
TargetShardCount: 2,
ActivationTargetShardCount: 1,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
AwsEndpoint: testAWSDynamoDBStreamsEndpoint,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -204,10 +204,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
activationTargetShardCount: defaultActivationTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: defaultTargetDBStreamsShardCount,
ActivationTargetShardCount: defaultActivationTargetDBStreamsShardCount,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand All @@ -224,20 +224,10 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"tableName": testAWSDynamoDBSmallTable,
"shardCount": "a",
"awsRegion": testAWSDynamoDBStreamsRegion},
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: defaultTargetDBStreamsShardCount,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
PodIdentityOwner: true,
},
triggerIndex: 4,
},
isError: false,
comment: "properly formed table name and region, wrong shard count",
authParams: testAWSKinesisAuthentication,
expected: &awsDynamoDBStreamsMetadata{},
isError: true,
comment: "invalid value - should cause error",
triggerIndex: 4,
},
{
Expand Down Expand Up @@ -278,9 +268,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsSessionToken": testAWSDynamoDBStreamsSessionToken,
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsAccessKeyID: testAWSDynamoDBStreamsAccessKeyID,
AwsSecretAccessKey: testAWSDynamoDBStreamsSecretAccessKey,
Expand Down Expand Up @@ -330,9 +320,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"awsRoleArn": testAWSDynamoDBStreamsRoleArn,
},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
AwsRoleArn: testAWSDynamoDBStreamsRoleArn,
PodIdentityOwner: true,
Expand All @@ -350,9 +340,9 @@ var testAwsDynamoDBStreamMetadata = []parseAwsDynamoDBStreamsMetadataTestData{
"identityOwner": "operator"},
authParams: map[string]string{},
expected: &awsDynamoDBStreamsMetadata{
targetShardCount: 2,
tableName: testAWSDynamoDBSmallTable,
awsRegion: testAWSDynamoDBStreamsRegion,
TargetShardCount: 2,
TableName: testAWSDynamoDBSmallTable,
AwsRegion: testAWSDynamoDBStreamsRegion,
awsAuthorization: awsutils.AuthorizationMetadata{
PodIdentityOwner: false,
},
Expand All @@ -370,15 +360,15 @@ var awsDynamoDBStreamMetricIdentifiers = []awsDynamoDBStreamsMetricIdentifier{
}

var awsDynamoDBStreamsGetMetricTestData = []*awsDynamoDBStreamsMetadata{
{tableName: testAWSDynamoDBBigTable},
{tableName: testAWSDynamoDBSmallTable},
{tableName: testAWSDynamoDBErrorTable},
{tableName: testAWSDynamoDBInvalidTable},
{TableName: testAWSDynamoDBBigTable},
{TableName: testAWSDynamoDBSmallTable},
{TableName: testAWSDynamoDBErrorTable},
{TableName: testAWSDynamoDBInvalidTable},
}

func TestParseAwsDynamoDBStreamsMetadata(t *testing.T) {
for _, testData := range testAwsDynamoDBStreamMetadata {
result, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
result, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil && !testData.isError {
t.Errorf("Expected success because %s got error, %s", testData.comment, err)
}
Expand All @@ -395,11 +385,11 @@ func TestParseAwsDynamoDBStreamsMetadata(t *testing.T) {
func TestAwsDynamoDBStreamsGetMetricSpecForScaling(t *testing.T) {
for _, testData := range awsDynamoDBStreamMetricIdentifiers {
ctx := context.Background()
meta, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex}, logr.Discard())
meta, err := parseAwsDynamoDBStreamsMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, ResolvedEnv: testAwsDynamoDBStreamAuthentication, AuthParams: testData.metadataTestData.authParams, TriggerIndex: testData.triggerIndex})
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err := getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err != nil {
t.Fatal("Could not get dynamodb stream arn:", err)
}
Expand All @@ -418,12 +408,12 @@ func TestAwsDynamoDBStreamsScalerGetMetrics(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
value, _, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand All @@ -442,12 +432,12 @@ func TestAwsDynamoDBStreamsScalerIsActive(t *testing.T) {
var err error
var streamArn *string
ctx := context.Background()
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.tableName)
streamArn, err = getDynamoDBStreamsArn(ctx, &mockAwsDynamoDB{}, &meta.TableName)
if err == nil {
scaler := awsDynamoDBStreamsScaler{"", meta, streamArn, &mockAwsDynamoDBStreams{}, logr.Discard()}
_, value, err = scaler.GetMetricsAndActivity(context.Background(), "MetricName")
}
switch meta.tableName {
switch meta.TableName {
case testAWSDynamoDBErrorTable:
assert.Error(t, err, "expect error because of dynamodb stream api error")
case testAWSDynamoDBInvalidTable:
Expand Down

0 comments on commit 362614b

Please sign in to comment.