Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Provide support for RDS MySQL IAM Authentication #140

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ For some database backends some special functionality is available:
which will use the equivalent of `rds generate-db-auth-token`
for the password. For this driver, the `AWS_REGION` environment variable
must be set.
* rds-mysql: This type of URL expects a working AWS configuration
which will use the equivalent of `rds generate-db-auth-token`
for the password. For this driver, the `AWS_REGION` environment variable
must be set.


Why this exporter exists
========================
Expand Down
13 changes: 7 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ type Job struct {
}

type connection struct {
conn *sqlx.DB
url string
driver string
host string
database string
user string
conn *sqlx.DB
url string
driver string
host string
database string
user string
tokenExpirationTime time.Time
}

// Query is an SQL query that is executed on a connection
Expand Down
102 changes: 93 additions & 9 deletions job.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ var (
CloudSQLPrefix = "cloudsql+"
)

func handleRDSMySQLIAMAuth(conn string) (string, time.Time, error) {
dsn := strings.TrimPrefix(conn, "rds-mysql://")
config, err := mysql.ParseDSN(dsn)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to parse MySQL DSN: %v", err)
}

sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))

token, err := rdsutils.BuildAuthToken(config.Addr, os.Getenv("AWS_REGION"), config.User, sess.Config.Credentials)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to build RDS auth token: %v", err)
}

expirationTime := time.Now().Add(14 * time.Minute)

return token, expirationTime, nil
}

// Init will initialize the metric descriptors
func (j *Job) Init(logger log.Logger, queries map[string]string) error {
j.log = log.With(logger, "job", j.Name)
Expand Down Expand Up @@ -207,23 +228,53 @@ func (j *Job) updateConnections() {
continue
}

// MySQL DSNs do not parse cleanly as URLs as of Go 1.12.8+
if strings.HasPrefix(conn, "mysql://") {
config, err := mysql.ParseDSN(strings.TrimPrefix(conn, "mysql://"))
// Handle both RDS MySQL and regular MySQL connections
if strings.HasPrefix(conn, "rds-mysql://") || strings.HasPrefix(conn, "mysql://") {
isRDS := strings.HasPrefix(conn, "rds-mysql://")
var dsn string
var expirationTime time.Time

trimmedConn := conn
if isRDS {
trimmedConn = strings.TrimPrefix(conn, "rds-mysql://")
} else {
trimmedConn = strings.TrimPrefix(conn, "mysql://")
}

config, err := mysql.ParseDSN(trimmedConn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to parse MySQL DSN", "url", conn, "err", err)
continue
}

if isRDS {
authToken, tokenExpiration, err := handleRDSMySQLIAMAuth(conn)
if err != nil {
level.Error(j.log).Log("msg", "Failed to build RDS auth token", "url", conn, "err", err)
continue
}
config.Passwd = authToken
config.AllowCleartextPasswords = true
expirationTime = tokenExpiration
}

dsn = config.FormatDSN()
if isRDS {
dsn = "rds-mysql://" + dsn
}

j.conns = append(j.conns, &connection{
conn: nil,
url: conn,
driver: "mysql",
host: config.Addr,
database: config.DBName,
user: config.User,
conn: nil,
url: dsn,
driver: "mysql",
host: config.Addr,
database: config.DBName,
user: config.User,
tokenExpirationTime: expirationTime,
})
continue
}

if strings.HasPrefix(conn, "rds-postgres://") {
// Reuse Postgres driver by stripping "rds-" from connection URL after building the RDS authentication token
conn = strings.TrimPrefix(conn, "rds-")
Expand Down Expand Up @@ -438,12 +489,45 @@ func (j *Job) runOnce() error {
func (c *connection) connect(job *Job) error {
// already connected
if c.conn != nil {
if strings.HasPrefix(c.url, "rds-mysql://") && time.Now().After(c.tokenExpirationTime) {
level.Warn(job.log).Log("msg", "Connection token expired, reconnecting")

authToken, expirationTime, err := handleRDSMySQLIAMAuth(c.url)
if err != nil {
return fmt.Errorf("failed to refresh RDS MySQL IAM Auth token: %w", err)
}

config, err := mysql.ParseDSN(strings.TrimPrefix(c.url, "rds-mysql://"))
if err != nil {
return fmt.Errorf("failed to parse MySQL DSN: %w", err)
}

config.Passwd = authToken
dsn := "rds-mysql://" + config.FormatDSN()

// Close the existing connection
c.conn.Close()
c.conn = nil

// Update the connection details
c.tokenExpirationTime = expirationTime
c.url = dsn

// Connect to the database with the new token
conn, err := sqlx.Connect(c.driver, strings.TrimPrefix(dsn, "rds-mysql://"))
if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err)
}
c.conn = conn
return nil
}
return nil
}
dsn := c.url
switch c.driver {
case "mysql":
dsn = strings.TrimPrefix(dsn, "mysql://")
dsn = strings.TrimPrefix(dsn, "rds-mysql://")
case "clickhouse+tcp", "clickhouse+http": // Support both http and tcp connections
dsn = strings.TrimPrefix(dsn, "clickhouse+")
c.driver = "clickhouse"
Expand Down
Loading