Skip to content

Commit

Permalink
Support unsafeSsl for RabbitMQ scaler (kedacore#4571)
Browse files Browse the repository at this point in the history
  • Loading branch information
dttung2905 authored May 30, 2023
1 parent d917094 commit 1d2e7ca
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 42 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio
### Improvements

- **General**: Metrics Adapter: remove deprecated Prometheus Metrics and non-gRPC code ([#3930](https://github.com/kedacore/keda/issues/3930))
- **Azure Data Exporer Scaler**: Use azidentity SDK ([#4489](https://github.com/kedacore/keda/issues/4489))
- **Azure Data Explorer Scaler**: Use azidentity SDK ([#4489](https://github.com/kedacore/keda/issues/4489))
- **External Scaler**: Add tls options in TriggerAuth metadata. ([#3565](https://github.com/kedacore/keda/issues/3565))
- **GCP PubSub Scaler**: Make it more flexible for metrics ([#4243](https://github.com/kedacore/keda/issues/4243))
- **Kafka Scaler:** Add support for OAuth extensions ([#4544](https://github.com/kedacore/keda/issues/4544))
- **Pulsar Scaler**: Improve error messages for unsuccessful connections ([#4563](https://github.com/kedacore/keda/issues/4563))
- **Security:** Enable secret scanning in GitHub repo
- **RabbitMQ Scaler**: Add support for `unsafeSsl` in trigger metadata ([#4448](https://github.com/kedacore/keda/issues/4448))

### Fixes

Expand Down
93 changes: 65 additions & 28 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type rabbitMQMetadata struct {
key string
keyPassword string
enableTLS bool
unsafeSsl bool
}

type queueInfo struct {
Expand Down Expand Up @@ -124,7 +125,7 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) {
return nil, fmt.Errorf("error parsing rabbitmq metadata: %w", err)
}
s.metadata = meta
s.httpClient = kedautil.CreateHTTPClient(meta.timeout, false)
s.httpClient = kedautil.CreateHTTPClient(meta.timeout, meta.unsafeSsl)

if meta.protocol == amqpProtocol {
// Override vhost if requested.
Expand All @@ -149,10 +150,7 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) {
return s, nil
}

func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
meta := rabbitMQMetadata{}

// Resolve protocol type
func resolveProtocol(config *ScalerConfig, meta *rabbitMQMetadata) error {
meta.protocol = defaultProtocol
if val, ok := config.AuthParams["protocol"]; ok {
meta.protocol = val
Expand All @@ -161,10 +159,12 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
meta.protocol = val
}
if meta.protocol != amqpProtocol && meta.protocol != httpProtocol && meta.protocol != autoProtocol {
return nil, fmt.Errorf("the protocol has to be either `%s`, `%s`, or `%s` but is `%s`", amqpProtocol, httpProtocol, autoProtocol, meta.protocol)
return fmt.Errorf("the protocol has to be either `%s`, `%s`, or `%s` but is `%s`", amqpProtocol, httpProtocol, autoProtocol, meta.protocol)
}
return nil
}

// Resolve host value
func resolveHostValue(config *ScalerConfig, meta *rabbitMQMetadata) error {
switch {
case config.AuthParams["host"] != "":
meta.host = config.AuthParams["host"]
Expand All @@ -173,10 +173,31 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
case config.TriggerMetadata["hostFromEnv"] != "":
meta.host = config.ResolvedEnv[config.TriggerMetadata["hostFromEnv"]]
default:
return nil, fmt.Errorf("no host setting given")
return fmt.Errorf("no host setting given")
}
return nil
}

// Resolve TLS authentication parameters
func resolveTimeout(config *ScalerConfig, meta *rabbitMQMetadata) error {
if val, ok := config.TriggerMetadata["timeout"]; ok {
timeoutMS, err := strconv.Atoi(val)
if err != nil {
return fmt.Errorf("unable to parse timeout: %w", err)
}
if meta.protocol == amqpProtocol {
return fmt.Errorf("amqp protocol doesn't support custom timeouts: %w", err)
}
if timeoutMS <= 0 {
return fmt.Errorf("timeout must be greater than 0: %w", err)
}
meta.timeout = time.Duration(timeoutMS) * time.Millisecond
} else {
meta.timeout = config.GlobalHTTPTimeout
}
return nil
}

func resolveTLSAuthParams(config *ScalerConfig, meta *rabbitMQMetadata) error {
meta.enableTLS = false
if val, ok := config.AuthParams["tls"]; ok {
val = strings.TrimSpace(val)
Expand All @@ -186,9 +207,29 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
meta.key = config.AuthParams["key"]
meta.enableTLS = true
} else if val != "disable" {
return nil, fmt.Errorf("err incorrect value for TLS given: %s", val)
return fmt.Errorf("err incorrect value for TLS given: %s", val)
}
}
return nil
}

func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
meta := rabbitMQMetadata{}

// Resolve protocol type
if err := resolveProtocol(config, &meta); err != nil {
return nil, err
}

// Resolve host value
if err := resolveHostValue(config, &meta); err != nil {
return nil, err
}

// Resolve TLS authentication parameters
if err := resolveTLSAuthParams(config, &meta); err != nil {
return nil, err
}

meta.keyPassword = config.AuthParams["keyPassword"]

Expand All @@ -198,6 +239,15 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
return nil, fmt.Errorf("both key and cert must be provided")
}

meta.unsafeSsl = false
if val, ok := config.TriggerMetadata["unsafeSsl"]; ok {
boolVal, err := strconv.ParseBool(val)
if err != nil {
return nil, fmt.Errorf("failed to parse unsafeSsl value. Must be either true or false")
}
meta.unsafeSsl = boolVal
}

// If the protocol is auto, check the host scheme.
if meta.protocol == autoProtocol {
parsedURL, err := url.Parse(meta.host)
Expand Down Expand Up @@ -254,22 +304,9 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
}

// Resolve timeout
if val, ok := config.TriggerMetadata["timeout"]; ok {
timeoutMS, err := strconv.Atoi(val)
if err != nil {
return nil, fmt.Errorf("unable to parse timeout: %w", err)
}
if meta.protocol == amqpProtocol {
return nil, fmt.Errorf("amqp protocol doesn't support custom timeouts: %w", err)
}
if timeoutMS <= 0 {
return nil, fmt.Errorf("timeout must be greater than 0: %w", err)
}
meta.timeout = time.Duration(timeoutMS) * time.Millisecond
} else {
meta.timeout = config.GlobalHTTPTimeout
if err := resolveTimeout(config, &meta); err != nil {
return nil, err
}

meta.scalerIndex = config.ScalerIndex

return &meta, nil
Expand Down Expand Up @@ -396,7 +433,7 @@ func getConnectionAndChannel(host string, meta *rabbitMQMetadata) (*amqp.Connect
var conn *amqp.Connection
var err error
if meta.enableTLS {
tlsConfig, configErr := kedautil.NewTLSConfigWithPassword(meta.cert, meta.key, meta.keyPassword, meta.ca, false)
tlsConfig, configErr := kedautil.NewTLSConfigWithPassword(meta.cert, meta.key, meta.keyPassword, meta.ca, meta.unsafeSsl)
if configErr == nil {
conn, err = amqp.DialTLS(host, tlsConfig)
}
Expand Down Expand Up @@ -538,7 +575,7 @@ func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe
func (s *rabbitMQScaler) GetMetricsAndActivity(_ context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus()
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, s.anonimizeRabbitMQError(err)
return []external_metrics.ExternalMetricValue{}, false, s.anonymizeRabbitMQError(err)
}

var metric external_metrics.ExternalMetricValue
Expand Down Expand Up @@ -623,7 +660,7 @@ func getMaximum(q []queueInfo) (int, int, float64) {
}

// Mask host for log purposes
func (s *rabbitMQScaler) anonimizeRabbitMQError(err error) error {
func (s *rabbitMQScaler) anonymizeRabbitMQError(err error) error {
errorMessage := fmt.Sprintf("error inspecting rabbitMQ: %s", err)
return fmt.Errorf(rabbitMQAnonymizePattern.ReplaceAllString(errorMessage, "user:password@"))
}
40 changes: 27 additions & 13 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -126,6 +127,10 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
{map[string]string{"mode": "QueueLength", "value": "1000", "queueName": "sample", "host": "http://", "useRegex": "true", "excludeUnacknowledged": "true"}, false, map[string]string{}},
// amqp and excludeUnacknowledged
{map[string]string{"mode": "QueueLength", "value": "1000", "queueName": "sample", "host": "amqp://", "useRegex": "true", "excludeUnacknowledged": "true"}, true, map[string]string{}},
// unsafeSsl true
{map[string]string{"queueName": "sample", "host": "https://", "unsafeSsl": "true"}, false, map[string]string{}},
// unsafeSsl wrong input
{map[string]string{"queueName": "sample", "host": "https://", "unsafeSsl": "random"}, true, map[string]string{}},
}

var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{
Expand All @@ -150,18 +155,27 @@ var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{
}

func TestRabbitMQParseMetadata(t *testing.T) {
for _, testData := range testRabbitMQMetadata {
_, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
for idx, testData := range testRabbitMQMetadata {
meta, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
if testData.isError && err == nil {
t.Error("Expected error but got success")
t.Errorf("Expected error but got success in test case %d", idx)
}
if val, ok := testData.metadata["unsafeSsl"]; ok && err == nil {
boolVal, err := strconv.ParseBool(val)
if err != nil && !testData.isError {
t.Errorf("Expect error but got success in test case %d", idx)
}
if boolVal != meta.unsafeSsl {
t.Errorf("Expect %t but got %t in test case %d", boolVal, meta.unsafeSsl, idx)
}
}
}
}

func TestRabbitMQParseAuthParamdata(t *testing.T) {
func TestRabbitMQParseAuthParamData(t *testing.T) {
for _, testData := range testRabbitMQAuthParamData {
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
if err != nil && !testData.isError {
Expand Down Expand Up @@ -251,7 +265,7 @@ var testQueueInfoTestData = []getQueueInfoTestData{
{`Password is incorrect`, http.StatusUnauthorized, false, nil, ""},
}

var vhostPathes = []string{"/myhost", "", "/", "//", rabbitRootVhostPath}
var vhostPaths = []string{"/myhost", "", "/", "//", rabbitRootVhostPath}

var testQueueInfoTestDataSingleVhost = []getQueueInfoTestData{
{`{"messages": 4, "messages_unacknowledged": 1, "message_stats": {"publish_details": {"rate": 1.4}}, "name": "evaluate_trials"}`, http.StatusOK, true, map[string]string{"hostFromEnv": "plainHost", "vhostName": "myhost"}, "/myhost"},
Expand All @@ -265,7 +279,7 @@ var testQueueInfoTestDataSingleVhost = []getQueueInfoTestData{
func TestGetQueueInfo(t *testing.T) {
allTestData := []getQueueInfoTestData{}
for _, testData := range testQueueInfoTestData {
for _, vhostPath := range vhostPathes {
for _, vhostPath := range vhostPaths {
testData := testData
testData.vhostPath = vhostPath
allTestData = append(allTestData, testData)
Expand Down Expand Up @@ -400,12 +414,12 @@ var testRegexQueueInfoTestData = []getQueueInfoTestData{
{`{"items":[]}`, http.StatusOK, false, map[string]string{"mode": "MessageRate", "value": "1000", "useRegex": "true", "operation": "avg"}, ""},
}

var vhostPathesForRegex = []string{"", "/test-vh", rabbitRootVhostPath}
var vhostPathsForRegex = []string{"", "/test-vh", rabbitRootVhostPath}

func TestGetQueueInfoWithRegex(t *testing.T) {
allTestData := []getQueueInfoTestData{}
for _, testData := range testRegexQueueInfoTestData {
for _, vhostPath := range vhostPathesForRegex {
for _, vhostPath := range vhostPathsForRegex {
testData := testData
testData.vhostPath = vhostPath
allTestData = append(allTestData, testData)
Expand Down Expand Up @@ -485,7 +499,7 @@ var testRegexPageSizeTestData = []getRegexPageSizeTestData{
func TestGetPageSizeWithRegex(t *testing.T) {
allTestData := []getRegexPageSizeTestData{}
for _, testData := range testRegexPageSizeTestData {
for _, vhostPath := range vhostPathesForRegex {
for _, vhostPath := range vhostPathsForRegex {
testData := testData
testData.queueInfo.vhostPath = vhostPath
allTestData = append(allTestData, testData)
Expand Down Expand Up @@ -568,7 +582,7 @@ type rabbitMQErrorTestData struct {
message string
}

var anonimizeRabbitMQErrorTestData = []rabbitMQErrorTestData{
var anonymizeRabbitMQErrorTestData = []rabbitMQErrorTestData{
{fmt.Errorf("https://user1:[email protected]"), "error inspecting rabbitMQ: https://user:[email protected]"},
{fmt.Errorf("https://fdasr345_-:[email protected]"), "error inspecting rabbitMQ: https://user:[email protected]"},
{fmt.Errorf("https://user1:[email protected]"), "error inspecting rabbitMQ: https://user:[email protected]"},
Expand All @@ -580,7 +594,7 @@ var anonimizeRabbitMQErrorTestData = []rabbitMQErrorTestData{
{fmt.Errorf("the queue https://user1:[email protected]/api/virtual is unavailable"), "error inspecting rabbitMQ: the queue https://user:[email protected]/api/virtual is unavailable"},
}

func TestRabbitMQAnonimizeRabbitMQError(t *testing.T) {
func TestRabbitMQAnonymizeRabbitMQError(t *testing.T) {
metadata := map[string]string{
"queueName": "evaluate_trials",
"hostFromEnv": host,
Expand All @@ -596,8 +610,8 @@ func TestRabbitMQAnonimizeRabbitMQError(t *testing.T) {
metadata: meta,
httpClient: nil,
}
for _, testData := range anonimizeRabbitMQErrorTestData {
err := s.anonimizeRabbitMQError(testData.err)
for _, testData := range anonymizeRabbitMQErrorTestData {
err := s.anonymizeRabbitMQError(testData.err)
assert.Equal(t, fmt.Sprint(err), testData.message)
}
}
Expand Down

0 comments on commit 1d2e7ca

Please sign in to comment.