Skip to content

Commit

Permalink
Refactor mongo scaler (#6261)
Browse files Browse the repository at this point in the history
Signed-off-by: rickbrouwer <[email protected]>
  • Loading branch information
rickbrouwer authored Nov 3, 2024
1 parent b34fad0 commit 0e7801d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 179 deletions.
251 changes: 78 additions & 173 deletions pkg/scalers/mongo_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"

"github.com/go-logr/logr"
Expand All @@ -22,60 +20,45 @@ import (
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

// mongoDBScaler is support for mongoDB in keda.
type mongoDBScaler struct {
metricType v2.MetricTargetType
metadata *mongoDBMetadata
metadata mongoDBMetadata
client *mongo.Client
logger logr.Logger
}

// mongoDBMetadata specify mongoDB scaler params.
type mongoDBMetadata struct {
// The string is used by connected with mongoDB.
// +optional
connectionString string
// Specify the prefix to connect to the mongoDB server, default value `mongodb`, if the connectionString be provided, don't need to specify this param.
// +optional
scheme string
// Specify the host to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
host string
// Specify the port to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
port string
// Specify the username to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
username string
// Specify the password to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param.
// +optional
password string

// The name of the database to be queried.
// +required
dbName string
// The name of the collection to be queried.
// +required
collection string
// A mongoDB filter doc,used by specify DB.
// +required
query string
// A threshold that is used as targetAverageValue in HPA
// +required
queryValue int64
// A threshold that is used to check if scaler is active
// +optional
activationQueryValue int64

// The index of the scaler inside the ScaledObject
// +internal
triggerIndex int
ConnectionString string `keda:"name=connectionString,order=authParams;triggerMetadata;resolvedEnv,optional"`
Scheme string `keda:"name=scheme,order=authParams;triggerMetadata,default=mongodb,optional"`
Host string `keda:"name=host,order=authParams;triggerMetadata,optional"`
Port string `keda:"name=port,order=authParams;triggerMetadata,optional"`
Username string `keda:"name=username,order=authParams;triggerMetadata,optional"`
Password string `keda:"name=password,order=authParams;triggerMetadata;resolvedEnv,optional"`
DBName string `keda:"name=dbName,order=authParams;triggerMetadata"`
Collection string `keda:"name=collection,order=triggerMetadata"`
Query string `keda:"name=query,order=triggerMetadata"`
QueryValue int64 `keda:"name=queryValue,order=triggerMetadata"`
ActivationQueryValue int64 `keda:"name=activationQueryValue,order=triggerMetadata,default=0"`
TriggerIndex int
}

// Default variables and settings
const (
mongoDBDefaultTimeOut = 10 * time.Second
)
func (m *mongoDBMetadata) Validate() error {
if m.ConnectionString == "" {
if m.Host == "" {
return fmt.Errorf("no host given")
}
if m.Port == "" && m.Scheme != "mongodb+srv" {
return fmt.Errorf("no port given")
}
if m.Username == "" {
return fmt.Errorf("no username given")
}
if m.Password == "" {
return fmt.Errorf("no password given")
}
}
return nil
}

// NewMongoDBScaler creates a new mongoDB scaler
func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
Expand All @@ -84,22 +67,14 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (
return nil, fmt.Errorf("error getting scaler metric type: %w", err)
}

ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut)
defer cancel()

meta, connStr, err := parseMongoDBMetadata(config)
meta, err := parseMongoDBMetadata(config)
if err != nil {
return nil, fmt.Errorf("failed to parsing mongoDB metadata, because of %w", err)
return nil, fmt.Errorf("error parsing mongodb metadata: %w", err)
}

opt := options.Client().ApplyURI(connStr)
client, err := mongo.Connect(ctx, opt)
client, err := createMongoDBClient(ctx, meta)
if err != nil {
return nil, fmt.Errorf("failed to establish connection with mongoDB, because of %w", err)
}

if err = client.Ping(ctx, readpref.Primary()); err != nil {
return nil, fmt.Errorf("failed to ping mongoDB, because of %w", err)
return nil, fmt.Errorf("error creating mongodb client: %w", err)
}

return &mongoDBScaler{
Expand All @@ -110,171 +85,101 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (
}, nil
}

func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (*mongoDBMetadata, string, error) {
var connStr string
var err error
// setting default metadata
func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (mongoDBMetadata, error) {
meta := mongoDBMetadata{}

// parse metaData from ScaledJob config
if val, ok := config.TriggerMetadata["collection"]; ok {
meta.collection = val
} else {
return nil, "", fmt.Errorf("no collection given")
err := config.TypedConfig(&meta)
if err != nil {
return meta, fmt.Errorf("error parsing mongodb metadata: %w", err)
}

if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, "", fmt.Errorf("no query given")
}
meta.TriggerIndex = config.TriggerIndex
return meta, nil
}

if val, ok := config.TriggerMetadata["queryValue"]; ok {
queryValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err)
}
meta.queryValue = queryValue
func createMongoDBClient(ctx context.Context, meta mongoDBMetadata) (*mongo.Client, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

var connString string
if meta.ConnectionString != "" {
connString = meta.ConnectionString
} else {
if config.AsMetricSource {
meta.queryValue = 0
} else {
return nil, "", fmt.Errorf("no queryValue given")
host := meta.Host
if meta.Scheme != "mongodb+srv" {
host = net.JoinHostPort(meta.Host, meta.Port)
}
}

meta.activationQueryValue = 0
if val, ok := config.TriggerMetadata["activationQueryValue"]; ok {
activationQueryValue, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err)
u := &url.URL{
Scheme: meta.Scheme,
User: url.UserPassword(meta.Username, meta.Password),
Host: host,
Path: meta.DBName,
}
meta.activationQueryValue = activationQueryValue
connString = u.String()
}

dbName, err := GetFromAuthOrMeta(config, "dbName")
client, err := mongo.Connect(ctx, options.Client().ApplyURI(connString))
if err != nil {
return nil, "", err
return nil, fmt.Errorf("failed to create mongodb client: %w", err)
}
meta.dbName = dbName

// Resolve connectionString
switch {
case config.AuthParams["connectionString"] != "":
meta.connectionString = config.AuthParams["connectionString"]
case config.TriggerMetadata["connectionStringFromEnv"] != "":
meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]]
default:
meta.connectionString = ""
scheme, err := GetFromAuthOrMeta(config, "scheme")
if err != nil {
meta.scheme = "mongodb"
} else {
meta.scheme = scheme
}

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, "", err
}
meta.host = host

if !strings.Contains(scheme, "mongodb+srv") {
port, err := GetFromAuthOrMeta(config, "port")
if err != nil {
return nil, "", err
}
meta.port = port
}

username, err := GetFromAuthOrMeta(config, "username")
if err != nil {
return nil, "", err
}
meta.username = username

if config.AuthParams["password"] != "" {
meta.password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
}
if len(meta.password) == 0 {
return nil, "", fmt.Errorf("no password given")
}
}

switch {
case meta.connectionString != "":
connStr = meta.connectionString
case meta.scheme == "mongodb+srv":
// nosemgrep: db-connection-string
connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), meta.host, meta.dbName)
default:
addr := net.JoinHostPort(meta.host, meta.port)
// nosemgrep: db-connection-string
connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), addr, meta.dbName)
err = client.Ping(ctx, readpref.Primary())
if err != nil {
return nil, fmt.Errorf("failed to ping mongodb: %w", err)
}

meta.triggerIndex = config.TriggerIndex
return &meta, connStr, nil
return client, nil
}

// Close disposes of mongoDB connections
func (s *mongoDBScaler) Close(ctx context.Context) error {
if s.client != nil {
err := s.client.Disconnect(ctx)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to close mongoDB connection, because of %v", err))
s.logger.Error(err, "Error closing mongodb connection")
return err
}
}

return nil
}

// getQueryResult query mongoDB by meta.query
func (s *mongoDBScaler) getQueryResult(ctx context.Context) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut)
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

filter, err := json2BsonDoc(s.metadata.query)
collection := s.client.Database(s.metadata.DBName).Collection(s.metadata.Collection)

filter, err := json2BsonDoc(s.metadata.Query)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to convert query param to bson.Doc, because of %v", err))
return 0, err
return 0, fmt.Errorf("failed to parse query: %w", err)
}

docsNum, err := s.client.Database(s.metadata.dbName).Collection(s.metadata.collection).CountDocuments(ctx, filter)
count, err := collection.CountDocuments(ctx, filter)
if err != nil {
s.logger.Error(err, fmt.Sprintf("failed to query %v in %v, because of %v", s.metadata.dbName, s.metadata.collection, err))
return 0, err
return 0, fmt.Errorf("failed to execute query: %w", err)
}

return docsNum, nil
return count, nil
}

// GetMetricsAndActivity query from mongoDB,and return to external metrics
func (s *mongoDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
num, err := s.getQueryResult(ctx)
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect momgoDB, because of %w", err)
return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect mongodb: %w", err)
}

metric := GenerateMetricInMili(metricName, float64(num))

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationQueryValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationQueryValue, nil
}

// GetMetricSpecForScaling get the query value for scaling
func (s *mongoDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec {
metricName := kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.Collection))
externalMetric := &v2.ExternalMetricSource{
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.collection))),
Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName),
},
Target: GetMetricTarget(s.metricType, s.metadata.queryValue),
}
metricSpec := v2.MetricSpec{
External: externalMetric, Type: externalMetricType,
Target: GetMetricTarget(s.metricType, s.metadata.QueryValue),
}
metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType}
return []v2.MetricSpec{metricSpec}
}

Expand Down
15 changes: 9 additions & 6 deletions pkg/scalers/mongo_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"testing"

"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/mongo"
v2 "k8s.io/api/autoscaling/v2"

"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ var mongoDBMetricIdentifiers = []mongoDBMetricIdentifier{

func TestParseMongoDBMetadata(t *testing.T) {
for _, testData := range testMONGODBMetadata {
_, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
_, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
if err != nil && !testData.raisesError {
t.Error("Expected success but got error:", err)
}
Expand All @@ -112,21 +112,24 @@ func TestParseMongoDBMetadata(t *testing.T) {

func TestParseMongoDBConnectionString(t *testing.T) {
for _, testData := range mongoDBConnectionStringTestDatas {
_, connStr, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, TriggerMetadata: testData.metadataTestData.metadata, AuthParams: testData.metadataTestData.authParams})
_, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{
ResolvedEnv: testData.metadataTestData.resolvedEnv,
TriggerMetadata: testData.metadataTestData.metadata,
AuthParams: testData.metadataTestData.authParams,
})
if err != nil {
t.Error("Expected success but got error:", err)
}
assert.Equal(t, testData.connectionString, connStr)
}
}

func TestMongoDBGetMetricSpecForScaling(t *testing.T) {
for _, testData := range mongoDBMetricIdentifiers {
meta, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex})
meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex})
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
mockMongoDBScaler := mongoDBScaler{"", meta, &mongo.Client{}, logr.Discard()}
mockMongoDBScaler := mongoDBScaler{metricType: v2.AverageValueMetricType, metadata: meta, client: &mongo.Client{}, logger: logr.Discard()}

metricSpec := mockMongoDBScaler.GetMetricSpecForScaling(context.Background())
metricName := metricSpec[0].External.Metric.Name
Expand Down

0 comments on commit 0e7801d

Please sign in to comment.